diff --git a/.travis.yml b/.travis.yml index fe4be6b6..d1ae0dd5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,9 @@ language: go go: - - "1.9.7" - - "1.10.3" + - "1.9.x" + - "1.10.x" + - "1.11.x" services: - redis-server - mysql @@ -34,6 +35,7 @@ install: - go get github.com/gogo/protobuf/proto - go get github.com/Knetic/govaluate - go get github.com/casbin/casbin + - go get github.com/elazarl/go-bindata-assetfs - go get -u honnef.co/go/tools/cmd/gosimple - go get -u github.com/mdempsky/unconvert - go get -u github.com/gordonklaus/ineffassign diff --git a/README.md b/README.md index aa5d8e19..5063645c 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ beego is used for rapid development of RESTful APIs, web apps and backend services in Go. It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific features such as interfaces and struct embedding. + Response time ranking: [web-frameworks](https://github.com/the-benchmarker/web-frameworks). + ###### More info at [beego.me](http://beego.me). ## Quick Start diff --git a/admin.go b/admin.go index 73d4f9f2..715f6b3f 100644 --- a/admin.go +++ b/admin.go @@ -35,7 +35,7 @@ import ( var beeAdminApp *adminApp // FilterMonitorFunc is default monitor filter when admin module is enable. -// if this func returns, admin module records qbs for this request by condition of this function logic. +// if this func returns, admin module records qps for this request by condition of this function logic. // usage: // func MyFilterMonitor(method, requestPath string, t time.Duration, pattern string, statusCode int) bool { // if method == "POST" { @@ -67,18 +67,18 @@ func init() { // AdminIndex is the default http.Handler for admin module. // it matches url pattern "/". -func adminIndex(rw http.ResponseWriter, r *http.Request) { +func adminIndex(rw http.ResponseWriter, _ *http.Request) { execTpl(rw, map[interface{}]interface{}{}, indexTpl, defaultScriptsTpl) } -// QpsIndex is the http.Handler for writing qbs statistics map result info in http.ResponseWriter. -// it's registered with url pattern "/qbs" in admin module. -func qpsIndex(rw http.ResponseWriter, r *http.Request) { +// QpsIndex is the http.Handler for writing qps statistics map result info in http.ResponseWriter. +// it's registered with url pattern "/qps" in admin module. +func qpsIndex(rw http.ResponseWriter, _ *http.Request) { data := make(map[interface{}]interface{}) data["Content"] = toolbox.StatisticsMap.GetMap() // do html escape before display path, avoid xss - if content, ok := (data["Content"]).(map[string]interface{}); ok { + if content, ok := (data["Content"]).(M); ok { if resultLists, ok := (content["Data"]).([][]string); ok { for i := range resultLists { if len(resultLists[i]) > 0 { @@ -104,7 +104,7 @@ func listConf(rw http.ResponseWriter, r *http.Request) { data := make(map[interface{}]interface{}) switch command { case "conf": - m := make(map[string]interface{}) + m := make(M) list("BConfig", BConfig, m) m["AppConfigPath"] = appConfigPath m["AppConfigProvider"] = appConfigProvider @@ -128,14 +128,14 @@ func listConf(rw http.ResponseWriter, r *http.Request) { execTpl(rw, data, routerAndFilterTpl, defaultScriptsTpl) case "filter": var ( - content = map[string]interface{}{ + content = M{ "Fields": []string{ "Router Pattern", "Filter Function", }, } filterTypes = []string{} - filterTypeData = make(map[string]interface{}) + filterTypeData = make(M) ) if BeeApp.Handlers.enableFilter { @@ -173,7 +173,7 @@ func listConf(rw http.ResponseWriter, r *http.Request) { } } -func list(root string, p interface{}, m map[string]interface{}) { +func list(root string, p interface{}, m M) { pt := reflect.TypeOf(p) pv := reflect.ValueOf(p) if pt.Kind() == reflect.Ptr { @@ -196,11 +196,11 @@ func list(root string, p interface{}, m map[string]interface{}) { } // PrintTree prints all registered routers. -func PrintTree() map[string]interface{} { +func PrintTree() M { var ( - content = map[string]interface{}{} + content = M{} methods = []string{} - methodsData = make(map[string]interface{}) + methodsData = make(M) ) for method, t := range BeeApp.Handlers.routers { @@ -291,12 +291,12 @@ func profIndex(rw http.ResponseWriter, r *http.Request) { // Healthcheck is a http.Handler calling health checking and showing the result. // it's in "/healthcheck" pattern in admin module. -func healthcheck(rw http.ResponseWriter, req *http.Request) { +func healthcheck(rw http.ResponseWriter, _ *http.Request) { var ( result []string data = make(map[interface{}]interface{}) resultList = new([][]string) - content = map[string]interface{}{ + content = M{ "Fields": []string{"Name", "Message", "Status"}, } ) @@ -344,7 +344,7 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) { } // List Tasks - content := make(map[string]interface{}) + content := make(M) resultList := new([][]string) var fields = []string{ "Task Name", diff --git a/admin_test.go b/admin_test.go index 03a3c044..539837cf 100644 --- a/admin_test.go +++ b/admin_test.go @@ -6,7 +6,7 @@ import ( ) func TestList_01(t *testing.T) { - m := make(map[string]interface{}) + m := make(M) list("BConfig", BConfig, m) t.Log(m) om := oldMap() @@ -18,8 +18,8 @@ func TestList_01(t *testing.T) { } } -func oldMap() map[string]interface{} { - m := make(map[string]interface{}) +func oldMap() M { + m := make(M) m["BConfig.AppName"] = BConfig.AppName m["BConfig.RunMode"] = BConfig.RunMode m["BConfig.RouterCaseSensitive"] = BConfig.RouterCaseSensitive diff --git a/beego.go b/beego.go index 6b90c61c..ac2b4f46 100644 --- a/beego.go +++ b/beego.go @@ -23,7 +23,7 @@ import ( const ( // VERSION represent beego web framework version. - VERSION = "1.10.1" + VERSION = "1.11.0" // DEV is for develop DEV = "dev" @@ -31,7 +31,10 @@ const ( PROD = "prod" ) -//hook function to run +// M is Map shortcut +type M map[string]interface{} + +// Hook function to run type hookfunc func() error var ( diff --git a/cache/memory.go b/cache/memory.go index 57e868cf..cb9802ab 100644 --- a/cache/memory.go +++ b/cache/memory.go @@ -203,13 +203,17 @@ func (bc *MemoryCache) StartAndGC(config string) error { dur := time.Duration(cf["interval"]) * time.Second bc.Every = cf["interval"] bc.dur = dur - go bc.vaccuum() + go bc.vacuum() return nil } // check expiration. -func (bc *MemoryCache) vaccuum() { - if bc.Every < 1 { +func (bc *MemoryCache) vacuum() { + bc.RLock() + every := bc.Every + bc.RUnlock() + + if every < 1 { return } for { diff --git a/cache/redis/redis.go b/cache/redis/redis.go index 55fea25e..372dd48b 100644 --- a/cache/redis/redis.go +++ b/cache/redis/redis.go @@ -39,6 +39,7 @@ import ( "github.com/gomodule/redigo/redis" "github.com/astaxie/beego/cache" + "strings" ) var ( @@ -164,6 +165,14 @@ func (rc *Cache) StartAndGC(config string) error { if _, ok := cf["conn"]; !ok { return errors.New("config has no conn key") } + + // Format redis://@: + cf["conn"] = strings.Replace(cf["conn"], "redis://", "", 1) + if i := strings.Index(cf["conn"], "@"); i > -1 { + cf["password"] = cf["conn"][0:i] + cf["conn"] = cf["conn"][i+1:] + } + if _, ok := cf["dbNum"]; !ok { cf["dbNum"] = "0" } diff --git a/config/ini.go b/config/ini.go index 1376d933..002e5e05 100644 --- a/config/ini.go +++ b/config/ini.go @@ -78,15 +78,37 @@ func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, e } } section := defaultSection + tmpBuf := bytes.NewBuffer(nil) for { - line, _, err := buf.ReadLine() - if err == io.EOF { + tmpBuf.Reset() + + shouldBreak := false + for { + tmp, isPrefix, err := buf.ReadLine() + if err == io.EOF { + shouldBreak = true + break + } + + //It might be a good idea to throw a error on all unknonw errors? + if _, ok := err.(*os.PathError); ok { + return nil, err + } + + tmpBuf.Write(tmp) + if isPrefix { + continue + } + + if !isPrefix { + break + } + } + if shouldBreak { break } - //It might be a good idea to throw a error on all unknonw errors? - if _, ok := err.(*os.PathError); ok { - return nil, err - } + + line := tmpBuf.Bytes() line = bytes.TrimSpace(line) if bytes.Equal(line, bEmpty) { continue diff --git a/config_test.go b/config_test.go index 379fe1c6..53411b01 100644 --- a/config_test.go +++ b/config_test.go @@ -48,15 +48,15 @@ func TestAssignConfig_02(t *testing.T) { _BConfig := &Config{} bs, _ := json.Marshal(newBConfig()) - jsonMap := map[string]interface{}{} + jsonMap := M{} json.Unmarshal(bs, &jsonMap) - configMap := map[string]interface{}{} + configMap := M{} for k, v := range jsonMap { if reflect.TypeOf(v).Kind() == reflect.Map { - for k1, v1 := range v.(map[string]interface{}) { + for k1, v1 := range v.(M) { if reflect.TypeOf(v1).Kind() == reflect.Map { - for k2, v2 := range v1.(map[string]interface{}) { + for k2, v2 := range v1.(M) { configMap[k2] = v2 } } else { diff --git a/context/context.go b/context/context.go index 8b32062c..452834e5 100644 --- a/context/context.go +++ b/context/context.go @@ -38,6 +38,14 @@ import ( "github.com/astaxie/beego/utils" ) +//commonly used mime-types +const ( + ApplicationJSON = "application/json" + ApplicationXML = "application/xml" + ApplicationYAML = "application/x-yaml" + TextXML = "text/xml" +) + // NewContext return the Context with Input and Output func NewContext() *Context { return &Context{ @@ -244,3 +252,11 @@ func (r *Response) CloseNotify() <-chan bool { } return nil } + +// Pusher http.Pusher +func (r *Response) Pusher() (pusher http.Pusher) { + if pusher, ok := r.ResponseWriter.(http.Pusher); ok { + return pusher + } + return nil +} \ No newline at end of file diff --git a/context/output.go b/context/output.go index 3cc6044e..3e277ab2 100644 --- a/context/output.go +++ b/context/output.go @@ -206,7 +206,7 @@ func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, encoding bool) // YAML writes yaml to response body. func (output *BeegoOutput) YAML(data interface{}) error { - output.Header("Content-Type", "application/application/x-yaml; charset=utf-8") + output.Header("Content-Type", "application/x-yaml; charset=utf-8") var content []byte var err error content, err = yaml.Marshal(data) @@ -260,6 +260,19 @@ func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error { return output.Body(content) } +// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header +func (output *BeegoOutput) ServeFormatted(data interface{}, hasIndent bool, hasEncode ...bool) { + accept := output.Context.Input.Header("Accept") + switch accept { + case ApplicationYAML: + output.YAML(data) + case ApplicationXML, TextXML: + output.XML(data, hasIndent) + default: + output.JSON(data, hasIndent, len(hasEncode) > 0 && hasEncode[0]) + } +} + // Download forces response for download file. // it prepares the download response header automatically. func (output *BeegoOutput) Download(file string, filename ...string) { diff --git a/controller.go b/controller.go index 8be43a33..4b8f9807 100644 --- a/controller.go +++ b/controller.go @@ -32,14 +32,6 @@ import ( "github.com/astaxie/beego/session" ) -//commonly used mime-types -const ( - applicationJSON = "application/json" - applicationXML = "application/xml" - applicationYAML = "application/x-yaml" - textXML = "text/xml" -) - var ( // ErrAbort custom error when user stop request handler manually. ErrAbort = errors.New("User stop run") @@ -47,10 +39,37 @@ var ( GlobalControllerRouter = make(map[string][]ControllerComments) ) +// ControllerFilter store the filter for controller +type ControllerFilter struct { + Pattern string + Pos int + Filter FilterFunc + ReturnOnOutput bool + ResetParams bool +} + +// ControllerFilterComments store the comment for controller level filter +type ControllerFilterComments struct { + Pattern string + Pos int + Filter string // NOQA + ReturnOnOutput bool + ResetParams bool +} + +// ControllerImportComments store the import comment for controller needed +type ControllerImportComments struct { + ImportPath string + ImportAlias string +} + // ControllerComments store the comment for the controller method type ControllerComments struct { Method string Router string + Filters []*ControllerFilter + ImportComments []*ControllerImportComments + FilterComments []*ControllerFilterComments AllowHTTPMethods []string Params []map[string]string MethodParams []*param.MethodParam @@ -275,16 +294,15 @@ func (c *Controller) viewPath() string { func (c *Controller) Redirect(url string, code int) { logAccess(c.Ctx, nil, code) c.Ctx.Redirect(code, url) - panic(ErrAbort) } -// Set the data depending on the accepted +// SetData set the data depending on the accepted func (c *Controller) SetData(data interface{}) { accept := c.Ctx.Input.Header("Accept") switch accept { - case applicationJSON: - c.Data["json"] = data - case applicationXML, textXML: + case context.ApplicationYAML: + c.Data["yaml"] = data + case context.ApplicationXML, context.TextXML: c.Data["xml"] = data default: c.Data["json"] = data @@ -333,54 +351,35 @@ func (c *Controller) URLFor(endpoint string, values ...interface{}) string { // ServeJSON sends a json response with encoding charset. func (c *Controller) ServeJSON(encoding ...bool) { var ( - hasIndent = true - hasEncoding = false + hasIndent = BConfig.RunMode != PROD + hasEncoding = len(encoding) > 0 && encoding[0] ) - if BConfig.RunMode == PROD { - hasIndent = false - } - if len(encoding) > 0 && encoding[0] { - hasEncoding = true - } + c.Ctx.Output.JSON(c.Data["json"], hasIndent, hasEncoding) } // ServeJSONP sends a jsonp response. func (c *Controller) ServeJSONP() { - hasIndent := true - if BConfig.RunMode == PROD { - hasIndent = false - } + hasIndent := BConfig.RunMode != PROD c.Ctx.Output.JSONP(c.Data["jsonp"], hasIndent) } // ServeXML sends xml response. func (c *Controller) ServeXML() { - hasIndent := true - if BConfig.RunMode == PROD { - hasIndent = false - } + hasIndent := BConfig.RunMode != PROD c.Ctx.Output.XML(c.Data["xml"], hasIndent) } -// ServeXML sends xml response. +// ServeYAML sends yaml response. func (c *Controller) ServeYAML() { c.Ctx.Output.YAML(c.Data["yaml"]) } -// ServeFormatted serve Xml OR Json, depending on the value of the Accept header -func (c *Controller) ServeFormatted() { - accept := c.Ctx.Input.Header("Accept") - switch accept { - case applicationJSON: - c.ServeJSON() - case applicationXML, textXML: - c.ServeXML() - case applicationYAML: - c.ServeYAML() - default: - c.ServeJSON() - } +// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header +func (c *Controller) ServeFormatted(encoding ...bool) { + hasIndent := BConfig.RunMode != PROD + hasEncoding := len(encoding) > 0 && encoding[0] + c.Ctx.Output.ServeFormatted(c.Data, hasIndent, hasEncoding) } // Input returns the input data map from POST or PUT request body and query string. diff --git a/error.go b/error.go index 7a50a5b3..727830df 100644 --- a/error.go +++ b/error.go @@ -361,7 +361,7 @@ func gatewayTimeout(rw http.ResponseWriter, r *http.Request) { func responseError(rw http.ResponseWriter, r *http.Request, errCode int, errContent string) { t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := map[string]interface{}{ + data := M{ "Title": http.StatusText(errCode), "BeegoVersion": VERSION, "Content": template.HTML(errContent), diff --git a/fs.go b/fs.go new file mode 100644 index 00000000..bf7002ad --- /dev/null +++ b/fs.go @@ -0,0 +1,74 @@ +package beego + +import ( + "net/http" + "os" + "path/filepath" +) + +type FileSystem struct { +} + +func (d FileSystem) Open(name string) (http.File, error) { + return os.Open(name) +} + +// Walk walks the file tree rooted at root in filesystem, calling walkFn for each file or +// directory in the tree, including root. All errors that arise visiting files +// and directories are filtered by walkFn. +func Walk(fs http.FileSystem, root string, walkFn filepath.WalkFunc) error { + + f, err := fs.Open(root) + if err != nil { + return err + } + info, err := f.Stat() + if err != nil { + err = walkFn(root, nil, err) + } else { + err = walk(fs, root, info, walkFn) + } + if err == filepath.SkipDir { + return nil + } + return err +} + +// walk recursively descends path, calling walkFn. +func walk(fs http.FileSystem, path string, info os.FileInfo, walkFn filepath.WalkFunc) error { + var err error + if !info.IsDir() { + return walkFn(path, info, nil) + } + + dir, err := fs.Open(path) + defer dir.Close() + if err != nil { + if err1 := walkFn(path, info, err); err1 != nil { + return err1 + } + return err + } + dirs, err := dir.Readdir(-1) + err1 := walkFn(path, info, err) + // If err != nil, walk can't walk into this directory. + // err1 != nil means walkFn want walk to skip this directory or stop walking. + // Therefore, if one of err and err1 isn't nil, walk will return. + if err != nil || err1 != nil { + // The caller's behavior is controlled by the return value, which is decided + // by walkFn. walkFn may ignore err and return nil. + // If walkFn returns SkipDir, it will be handled by the caller. + // So walk should return whatever walkFn returns. + return err1 + } + + for _, fileInfo := range dirs { + filename := filepath.Join(path, fileInfo.Name()) + if err = walk(fs, filename, fileInfo, walkFn); err != nil { + if !fileInfo.IsDir() || err != filepath.SkipDir { + return err + } + } + } + return nil +} diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..a3314457 --- /dev/null +++ b/go.mod @@ -0,0 +1,40 @@ +module github.com/astaxie/beego + +require ( + github.com/BurntSushi/toml v0.3.1 // indirect + github.com/Knetic/govaluate v3.0.0+incompatible // indirect + github.com/beego/goyaml2 v0.0.0-20130207012346-5545475820dd + github.com/beego/x2j v0.0.0-20131220205130-a0352aadc542 + github.com/belogik/goes v0.0.0-20151229125003-e54d722c3aff + github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737 + github.com/casbin/casbin v1.6.0 + github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58 + github.com/couchbase/go-couchbase v0.0.0-20181019154153-595f46701cbc + github.com/couchbase/gomemcached v0.0.0-20180723192129-20e69a1ee160 // indirect + github.com/couchbase/goutils v0.0.0-20180530154633-e865a1461c8a // indirect + github.com/cupcake/rdb v0.0.0-20161107195141-43ba34106c76 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/edsrzf/mmap-go v0.0.0-20170320065105-0bce6a688712 // indirect + github.com/elazarl/go-bindata-assetfs v0.0.0-20180223110309-38087fe4dafb + github.com/go-redis/redis v6.14.2+incompatible + github.com/go-sql-driver/mysql v1.4.0 + github.com/gogo/protobuf v1.1.1 + github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect + github.com/gomodule/redigo v2.0.0+incompatible + github.com/lib/pq v1.0.0 + github.com/mattn/go-sqlite3 v1.10.0 + github.com/onsi/gomega v1.4.2 // indirect + github.com/pelletier/go-toml v1.2.0 // indirect + github.com/pkg/errors v0.8.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726 // indirect + github.com/siddontang/ledisdb v0.0.0-20181029004158-becf5f38d373 + github.com/siddontang/rdb v0.0.0-20150307021120-fc89ed2e418d // indirect + github.com/ssdb/gossdb v0.0.0-20180723034631-88f6b59b84ec + github.com/stretchr/testify v1.2.2 // indirect + github.com/syndtr/goleveldb v0.0.0-20181105012736-f9080354173f // indirect + github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b // indirect + golang.org/x/crypto v0.0.0-20180723164146-c126467f60eb + google.golang.org/appengine v1.1.0 // indirect + gopkg.in/yaml.v2 v2.2.1 +) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..547836e9 --- /dev/null +++ b/go.sum @@ -0,0 +1,94 @@ +github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/Knetic/govaluate v3.0.0+incompatible h1:7o6+MAPhYTCF0+fdvoz1xDedhRb4f6s9Tn1Tt7/WTEg= +github.com/Knetic/govaluate v3.0.0+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= +github.com/beego/goyaml2 v0.0.0-20130207012346-5545475820dd h1:jZtX5jh5IOMu0fpOTC3ayh6QGSPJ/KWOv1lgPvbRw1M= +github.com/beego/goyaml2 v0.0.0-20130207012346-5545475820dd/go.mod h1:1b+Y/CofkYwXMUU0OhQqGvsY2Bvgr4j6jfT699wyZKQ= +github.com/beego/x2j v0.0.0-20131220205130-a0352aadc542 h1:nYXb+3jF6Oq/j8R/y90XrKpreCxIalBWfeyeKymgOPk= +github.com/beego/x2j v0.0.0-20131220205130-a0352aadc542/go.mod h1:kSeGC/p1AbBiEp5kat81+DSQrZenVBZXklMLaELspWU= +github.com/belogik/goes v0.0.0-20151229125003-e54d722c3aff h1:/kO0p2RTGLB8R5gub7ps0GmYpB2O8LXEoPq8tzFDCUI= +github.com/belogik/goes v0.0.0-20151229125003-e54d722c3aff/go.mod h1:PhH1ZhyCzHKt4uAasyx+ljRCgoezetRNf59CUtwUkqY= +github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737 h1:rRISKWyXfVxvoa702s91Zl5oREZTrR3yv+tXrrX7G/g= +github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737/go.mod h1:PmM6Mmwb0LSuEubjR8N7PtNe1KxZLtOUHtbeikc5h60= +github.com/casbin/casbin v1.6.0 h1:uIhuV5I0ilXGUm3y+xJ8nG7VOnYDeZZQiNsFOTF2QmI= +github.com/casbin/casbin v1.6.0/go.mod h1:c67qKN6Oum3UF5Q1+BByfFxkwKvhwW57ITjqwtzR1KE= +github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58 h1:F1EaeKL/ta07PY/k9Os/UFtwERei2/XzGemhpGnBKNg= +github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58/go.mod h1:EOBUe0h4xcZ5GoxqC5SDxFQ8gwyZPKQoEzownBlhI80= +github.com/couchbase/go-couchbase v0.0.0-20181019154153-595f46701cbc h1:Byzmalcea3rzOdgt4Ny3xrtXkd25zUMPFI5oeKksSbU= +github.com/couchbase/go-couchbase v0.0.0-20181019154153-595f46701cbc/go.mod h1:TWI8EKQMs5u5jLKW/tsb9VwauIrMIxQG1r5fMsswK5U= +github.com/couchbase/gomemcached v0.0.0-20180723192129-20e69a1ee160 h1:yaqs73s76owCkJbPZo8GKSosZoMjezdLDslJ8aaDk0w= +github.com/couchbase/gomemcached v0.0.0-20180723192129-20e69a1ee160/go.mod h1:srVSlQLB8iXBVXHgnqemxUXqN6FCvClgCMPCsjBDR7c= +github.com/couchbase/goutils v0.0.0-20180530154633-e865a1461c8a h1:Y5XsLCEhtEI8qbD9RP3Qlv5FXdTDHxZM9UPUnMRgBp8= +github.com/couchbase/goutils v0.0.0-20180530154633-e865a1461c8a/go.mod h1:BQwMFlJzDjFDG3DJUdU0KORxn88UlsOULuxLExMh3Hs= +github.com/cupcake/rdb v0.0.0-20161107195141-43ba34106c76 h1:Lgdd/Qp96Qj8jqLpq2cI1I1X7BJnu06efS+XkhRoLUQ= +github.com/cupcake/rdb v0.0.0-20161107195141-43ba34106c76/go.mod h1:vYwsqCOLxGiisLwp9rITslkFNpZD5rz43tf41QFkTWY= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/edsrzf/mmap-go v0.0.0-20170320065105-0bce6a688712 h1:aaQcKT9WumO6JEJcRyTqFVq4XUZiUcKR2/GI31TOcz8= +github.com/edsrzf/mmap-go v0.0.0-20170320065105-0bce6a688712/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M= +github.com/elazarl/go-bindata-assetfs v0.0.0-20180223110309-38087fe4dafb h1:T6FhFH6fLQPEu7n7PauDhb4mhpxhlfaL7a7MZEpIgDc= +github.com/elazarl/go-bindata-assetfs v0.0.0-20180223110309-38087fe4dafb/go.mod h1:v+YaWX3bdea5J/mo8dSETolEo7R71Vk1u8bnjau5yw4= +github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/go-redis/redis v6.14.2+incompatible h1:UE9pLhzmWf+xHNmZsoccjXosPicuiNaInPgym8nzfg0= +github.com/go-redis/redis v6.14.2+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= +github.com/go-sql-driver/mysql v1.4.0 h1:7LxgVwFb2hIQtMm87NdgAVfXjnt4OePseqT1tKx+opk= +github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/gogo/protobuf v1.1.1 h1:72R+M5VuhED/KujmZVcIquuo8mBgX4oVda//DQb3PXo= +github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w= +github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0= +github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= +github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A= +github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o= +github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/onsi/ginkgo v1.6.0 h1:Ix8l273rp3QzYgXSR+c8d1fTG7UPgYkOSELPhiY/YGw= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v1.4.2 h1:3mYCb7aPxS/RU7TI1y4rkEn1oKmPRjNJLNEXgw7MH2I= +github.com/onsi/gomega v1.4.2/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/pelletier/go-toml v1.2.0 h1:T5zMGML61Wp+FlcbWjRDT7yAxhJNAiPPLOFECq181zc= +github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= +github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726 h1:xT+JlYxNGqyT+XcU8iUrN18JYed2TvG9yN5ULG2jATM= +github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726/go.mod h1:3yhqj7WBBfRhbBlzyOC3gUxftwsU0u8gqevxwIHQpMw= +github.com/siddontang/ledisdb v0.0.0-20181029004158-becf5f38d373 h1:p6IxqQMjab30l4lb9mmkIkkcE1yv6o0SKbPhW5pxqHI= +github.com/siddontang/ledisdb v0.0.0-20181029004158-becf5f38d373/go.mod h1:mF1DpOSOUiJRMR+FDqaqu3EBqrybQtrDDszLUZ6oxPg= +github.com/siddontang/rdb v0.0.0-20150307021120-fc89ed2e418d h1:NVwnfyR3rENtlz62bcrkXME3INVUa4lcdGt+opvxExs= +github.com/siddontang/rdb v0.0.0-20150307021120-fc89ed2e418d/go.mod h1:AMEsy7v5z92TR1JKMkLLoaOQk++LVnOKL3ScbJ8GNGA= +github.com/ssdb/gossdb v0.0.0-20180723034631-88f6b59b84ec h1:q6XVwXmKvCRHRqesF3cSv6lNqqHi0QWOvgDlSohg8UA= +github.com/ssdb/gossdb v0.0.0-20180723034631-88f6b59b84ec/go.mod h1:QBvMkMya+gXctz3kmljlUCu/yB3GZ6oee+dUozsezQE= +github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/syndtr/goleveldb v0.0.0-20181105012736-f9080354173f h1:EEVjSRihF8NIbfyCcErpSpNHEKrY3s8EAwqiPENZZn8= +github.com/syndtr/goleveldb v0.0.0-20181105012736-f9080354173f/go.mod h1:Z4AUp2Km+PwemOoO/VB5AOx9XSsIItzFjoJlOSiYmn0= +github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b h1:0Ve0/CCjiAiyKddUMUn3RwIGlq2iTW4GuVzyoKBYO/8= +github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b/go.mod h1:Q12BUT7DqIlHRmgv3RskH+UCM/4eqVMgI0EMmlSpAXc= +golang.org/x/crypto v0.0.0-20180723164146-c126467f60eb h1:Ah9YqXLj6fEgeKqcmBuLCbAsrF3ScD7dJ/bYM0C6tXI= +golang.org/x/crypto v0.0.0-20180723164146-c126467f60eb/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd h1:nTDtHvHSdCn1m6ITfMRqtOd/9+7a3s8RBNOZ3eYZzJA= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e h1:o3PsSEY8E4eXWkXrIP9YJALUkVZqzHJT5DOasTyn8Vs= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +google.golang.org/appengine v1.1.0 h1:igQkv0AAhEIvTEpD5LIpAfav2eeVO9HBTjvKHVJPRSs= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/httplib/httplib_test.go b/httplib/httplib_test.go index 32d3e7f6..8970b764 100644 --- a/httplib/httplib_test.go +++ b/httplib/httplib_test.go @@ -16,6 +16,8 @@ package httplib import ( "io/ioutil" + "net" + "net/http" "os" "strings" "testing" @@ -161,7 +163,16 @@ func TestWithSetting(t *testing.T) { var setting BeegoHTTPSettings setting.EnableCookie = true setting.UserAgent = v - setting.Transport = nil + setting.Transport = &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 50, + IdleConnTimeout: 90 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } setting.ReadWriteTimeout = 5 * time.Second SetDefaultSetting(setting) diff --git a/orm/db.go b/orm/db.go index 87d08df6..cb807b5a 100644 --- a/orm/db.go +++ b/orm/db.go @@ -928,7 +928,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi maps[fi.column] = true } } else { - panic(fmt.Errorf("wrong field/column name `%s`", col)) + return 0, fmt.Errorf("wrong field/column name `%s`", col) } } if hasRel { diff --git a/orm/db_tables.go b/orm/db_tables.go index 42be5550..4b21a6fc 100644 --- a/orm/db_tables.go +++ b/orm/db_tables.go @@ -372,7 +372,13 @@ func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (whe operator = "exact" } - operSQL, args := t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz) + var operSQL string + var args []interface{} + if p.isRaw { + operSQL = p.sql + } else { + operSQL, args = t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz) + } leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q) t.base.GenerateOperatorLeftCol(fi, operator, &leftCol) diff --git a/orm/orm.go b/orm/orm.go index e59589aa..b00c974e 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -548,6 +548,9 @@ func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { al.Name = aliasName al.DriverName = driverName + al.DB = db + + detectTZ(al) o := new(orm) o.alias = al diff --git a/orm/orm_conds.go b/orm/orm_conds.go index f6e389ec..f3fd66f0 100644 --- a/orm/orm_conds.go +++ b/orm/orm_conds.go @@ -31,6 +31,8 @@ type condValue struct { isOr bool isNot bool isCond bool + isRaw bool + sql string } // Condition struct. @@ -45,6 +47,15 @@ func NewCondition() *Condition { return c } +// Raw add raw sql to condition +func (c Condition) Raw(expr string, sql string) *Condition { + if len(sql) == 0 { + panic(fmt.Errorf(" sql cannot empty")) + } + c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), sql: sql, isRaw: true}) + return &c +} + // And add expression to condition func (c Condition) And(expr string, args ...interface{}) *Condition { if expr == "" || len(args) == 0 { diff --git a/orm/orm_queryset.go b/orm/orm_queryset.go index 8b5d6fd2..b2657c6d 100644 --- a/orm/orm_queryset.go +++ b/orm/orm_queryset.go @@ -79,6 +79,15 @@ func (o querySet) Filter(expr string, args ...interface{}) QuerySeter { return &o } +// add raw sql to querySeter. +func (o querySet) FilterRaw(expr string, sql string) QuerySeter { + if o.cond == nil { + o.cond = NewCondition() + } + o.cond = o.cond.Raw(expr, sql) + return &o +} + // add NOT condition to querySeter. func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter { if o.cond == nil { @@ -198,11 +207,7 @@ func (o *querySet) PrepareInsert() (Inserter, error) { // query all data and map to containers. // cols means the columns when querying. func (o *querySet) All(container interface{}, cols ...string) (int64, error) { - num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) - if num == 0 { - return 0, ErrNoRows - } - return num, err + return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) } // query one row data and map to containers. diff --git a/orm/orm_test.go b/orm/orm_test.go index cdb2fe3f..89a714b6 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -899,6 +899,18 @@ func TestOperators(t *testing.T) { num, err = qs.Filter("id__between", []int{2, 3}).Count() throwFail(t, err) throwFail(t, AssertIs(num, 2)) + + num, err = qs.FilterRaw("user_name", "= 'slene'").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.FilterRaw("status", "IN (1, 2)").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.FilterRaw("profile_id", "IN (SELECT id FROM user_profile WHERE age=30)").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) } func TestSetCond(t *testing.T) { @@ -924,6 +936,11 @@ func TestSetCond(t *testing.T) { num, err = qs.SetCond(cond4).Count() throwFail(t, err) throwFail(t, AssertIs(num, 3)) + + cond5 := cond.Raw("user_name", "= 'slene'").OrNotCond(cond.And("user_name", "slene")) + num, err = qs.SetCond(cond5).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) } func TestLimit(t *testing.T) { @@ -1011,13 +1028,13 @@ func TestAll(t *testing.T) { qs = dORM.QueryTable("user") num, err = qs.Filter("user_name", "nothing").All(&users) - throwFailNow(t, AssertIs(err, ErrNoRows)) + throwFailNow(t, err) throwFailNow(t, AssertIs(num, 0)) var users3 []*User qs = dORM.QueryTable("user") num, err = qs.Filter("user_name", "nothing").All(&users3) - throwFailNow(t, AssertIs(err, ErrNoRows)) + throwFailNow(t, err) throwFailNow(t, AssertIs(num, 0)) throwFailNow(t, AssertIs(users3 == nil, false)) } diff --git a/orm/types.go b/orm/types.go index 2fdc98c7..8ba4099f 100644 --- a/orm/types.go +++ b/orm/types.go @@ -147,6 +147,11 @@ type QuerySeter interface { // // time compare // qs.Filter("created", time.Now()) Filter(string, ...interface{}) QuerySeter + // add raw sql to querySeter. + // for example: + // qs.FilterRaw("user_id IN (SELECT id FROM profile WHERE age>=18)") + // //sql-> WHERE user_id IN (SELECT id FROM profile WHERE age>=18) + FilterRaw(string, string) QuerySeter // add NOT condition to querySeter. // have the same usage as Filter Exclude(string, ...interface{}) QuerySeter diff --git a/orm/utils.go b/orm/utils.go index 669d4734..b76bb2e3 100644 --- a/orm/utils.go +++ b/orm/utils.go @@ -198,7 +198,7 @@ func ToInt64(value interface{}) (d int64) { return } -// snake string, XxYy to xx_yy , XxYY to xx_yy +// snake string, XxYy to xx_yy , XxYY to xx_y_y func snakeString(s string) string { data := make([]byte, 0, len(s)*2) j := false diff --git a/orm/utils_test.go b/orm/utils_test.go index 8c7c5008..11e76687 100644 --- a/orm/utils_test.go +++ b/orm/utils_test.go @@ -34,3 +34,20 @@ func TestCamelString(t *testing.T) { } } } + +func TestSnakeString(t *testing.T) { + camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} + snake := []string{"pic_url", "hello_world", "hello_world", "hel_l_o_word", "pic_url1", "xy_x_x"} + + answer := make(map[string]string) + for i, v := range camel { + answer[v] = snake[i] + } + + for _, v := range camel { + res := snakeString(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} diff --git a/parser.go b/parser.go index 2ac48b85..a8690274 100644 --- a/parser.go +++ b/parser.go @@ -39,7 +39,7 @@ var globalRouterTemplate = `package routers import ( "github.com/astaxie/beego" - "github.com/astaxie/beego/context/param" + "github.com/astaxie/beego/context/param"{{.globalimport}} ) func init() { @@ -52,6 +52,22 @@ var ( commentFilename string pkgLastupdate map[string]int64 genInfoList map[string][]ControllerComments + + routerHooks = map[string]int{ + "beego.BeforeStatic": BeforeStatic, + "beego.BeforeRouter": BeforeRouter, + "beego.BeforeExec": BeforeExec, + "beego.AfterExec": AfterExec, + "beego.FinishRouter": FinishRouter, + } + + routerHooksMapping = map[int]string{ + BeforeStatic: "beego.BeforeStatic", + BeforeRouter: "beego.BeforeRouter", + BeforeExec: "beego.BeforeExec", + AfterExec: "beego.AfterExec", + FinishRouter: "beego.FinishRouter", + } ) const commentPrefix = "commentsRouter_" @@ -102,6 +118,20 @@ type parsedComment struct { routerPath string methods []string params map[string]parsedParam + filters []parsedFilter + imports []parsedImport +} + +type parsedImport struct { + importPath string + importAlias string +} + +type parsedFilter struct { + pattern string + pos int + filter string + params []bool } type parsedParam struct { @@ -126,6 +156,8 @@ func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error { cc.Router = parsedComment.routerPath cc.AllowHTTPMethods = parsedComment.methods cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment) + cc.FilterComments = buildFilters(parsedComment.filters) + cc.ImportComments = buildImports(parsedComment.imports) genInfoList[key] = append(genInfoList[key], cc) } } @@ -133,6 +165,48 @@ func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error { return nil } +func buildImports(pis []parsedImport) []*ControllerImportComments { + var importComments []*ControllerImportComments + + for _, pi := range pis { + importComments = append(importComments, &ControllerImportComments{ + ImportPath: pi.importPath, + ImportAlias: pi.importAlias, + }) + } + + return importComments +} + +func buildFilters(pfs []parsedFilter) []*ControllerFilterComments { + var filterComments []*ControllerFilterComments + + for _, pf := range pfs { + var ( + returnOnOutput bool + resetParams bool + ) + + if len(pf.params) >= 1 { + returnOnOutput = pf.params[0] + } + + if len(pf.params) >= 2 { + resetParams = pf.params[1] + } + + filterComments = append(filterComments, &ControllerFilterComments{ + Filter: pf.filter, + Pattern: pf.pattern, + Pos: pf.pos, + ReturnOnOutput: returnOnOutput, + ResetParams: resetParams, + }) + } + + return filterComments +} + func buildMethodParams(funcParams []*ast.Field, pc *parsedComment) []*param.MethodParam { result := make([]*param.MethodParam, 0, len(funcParams)) for _, fparam := range funcParams { @@ -181,6 +255,8 @@ var routeRegex = regexp.MustCompile(`@router\s+(\S+)(?:\s+\[(\S+)\])?`) func parseComment(lines []*ast.Comment) (pcs []*parsedComment, err error) { pcs = []*parsedComment{} params := map[string]parsedParam{} + filters := []parsedFilter{} + imports := []parsedImport{} for _, c := range lines { t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) @@ -209,9 +285,69 @@ func parseComment(lines []*ast.Comment) (pcs []*parsedComment, err error) { } } + for _, c := range lines { + t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) + if strings.HasPrefix(t, "@Import") { + iv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Import"))) + if len(iv) == 0 || len(iv) > 2 { + logs.Error("Invalid @Import format. Only accepts 1 or 2 parameters") + continue + } + + p := parsedImport{} + p.importPath = iv[0] + + if len(iv) == 2 { + p.importAlias = iv[1] + } + + imports = append(imports, p) + } + } + +filterLoop: + for _, c := range lines { + t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) + if strings.HasPrefix(t, "@Filter") { + fv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Filter"))) + if len(fv) < 3 { + logs.Error("Invalid @Filter format. Needs at least 3 parameters") + continue filterLoop + } + + p := parsedFilter{} + p.pattern = fv[0] + posName := fv[1] + if pos, exists := routerHooks[posName]; exists { + p.pos = pos + } else { + logs.Error("Invalid @Filter pos: ", posName) + continue filterLoop + } + + p.filter = fv[2] + fvParams := fv[3:] + for _, fvParam := range fvParams { + switch fvParam { + case "true": + p.params = append(p.params, true) + case "false": + p.params = append(p.params, false) + default: + logs.Error("Invalid @Filter param: ", fvParam) + continue filterLoop + } + } + + filters = append(filters, p) + } + } + for _, c := range lines { var pc = &parsedComment{} pc.params = params + pc.filters = filters + pc.imports = imports t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) if strings.HasPrefix(t, "@router") { @@ -276,8 +412,9 @@ func genRouterCode(pkgRealpath string) { os.Mkdir(getRouterDir(pkgRealpath), 0755) logs.Info("generate router from comments") var ( - globalinfo string - sortKey []string + globalinfo string + globalimport string + sortKey []string ) for k := range genInfoList { sortKey = append(sortKey, k) @@ -295,6 +432,7 @@ func genRouterCode(pkgRealpath string) { } allmethod = strings.TrimRight(allmethod, ",") + "}" } + params := "nil" if len(c.Params) > 0 { params = "[]map[string]string{" @@ -305,6 +443,7 @@ func genRouterCode(pkgRealpath string) { } params = strings.TrimRight(params, ",") + "}" } + methodParams := "param.Make(" if len(c.MethodParams) > 0 { lines := make([]string, 0, len(c.MethodParams)) @@ -316,24 +455,66 @@ func genRouterCode(pkgRealpath string) { ",\n " } methodParams += ")" + + imports := "" + if len(c.ImportComments) > 0 { + for _, i := range c.ImportComments { + if i.ImportAlias != "" { + imports += fmt.Sprintf(` + %s "%s"`, i.ImportAlias, i.ImportPath) + } else { + imports += fmt.Sprintf(` + "%s"`, i.ImportPath) + } + } + } + + filters := "" + if len(c.FilterComments) > 0 { + for _, f := range c.FilterComments { + filters += fmt.Sprintf(` &beego.ControllerFilter{ + Pattern: "%s", + Pos: %s, + Filter: %s, + ReturnOnOutput: %v, + ResetParams: %v, + },`, f.Pattern, routerHooksMapping[f.Pos], f.Filter, f.ReturnOnOutput, f.ResetParams) + } + } + + if filters == "" { + filters = "nil" + } else { + filters = fmt.Sprintf(`[]*beego.ControllerFilter{ +%s + }`, filters) + } + + globalimport = imports + globalinfo = globalinfo + ` - beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"], - beego.ControllerComments{ - Method: "` + strings.TrimSpace(c.Method) + `", - ` + "Router: `" + c.Router + "`" + `, - AllowHTTPMethods: ` + allmethod + `, - MethodParams: ` + methodParams + `, - Params: ` + params + `}) + beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"], + beego.ControllerComments{ + Method: "` + strings.TrimSpace(c.Method) + `", + ` + "Router: `" + c.Router + "`" + `, + AllowHTTPMethods: ` + allmethod + `, + MethodParams: ` + methodParams + `, + Filters: ` + filters + `, + Params: ` + params + `}) ` } } + if globalinfo != "" { f, err := os.Create(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) if err != nil { panic(err) } defer f.Close() - f.WriteString(strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1)) + + content := strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1) + content = strings.Replace(content, "{{.globalimport}}", globalimport, -1) + f.WriteString(content) } } diff --git a/router.go b/router.go index 2e0cea25..4ecac6bc 100644 --- a/router.go +++ b/router.go @@ -43,7 +43,7 @@ const ( ) const ( - routerTypeBeego = iota + routerTypeBeego = iota routerTypeRESTFul routerTypeHandler ) @@ -277,6 +277,10 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) { key := t.PkgPath() + ":" + t.Name() if comm, ok := GlobalControllerRouter[key]; ok { for _, a := range comm { + for _, f := range a.Filters { + p.InsertFilter(f.Pattern, f.Pos, f.Filter, f.ReturnOnOutput, f.ResetParams) + } + p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method) } } @@ -794,7 +798,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) if !isRunnable { //Invoke the request handler var execController ControllerInterface - if routerInfo.initialize != nil { + if routerInfo != nil && routerInfo.initialize != nil { execController = routerInfo.initialize() } else { vc := reflect.New(runRouter) @@ -877,7 +881,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) } Admin: -//admin module record QPS + //admin module record QPS statusCode := context.ResponseWriter.Status if statusCode == 0 { diff --git a/session/sess_file.go b/session/sess_file.go index 53e19811..c089dade 100644 --- a/session/sess_file.go +++ b/session/sess_file.go @@ -21,6 +21,7 @@ import ( "os" "path" "path/filepath" + "strings" "sync" "time" ) @@ -127,6 +128,9 @@ func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error { // if file is not exist, create it. // the file path is generated from sid string. func (fp *FileProvider) SessionRead(sid string) (Store, error) { + if strings.ContainsAny(sid, "./") { + return nil, nil + } filepder.lock.Lock() defer filepder.lock.Unlock() diff --git a/session/session.go b/session/session.go index cf647521..c7e7dc69 100644 --- a/session/session.go +++ b/session/session.go @@ -96,6 +96,7 @@ type ManagerConfig struct { EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"` SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"` EnableSidInURLQuery bool `json:"EnableSidInURLQuery"` + SessionIDPrefix string `json:"sessionIDPrefix"` } // Manager contains Provider and its configuration. @@ -153,6 +154,11 @@ func NewManager(provideName string, cf *ManagerConfig) (*Manager, error) { }, nil } +// GetProvider return current manager's provider +func (manager *Manager) GetProvider() Provider { + return manager.provider +} + // getSid retrieves session identifier from HTTP Request. // First try to retrieve id by reading from cookie, session cookie name is configurable, // if not exist, then retrieve id from querying parameters. @@ -331,7 +337,7 @@ func (manager *Manager) sessionID() (string, error) { if n != len(b) || err != nil { return "", fmt.Errorf("Could not successfully read from the system CSPRNG") } - return hex.EncodeToString(b), nil + return manager.config.SessionIDPrefix + hex.EncodeToString(b), nil } // Set cookie with https. diff --git a/staticfile.go b/staticfile.go index 30ee3dfb..68241a86 100644 --- a/staticfile.go +++ b/staticfile.go @@ -178,7 +178,7 @@ func searchFile(ctx *context.Context) (string, os.FileInfo, error) { if !strings.Contains(requestPath, prefix) { continue } - if len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' { + if prefix != "/" && len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' { continue } filePath := path.Join(staticDir, requestPath[len(prefix):]) diff --git a/template.go b/template.go index 88c84e2f..d673a54d 100644 --- a/template.go +++ b/template.go @@ -20,6 +20,7 @@ import ( "html/template" "io" "io/ioutil" + "net/http" "os" "path/filepath" "regexp" @@ -40,6 +41,7 @@ var ( beeTemplateExt = []string{"tpl", "html"} // beeTemplatePreprocessors stores associations of extension -> preprocessor handler beeTemplateEngines = map[string]templatePreProcessor{} + beeTemplateFS = defaultFSFunc ) // ExecuteTemplate applies the template with name to the specified data object, @@ -181,12 +183,17 @@ func lockViewPaths() { // BuildTemplate will build all template files in a directory. // it makes beego can render any template file in view directory. func BuildTemplate(dir string, files ...string) error { - if _, err := os.Stat(dir); err != nil { + var err error + fs := beeTemplateFS() + f, err := fs.Open(dir) + defer f.Close() + if err != nil { if os.IsNotExist(err) { return nil } return errors.New("dir open err") } + beeTemplates, ok := beeViewPathTemplates[dir] if !ok { panic("Unknown view path: " + dir) @@ -195,11 +202,11 @@ func BuildTemplate(dir string, files ...string) error { root: dir, files: make(map[string][]string), } - err := filepath.Walk(dir, func(path string, f os.FileInfo, err error) error { + err = Walk(fs, dir, func(path string, f os.FileInfo, err error) error { return self.visit(path, f, err) }) if err != nil { - fmt.Printf("filepath.Walk() returned %v\n", err) + fmt.Printf("Walk() returned %v\n", err) return err } buildAllFiles := len(files) == 0 @@ -210,11 +217,11 @@ func BuildTemplate(dir string, files ...string) error { ext := filepath.Ext(file) var t *template.Template if len(ext) == 0 { - t, err = getTemplate(self.root, file, v...) + t, err = getTemplate(self.root, fs, file, v...) } else if fn, ok := beeTemplateEngines[ext[1:]]; ok { t, err = fn(self.root, file, beegoTplFuncMap) } else { - t, err = getTemplate(self.root, file, v...) + t, err = getTemplate(self.root, fs, file, v...) } if err != nil { logs.Error("parse template err:", file, err) @@ -229,9 +236,10 @@ func BuildTemplate(dir string, files ...string) error { return nil } -func getTplDeep(root, file, parent string, t *template.Template) (*template.Template, [][]string, error) { +func getTplDeep(root string, fs http.FileSystem, file string, parent string, t *template.Template) (*template.Template, [][]string, error) { var fileAbsPath string var rParent string + var err error if filepath.HasPrefix(file, "../") { rParent = filepath.Join(filepath.Dir(parent), file) fileAbsPath = filepath.Join(root, filepath.Dir(parent), file) @@ -239,10 +247,12 @@ func getTplDeep(root, file, parent string, t *template.Template) (*template.Temp rParent = file fileAbsPath = filepath.Join(root, file) } - if e := utils.FileExists(fileAbsPath); !e { + f, err := fs.Open(fileAbsPath) + defer f.Close() + if err != nil { panic("can't find template file:" + file) } - data, err := ioutil.ReadFile(fileAbsPath) + data, err := ioutil.ReadAll(f) if err != nil { return nil, [][]string{}, err } @@ -261,7 +271,7 @@ func getTplDeep(root, file, parent string, t *template.Template) (*template.Temp if !HasTemplateExt(m[1]) { continue } - _, _, err = getTplDeep(root, m[1], rParent, t) + _, _, err = getTplDeep(root, fs, m[1], rParent, t) if err != nil { return nil, [][]string{}, err } @@ -270,14 +280,14 @@ func getTplDeep(root, file, parent string, t *template.Template) (*template.Temp return t, allSub, nil } -func getTemplate(root, file string, others ...string) (t *template.Template, err error) { +func getTemplate(root string, fs http.FileSystem, file string, others ...string) (t *template.Template, err error) { t = template.New(file).Delims(BConfig.WebConfig.TemplateLeft, BConfig.WebConfig.TemplateRight).Funcs(beegoTplFuncMap) var subMods [][]string - t, subMods, err = getTplDeep(root, file, "", t) + t, subMods, err = getTplDeep(root, fs, file, "", t) if err != nil { return nil, err } - t, err = _getTemplate(t, root, subMods, others...) + t, err = _getTemplate(t, root, fs, subMods, others...) if err != nil { return nil, err @@ -285,7 +295,7 @@ func getTemplate(root, file string, others ...string) (t *template.Template, err return } -func _getTemplate(t0 *template.Template, root string, subMods [][]string, others ...string) (t *template.Template, err error) { +func _getTemplate(t0 *template.Template, root string, fs http.FileSystem, subMods [][]string, others ...string) (t *template.Template, err error) { t = t0 for _, m := range subMods { if len(m) == 2 { @@ -297,11 +307,11 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others for _, otherFile := range others { if otherFile == m[1] { var subMods1 [][]string - t, subMods1, err = getTplDeep(root, otherFile, "", t) + t, subMods1, err = getTplDeep(root, fs, otherFile, "", t) if err != nil { logs.Trace("template parse file err:", err) } else if len(subMods1) > 0 { - t, err = _getTemplate(t, root, subMods1, others...) + t, err = _getTemplate(t, root, fs, subMods1, others...) } break } @@ -310,8 +320,16 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others for _, otherFile := range others { var data []byte fileAbsPath := filepath.Join(root, otherFile) - data, err = ioutil.ReadFile(fileAbsPath) + f, err := fs.Open(fileAbsPath) if err != nil { + f.Close() + logs.Trace("template file parse error, not success open file:", err) + continue + } + data, err = ioutil.ReadAll(f) + f.Close() + if err != nil { + logs.Trace("template file parse error, not success read file:", err) continue } reg := regexp.MustCompile(BConfig.WebConfig.TemplateLeft + "[ ]*define[ ]+\"([^\"]+)\"") @@ -319,11 +337,14 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others for _, sub := range allSub { if len(sub) == 2 && sub[1] == m[1] { var subMods1 [][]string - t, subMods1, err = getTplDeep(root, otherFile, "", t) + t, subMods1, err = getTplDeep(root, fs, otherFile, "", t) if err != nil { logs.Trace("template parse file err:", err) } else if len(subMods1) > 0 { - t, err = _getTemplate(t, root, subMods1, others...) + t, err = _getTemplate(t, root, fs, subMods1, others...) + if err != nil { + logs.Trace("template parse file err:", err) + } } break } @@ -335,6 +356,15 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others return } +type templateFSFunc func() http.FileSystem + +func defaultFSFunc() http.FileSystem { + return FileSystem{} +} +func SetTemplateFSFunc(fnt templateFSFunc) { + beeTemplateFS = fnt +} + // SetViewsPath sets view directory path in beego application. func SetViewsPath(path string) *App { BConfig.WebConfig.ViewsPath = path diff --git a/template_test.go b/template_test.go index 2153ef72..287faadc 100644 --- a/template_test.go +++ b/template_test.go @@ -16,6 +16,9 @@ package beego import ( "bytes" + "github.com/astaxie/beego/testdata" + "github.com/elazarl/go-bindata-assetfs" + "net/http" "os" "path/filepath" "testing" @@ -256,3 +259,58 @@ func TestTemplateLayout(t *testing.T) { } os.RemoveAll(dir) } + +type TestingFileSystem struct { + assetfs *assetfs.AssetFS +} + +func (d TestingFileSystem) Open(name string) (http.File, error) { + return d.assetfs.Open(name) +} + +var outputBinData = ` + + + beego welcome template + + + + +

Hello, blocks!

+ + +

Hello, astaxie!

+ + + +

Hello

+

This is SomeVar: val

+ + +` + +func TestFsBinData(t *testing.T) { + SetTemplateFSFunc(func() http.FileSystem { + return TestingFileSystem{&assetfs.AssetFS{Asset: testdata.Asset, AssetDir: testdata.AssetDir, AssetInfo: testdata.AssetInfo}} + }) + dir := "views" + if err := AddViewPath("views"); err != nil { + t.Fatal(err) + } + beeTemplates := beeViewPathTemplates[dir] + if len(beeTemplates) != 3 { + t.Fatalf("should be 3 but got %v", len(beeTemplates)) + } + if err := beeTemplates["index.tpl"].ExecuteTemplate(os.Stdout, "index.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil { + t.Fatal(err) + } + out := bytes.NewBufferString("") + if err := beeTemplates["index.tpl"].ExecuteTemplate(out, "index.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil { + t.Fatal(err) + } + + if out.String() != outputBinData { + t.Log(out.String()) + t.Fatal("Compare failed") + } +} diff --git a/templatefunc.go b/templatefunc.go index e4d4667a..8c1504aa 100644 --- a/templatefunc.go +++ b/templatefunc.go @@ -692,7 +692,7 @@ func ge(arg1, arg2 interface{}) (bool, error) { // MapGet getting value from map by keys // usage: -// Data["m"] = map[string]interface{} { +// Data["m"] = M{ // "a": 1, // "1": map[string]float64{ // "c": 4, diff --git a/templatefunc_test.go b/templatefunc_test.go index aa82c26f..c7b8fbd3 100644 --- a/templatefunc_test.go +++ b/templatefunc_test.go @@ -329,7 +329,7 @@ func TestMapGet(t *testing.T) { } // test 2 level map - m2 := map[string]interface{}{ + m2 := M{ "1": map[string]float64{ "2": 3.5, }, @@ -344,11 +344,11 @@ func TestMapGet(t *testing.T) { } // test 5 level map - m5 := map[string]interface{}{ - "1": map[string]interface{}{ - "2": map[string]interface{}{ - "3": map[string]interface{}{ - "4": map[string]interface{}{ + m5 := M{ + "1": M{ + "2": M{ + "3": M{ + "4": M{ "5": 1.2, }, }, diff --git a/testdata/Makefile b/testdata/Makefile new file mode 100644 index 00000000..e80e8238 --- /dev/null +++ b/testdata/Makefile @@ -0,0 +1,2 @@ +build_view: + $(GOPATH)/bin/go-bindata-assetfs -pkg testdata views/... \ No newline at end of file diff --git a/testdata/bindata.go b/testdata/bindata.go new file mode 100644 index 00000000..beade103 --- /dev/null +++ b/testdata/bindata.go @@ -0,0 +1,296 @@ +// Code generated by go-bindata. +// sources: +// views/blocks/block.tpl +// views/header.tpl +// views/index.tpl +// DO NOT EDIT! + +package testdata + +import ( + "bytes" + "compress/gzip" + "fmt" + "github.com/elazarl/go-bindata-assetfs" + "io" + "io/ioutil" + "os" + "path/filepath" + "strings" + "time" +) + +func bindataRead(data []byte, name string) ([]byte, error) { + gz, err := gzip.NewReader(bytes.NewBuffer(data)) + if err != nil { + return nil, fmt.Errorf("Read %q: %v", name, err) + } + + var buf bytes.Buffer + _, err = io.Copy(&buf, gz) + clErr := gz.Close() + + if err != nil { + return nil, fmt.Errorf("Read %q: %v", name, err) + } + if clErr != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +type asset struct { + bytes []byte + info os.FileInfo +} + +type bindataFileInfo struct { + name string + size int64 + mode os.FileMode + modTime time.Time +} + +func (fi bindataFileInfo) Name() string { + return fi.name +} +func (fi bindataFileInfo) Size() int64 { + return fi.size +} +func (fi bindataFileInfo) Mode() os.FileMode { + return fi.mode +} +func (fi bindataFileInfo) ModTime() time.Time { + return fi.modTime +} +func (fi bindataFileInfo) IsDir() bool { + return false +} +func (fi bindataFileInfo) Sys() interface{} { + return nil +} + +var _viewsBlocksBlockTpl = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xaa\xae\x4e\x49\x4d\xcb\xcc\x4b\x55\x50\x4a\xca\xc9\x4f\xce\x56\xaa\xad\xe5\xb2\xc9\x30\xb4\xf3\x48\xcd\xc9\xc9\xd7\x51\x00\x8b\x15\x2b\xda\xe8\x67\x18\xda\x71\x55\x57\xa7\xe6\xa5\xd4\xd6\x02\x02\x00\x00\xff\xff\xfd\xa1\x7a\xf6\x32\x00\x00\x00") + +func viewsBlocksBlockTplBytes() ([]byte, error) { + return bindataRead( + _viewsBlocksBlockTpl, + "views/blocks/block.tpl", + ) +} + +func viewsBlocksBlockTpl() (*asset, error) { + bytes, err := viewsBlocksBlockTplBytes() + if err != nil { + return nil, err + } + + info := bindataFileInfo{name: "views/blocks/block.tpl", size: 50, mode: os.FileMode(436), modTime: time.Unix(1541431067, 0)} + a := &asset{bytes: bytes, info: info} + return a, nil +} + +var _viewsHeaderTpl = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xaa\xae\x4e\x49\x4d\xcb\xcc\x4b\x55\x50\xca\x48\x4d\x4c\x49\x2d\x52\xaa\xad\xe5\xb2\xc9\x30\xb4\xf3\x48\xcd\xc9\xc9\xd7\x51\x48\x2c\x2e\x49\xac\xc8\x4c\x55\xb4\xd1\xcf\x30\xb4\xe3\xaa\xae\x4e\xcd\x4b\xa9\xad\x05\x04\x00\x00\xff\xff\xe4\x12\x47\x01\x34\x00\x00\x00") + +func viewsHeaderTplBytes() ([]byte, error) { + return bindataRead( + _viewsHeaderTpl, + "views/header.tpl", + ) +} + +func viewsHeaderTpl() (*asset, error) { + bytes, err := viewsHeaderTplBytes() + if err != nil { + return nil, err + } + + info := bindataFileInfo{name: "views/header.tpl", size: 52, mode: os.FileMode(436), modTime: time.Unix(1541431067, 0)} + a := &asset{bytes: bytes, info: info} + return a, nil +} + +var _viewsIndexTpl = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x64\x8f\xbd\x8a\xc3\x30\x10\x84\x6b\xeb\x29\xe6\xfc\x00\x16\xb8\x3c\x16\x35\x77\xa9\x13\x88\x09\xa4\xf4\xcf\x12\x99\x48\x48\xd8\x82\x10\x84\xde\x3d\xc8\x8a\x8b\x90\x6a\xa4\xd9\x6f\xd8\x59\xfa\xf9\x3f\xfe\x75\xd7\xd3\x01\x3a\x58\xa3\x04\x15\x01\x48\x73\x3f\xe5\x07\x40\x61\x0e\x86\xd5\xc0\x7c\x73\x78\xb0\x19\x9d\x65\x04\xb6\xde\xf4\x81\x49\x96\x69\x8e\xc8\x3d\x43\x83\x9b\x9e\x4a\x88\x2a\xc6\x9d\x43\x3d\x18\x37\xde\xeb\x94\x3e\xdd\x1c\xe1\xe5\xcb\xde\xe0\x55\x6e\xd2\x04\x6f\x32\x20\x2a\xd2\xad\x8a\x11\x4d\x97\x57\x22\x25\x92\xba\x55\xa2\x22\xaf\xd0\xe9\x79\xc5\xbc\xe2\xec\x2c\x5f\xfa\xe5\x17\x99\x7b\x7f\x36\xd2\x97\x8a\xa5\x19\xc9\x72\xe7\x2b\x00\x00\xff\xff\xb2\x39\xca\x9f\xff\x00\x00\x00") + +func viewsIndexTplBytes() ([]byte, error) { + return bindataRead( + _viewsIndexTpl, + "views/index.tpl", + ) +} + +func viewsIndexTpl() (*asset, error) { + bytes, err := viewsIndexTplBytes() + if err != nil { + return nil, err + } + + info := bindataFileInfo{name: "views/index.tpl", size: 255, mode: os.FileMode(436), modTime: time.Unix(1541434906, 0)} + a := &asset{bytes: bytes, info: info} + return a, nil +} + +// Asset loads and returns the asset for the given name. +// It returns an error if the asset could not be found or +// could not be loaded. +func Asset(name string) ([]byte, error) { + cannonicalName := strings.Replace(name, "\\", "/", -1) + if f, ok := _bindata[cannonicalName]; ok { + a, err := f() + if err != nil { + return nil, fmt.Errorf("Asset %s can't read by error: %v", name, err) + } + return a.bytes, nil + } + return nil, fmt.Errorf("Asset %s not found", name) +} + +// MustAsset is like Asset but panics when Asset would return an error. +// It simplifies safe initialization of global variables. +func MustAsset(name string) []byte { + a, err := Asset(name) + if err != nil { + panic("asset: Asset(" + name + "): " + err.Error()) + } + + return a +} + +// AssetInfo loads and returns the asset info for the given name. +// It returns an error if the asset could not be found or +// could not be loaded. +func AssetInfo(name string) (os.FileInfo, error) { + cannonicalName := strings.Replace(name, "\\", "/", -1) + if f, ok := _bindata[cannonicalName]; ok { + a, err := f() + if err != nil { + return nil, fmt.Errorf("AssetInfo %s can't read by error: %v", name, err) + } + return a.info, nil + } + return nil, fmt.Errorf("AssetInfo %s not found", name) +} + +// AssetNames returns the names of the assets. +func AssetNames() []string { + names := make([]string, 0, len(_bindata)) + for name := range _bindata { + names = append(names, name) + } + return names +} + +// _bindata is a table, holding each asset generator, mapped to its name. +var _bindata = map[string]func() (*asset, error){ + "views/blocks/block.tpl": viewsBlocksBlockTpl, + "views/header.tpl": viewsHeaderTpl, + "views/index.tpl": viewsIndexTpl, +} + +// AssetDir returns the file names below a certain +// directory embedded in the file by go-bindata. +// For example if you run go-bindata on data/... and data contains the +// following hierarchy: +// data/ +// foo.txt +// img/ +// a.png +// b.png +// then AssetDir("data") would return []string{"foo.txt", "img"} +// AssetDir("data/img") would return []string{"a.png", "b.png"} +// AssetDir("foo.txt") and AssetDir("notexist") would return an error +// AssetDir("") will return []string{"data"}. +func AssetDir(name string) ([]string, error) { + node := _bintree + if len(name) != 0 { + cannonicalName := strings.Replace(name, "\\", "/", -1) + pathList := strings.Split(cannonicalName, "/") + for _, p := range pathList { + node = node.Children[p] + if node == nil { + return nil, fmt.Errorf("Asset %s not found", name) + } + } + } + if node.Func != nil { + return nil, fmt.Errorf("Asset %s not found", name) + } + rv := make([]string, 0, len(node.Children)) + for childName := range node.Children { + rv = append(rv, childName) + } + return rv, nil +} + +type bintree struct { + Func func() (*asset, error) + Children map[string]*bintree +} + +var _bintree = &bintree{nil, map[string]*bintree{ + "views": &bintree{nil, map[string]*bintree{ + "blocks": &bintree{nil, map[string]*bintree{ + "block.tpl": &bintree{viewsBlocksBlockTpl, map[string]*bintree{}}, + }}, + "header.tpl": &bintree{viewsHeaderTpl, map[string]*bintree{}}, + "index.tpl": &bintree{viewsIndexTpl, map[string]*bintree{}}, + }}, +}} + +// RestoreAsset restores an asset under the given directory +func RestoreAsset(dir, name string) error { + data, err := Asset(name) + if err != nil { + return err + } + info, err := AssetInfo(name) + if err != nil { + return err + } + err = os.MkdirAll(_filePath(dir, filepath.Dir(name)), os.FileMode(0755)) + if err != nil { + return err + } + err = ioutil.WriteFile(_filePath(dir, name), data, info.Mode()) + if err != nil { + return err + } + err = os.Chtimes(_filePath(dir, name), info.ModTime(), info.ModTime()) + if err != nil { + return err + } + return nil +} + +// RestoreAssets restores an asset under the given directory recursively +func RestoreAssets(dir, name string) error { + children, err := AssetDir(name) + // File + if err != nil { + return RestoreAsset(dir, name) + } + // Dir + for _, child := range children { + err = RestoreAssets(dir, filepath.Join(name, child)) + if err != nil { + return err + } + } + return nil +} + +func _filePath(dir, name string) string { + cannonicalName := strings.Replace(name, "\\", "/", -1) + return filepath.Join(append([]string{dir}, strings.Split(cannonicalName, "/")...)...) +} + +func assetFS() *assetfs.AssetFS { + assetInfo := func(path string) (os.FileInfo, error) { + return os.Stat(path) + } + for k := range _bintree.Children { + return &assetfs.AssetFS{Asset: Asset, AssetDir: AssetDir, AssetInfo: assetInfo, Prefix: k} + } + panic("unreachable") +} diff --git a/testdata/views/blocks/block.tpl b/testdata/views/blocks/block.tpl new file mode 100644 index 00000000..2a9c57fc --- /dev/null +++ b/testdata/views/blocks/block.tpl @@ -0,0 +1,3 @@ +{{define "block"}} +

Hello, blocks!

+{{end}} \ No newline at end of file diff --git a/testdata/views/header.tpl b/testdata/views/header.tpl new file mode 100644 index 00000000..041fa403 --- /dev/null +++ b/testdata/views/header.tpl @@ -0,0 +1,3 @@ +{{define "header"}} +

Hello, astaxie!

+{{end}} \ No newline at end of file diff --git a/testdata/views/index.tpl b/testdata/views/index.tpl new file mode 100644 index 00000000..21b7fc06 --- /dev/null +++ b/testdata/views/index.tpl @@ -0,0 +1,15 @@ + + + + beego welcome template + + + + {{template "block"}} + {{template "header"}} + {{template "blocks/block.tpl"}} + +

{{ .Title }}

+

This is SomeVar: {{ .SomeVar }}

+ + diff --git a/toolbox/task.go b/toolbox/task.go index 672717cd..7e841e89 100644 --- a/toolbox/task.go +++ b/toolbox/task.go @@ -428,6 +428,9 @@ func run() { continue case <-changed: now = time.Now().Local() + for _, t := range AdminTaskList { + t.SetNext(now) + } continue case <-stop: return @@ -446,6 +449,7 @@ func StopTask() { // AddTask add task with name func AddTask(taskname string, t Tasker) { + t.SetNext(time.Now().Local()) AdminTaskList[taskname] = t if isstart { changed <- true diff --git a/utils/safemap_test.go b/utils/safemap_test.go index b3aa0896..65085195 100644 --- a/utils/safemap_test.go +++ b/utils/safemap_test.go @@ -26,6 +26,7 @@ func TestNewBeeMap(t *testing.T) { } func TestSet(t *testing.T) { + safeMap = NewBeeMap() if ok := safeMap.Set("astaxie", 1); !ok { t.Error("expected", true, "got", false) } diff --git a/vendor/github.com/Knetic/govaluate/.gitignore b/vendor/github.com/Knetic/govaluate/.gitignore new file mode 100644 index 00000000..da210fb3 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/.gitignore @@ -0,0 +1,28 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test +coverage.out + +manual_test.go +*.out +*.err diff --git a/vendor/github.com/Knetic/govaluate/.travis.yml b/vendor/github.com/Knetic/govaluate/.travis.yml new file mode 100644 index 00000000..35ae404a --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/.travis.yml @@ -0,0 +1,10 @@ +language: go + +script: ./test.sh + +go: + - 1.2 + - 1.3 + - 1.4 + - 1.5 + - 1.6 diff --git a/vendor/github.com/Knetic/govaluate/CONTRIBUTORS b/vendor/github.com/Knetic/govaluate/CONTRIBUTORS new file mode 100644 index 00000000..364e5cf9 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/CONTRIBUTORS @@ -0,0 +1,12 @@ +This library was authored by George Lester, and contains contributions from: + +vjeantet (regex support) +iasci (ternary operator) +oxtoacart (parameter structures, deferred parameter retrieval) +wmiller848 (bitwise operators) +prashantv (optimization of bools) +dpaolella (exposure of variables used in an expression) +benpaxton (fix for missing type checks during literal elide process) +abrander (panic-finding testing tool) +xfennec (fix for dates being parsed in the current Location) +bgaifullin (lifting restriction on complex/struct types) \ No newline at end of file diff --git a/vendor/github.com/Knetic/govaluate/EvaluableExpression.go b/vendor/github.com/Knetic/govaluate/EvaluableExpression.go new file mode 100644 index 00000000..8f43112a --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/EvaluableExpression.go @@ -0,0 +1,272 @@ +package govaluate + +import ( + "errors" + "fmt" +) + +const isoDateFormat string = "2006-01-02T15:04:05.999999999Z0700" +const shortCircuitHolder int = -1 + +var DUMMY_PARAMETERS = MapParameters(map[string]interface{}{}) + +/* + EvaluableExpression represents a set of ExpressionTokens which, taken together, + are an expression that can be evaluated down into a single value. +*/ +type EvaluableExpression struct { + + /* + Represents the query format used to output dates. Typically only used when creating SQL or Mongo queries from an expression. + Defaults to the complete ISO8601 format, including nanoseconds. + */ + QueryDateFormat string + + /* + Whether or not to safely check types when evaluating. + If true, this library will return error messages when invalid types are used. + If false, the library will panic when operators encounter types they can't use. + + This is exclusively for users who need to squeeze every ounce of speed out of the library as they can, + and you should only set this to false if you know exactly what you're doing. + */ + ChecksTypes bool + + tokens []ExpressionToken + evaluationStages *evaluationStage + inputExpression string +} + +/* + Parses a new EvaluableExpression from the given [expression] string. + Returns an error if the given expression has invalid syntax. +*/ +func NewEvaluableExpression(expression string) (*EvaluableExpression, error) { + + functions := make(map[string]ExpressionFunction) + return NewEvaluableExpressionWithFunctions(expression, functions) +} + +/* + Similar to [NewEvaluableExpression], except that instead of a string, an already-tokenized expression is given. + This is useful in cases where you may be generating an expression automatically, or using some other parser (e.g., to parse from a query language) +*/ +func NewEvaluableExpressionFromTokens(tokens []ExpressionToken) (*EvaluableExpression, error) { + + var ret *EvaluableExpression + var err error + + ret = new(EvaluableExpression) + ret.QueryDateFormat = isoDateFormat + + err = checkBalance(tokens) + if err != nil { + return nil, err + } + + err = checkExpressionSyntax(tokens) + if err != nil { + return nil, err + } + + ret.tokens, err = optimizeTokens(tokens) + if err != nil { + return nil, err + } + + ret.evaluationStages, err = planStages(ret.tokens) + if err != nil { + return nil, err + } + + ret.ChecksTypes = true + return ret, nil +} + +/* + Similar to [NewEvaluableExpression], except enables the use of user-defined functions. + Functions passed into this will be available to the expression. +*/ +func NewEvaluableExpressionWithFunctions(expression string, functions map[string]ExpressionFunction) (*EvaluableExpression, error) { + + var ret *EvaluableExpression + var err error + + ret = new(EvaluableExpression) + ret.QueryDateFormat = isoDateFormat + ret.inputExpression = expression + + ret.tokens, err = parseTokens(expression, functions) + if err != nil { + return nil, err + } + + err = checkBalance(ret.tokens) + if err != nil { + return nil, err + } + + err = checkExpressionSyntax(ret.tokens) + if err != nil { + return nil, err + } + + ret.tokens, err = optimizeTokens(ret.tokens) + if err != nil { + return nil, err + } + + ret.evaluationStages, err = planStages(ret.tokens) + if err != nil { + return nil, err + } + + ret.ChecksTypes = true + return ret, nil +} + +/* + Same as `Eval`, but automatically wraps a map of parameters into a `govalute.Parameters` structure. +*/ +func (this EvaluableExpression) Evaluate(parameters map[string]interface{}) (interface{}, error) { + + if parameters == nil { + return this.Eval(nil) + } + return this.Eval(MapParameters(parameters)) +} + +/* + Runs the entire expression using the given [parameters]. + e.g., If the expression contains a reference to the variable "foo", it will be taken from `parameters.Get("foo")`. + + This function returns errors if the combination of expression and parameters cannot be run, + such as if a variable in the expression is not present in [parameters]. + + In all non-error circumstances, this returns the single value result of the expression and parameters given. + e.g., if the expression is "1 + 1", this will return 2.0. + e.g., if the expression is "foo + 1" and parameters contains "foo" = 2, this will return 3.0 +*/ +func (this EvaluableExpression) Eval(parameters Parameters) (interface{}, error) { + + if this.evaluationStages == nil { + return nil, nil + } + + if parameters != nil { + parameters = &sanitizedParameters{parameters} + } + return this.evaluateStage(this.evaluationStages, parameters) +} + +func (this EvaluableExpression) evaluateStage(stage *evaluationStage, parameters Parameters) (interface{}, error) { + + var left, right interface{} + var err error + + if stage.leftStage != nil { + left, err = this.evaluateStage(stage.leftStage, parameters) + if err != nil { + return nil, err + } + } + + if stage.isShortCircuitable() { + switch stage.symbol { + case AND: + if left == false { + return false, nil + } + case OR: + if left == true { + return true, nil + } + case COALESCE: + if left != nil { + return left, nil + } + + case TERNARY_TRUE: + if left == false { + right = shortCircuitHolder + } + case TERNARY_FALSE: + if left != nil { + right = shortCircuitHolder + } + } + } + + if right != shortCircuitHolder && stage.rightStage != nil { + right, err = this.evaluateStage(stage.rightStage, parameters) + if err != nil { + return nil, err + } + } + + if this.ChecksTypes { + if stage.typeCheck == nil { + + err = typeCheck(stage.leftTypeCheck, left, stage.symbol, stage.typeErrorFormat) + if err != nil { + return nil, err + } + + err = typeCheck(stage.rightTypeCheck, right, stage.symbol, stage.typeErrorFormat) + if err != nil { + return nil, err + } + } else { + // special case where the type check needs to know both sides to determine if the operator can handle it + if !stage.typeCheck(left, right) { + errorMsg := fmt.Sprintf(stage.typeErrorFormat, left, stage.symbol.String()) + return nil, errors.New(errorMsg) + } + } + } + + return stage.operator(left, right, parameters) +} + +func typeCheck(check stageTypeCheck, value interface{}, symbol OperatorSymbol, format string) error { + + if check == nil { + return nil + } + + if check(value) { + return nil + } + + errorMsg := fmt.Sprintf(format, value, symbol.String()) + return errors.New(errorMsg) +} + +/* + Returns an array representing the ExpressionTokens that make up this expression. +*/ +func (this EvaluableExpression) Tokens() []ExpressionToken { + + return this.tokens +} + +/* + Returns the original expression used to create this EvaluableExpression. +*/ +func (this EvaluableExpression) String() string { + + return this.inputExpression +} + +/* + Returns an array representing the variables contained in this EvaluableExpression. +*/ +func (this EvaluableExpression) Vars() []string { + var varlist []string + for _, val := range this.Tokens() { + if val.Kind == VARIABLE { + varlist = append(varlist, val.Value.(string)) + } + } + return varlist +} diff --git a/vendor/github.com/Knetic/govaluate/EvaluableExpression_sql.go b/vendor/github.com/Knetic/govaluate/EvaluableExpression_sql.go new file mode 100644 index 00000000..7e0ad1c8 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/EvaluableExpression_sql.go @@ -0,0 +1,167 @@ +package govaluate + +import ( + "errors" + "fmt" + "regexp" + "time" +) + +/* + Returns a string representing this expression as if it were written in SQL. + This function assumes that all parameters exist within the same table, and that the table essentially represents + a serialized object of some sort (e.g., hibernate). + If your data model is more normalized, you may need to consider iterating through each actual token given by `Tokens()` + to create your query. + + Boolean values are considered to be "1" for true, "0" for false. + + Times are formatted according to this.QueryDateFormat. +*/ +func (this EvaluableExpression) ToSQLQuery() (string, error) { + + var stream *tokenStream + var transactions *expressionOutputStream + var transaction string + var err error + + stream = newTokenStream(this.tokens) + transactions = new(expressionOutputStream) + + for stream.hasNext() { + + transaction, err = this.findNextSQLString(stream, transactions) + if err != nil { + return "", err + } + + transactions.add(transaction) + } + + return transactions.createString(" "), nil +} + +func (this EvaluableExpression) findNextSQLString(stream *tokenStream, transactions *expressionOutputStream) (string, error) { + + var token ExpressionToken + var ret string + + token = stream.next() + + switch token.Kind { + + case STRING: + ret = fmt.Sprintf("'%v'", token.Value) + case PATTERN: + ret = fmt.Sprintf("'%s'", token.Value.(*regexp.Regexp).String()) + case TIME: + ret = fmt.Sprintf("'%s'", token.Value.(time.Time).Format(this.QueryDateFormat)) + + case LOGICALOP: + switch logicalSymbols[token.Value.(string)] { + + case AND: + ret = "AND" + case OR: + ret = "OR" + } + + case BOOLEAN: + if token.Value.(bool) { + ret = "1" + } else { + ret = "0" + } + + case VARIABLE: + ret = fmt.Sprintf("[%s]", token.Value.(string)) + + case NUMERIC: + ret = fmt.Sprintf("%g", token.Value.(float64)) + + case COMPARATOR: + switch comparatorSymbols[token.Value.(string)] { + + case EQ: + ret = "=" + case NEQ: + ret = "<>" + case REQ: + ret = "RLIKE" + case NREQ: + ret = "NOT RLIKE" + default: + ret = fmt.Sprintf("%s", token.Value.(string)) + } + + case TERNARY: + + switch ternarySymbols[token.Value.(string)] { + + case COALESCE: + + left := transactions.rollback() + right, err := this.findNextSQLString(stream, transactions) + if err != nil { + return "", err + } + + ret = fmt.Sprintf("COALESCE(%v, %v)", left, right) + case TERNARY_TRUE: + fallthrough + case TERNARY_FALSE: + return "", errors.New("Ternary operators are unsupported in SQL output") + } + case PREFIX: + switch prefixSymbols[token.Value.(string)] { + + case INVERT: + ret = fmt.Sprintf("NOT") + default: + + right, err := this.findNextSQLString(stream, transactions) + if err != nil { + return "", err + } + + ret = fmt.Sprintf("%s%s", token.Value.(string), right) + } + case MODIFIER: + + switch modifierSymbols[token.Value.(string)] { + + case EXPONENT: + + left := transactions.rollback() + right, err := this.findNextSQLString(stream, transactions) + if err != nil { + return "", err + } + + ret = fmt.Sprintf("POW(%s, %s)", left, right) + case MODULUS: + + left := transactions.rollback() + right, err := this.findNextSQLString(stream, transactions) + if err != nil { + return "", err + } + + ret = fmt.Sprintf("MOD(%s, %s)", left, right) + default: + ret = fmt.Sprintf("%s", token.Value.(string)) + } + case CLAUSE: + ret = "(" + case CLAUSE_CLOSE: + ret = ")" + case SEPARATOR: + ret = "," + + default: + errorMsg := fmt.Sprintf("Unrecognized query token '%s' of kind '%s'", token.Value, token.Kind) + return "", errors.New(errorMsg) + } + + return ret, nil +} diff --git a/vendor/github.com/Knetic/govaluate/ExpressionToken.go b/vendor/github.com/Knetic/govaluate/ExpressionToken.go new file mode 100644 index 00000000..f849f381 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/ExpressionToken.go @@ -0,0 +1,9 @@ +package govaluate + +/* + Represents a single parsed token. +*/ +type ExpressionToken struct { + Kind TokenKind + Value interface{} +} diff --git a/vendor/github.com/Knetic/govaluate/LICENSE b/vendor/github.com/Knetic/govaluate/LICENSE new file mode 100644 index 00000000..24b9b459 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2014-2016 George Lester + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/Knetic/govaluate/MANUAL.md b/vendor/github.com/Knetic/govaluate/MANUAL.md new file mode 100644 index 00000000..e0658285 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/MANUAL.md @@ -0,0 +1,176 @@ +govaluate +==== + +This library contains quite a lot of functionality, this document is meant to be formal documentation on the operators and features of it. +Some of this documentation may duplicate what's in README.md, but should never conflict. + +# Types + +This library only officially deals with four types; `float64`, `bool`, `string`, and arrays. + +All numeric literals, with or without a radix, will be converted to `float64` for evaluation. For instance; in practice, there is no difference between the literals "1.0" and "1", they both end up as `float64`. This matters to users because if you intend to return numeric values from your expressions, then the returned value will be `float64`, not any other numeric type. + +Any string _literal_ (not parameter) which is interpretable as a date will be converted to a `float64` representation of that date's unix time. Any `time.Time` parameters will not be operable with these date literals; such parameters will need to use the `time.Time.Unix()` method to get a numeric representation. + +Arrays are untyped, and can be mixed-type. Internally they're all just `interface{}`. Only two operators can interact with arrays, `IN` and `,`. All other operators will refuse to operate on arrays. + +# Operators + +## Modifiers + +### Addition, concatenation `+` + +If either left or right sides of the `+` operator are a `string`, then this operator will perform string concatenation and return that result. If neither are string, then both must be numeric, and this will return a numeric result. + +Any other case is invalid. + +### Arithmetic `-` `*` `/` `**` `%` + +`**` refers to "take to the power of". For instance, `3 ** 4` == 81. + +* _Left side_: numeric +* _Right side_: numeric +* _Returns_: numeric + +### Bitwise shifts, masks `>>` `<<` `|` `&` `^` + +All of these operators convert their `float64` left and right sides to `int64`, perform their operation, and then convert back. +Given how this library assumes numeric are represented (as `float64`), it is unlikely that this behavior will change, even though it may cause havoc with extremely large or small numbers. + +* _Left side_: numeric +* _Right side_: numeric +* _Returns_: numeric + +### Negation `-` + +Prefix only. This can never have a left-hand value. + +* _Right side_: numeric +* _Returns_: numeric + +### Inversion `!` + +Prefix only. This can never have a left-hand value. + +* _Right side_: bool +* _Returns_: bool + +### Bitwise NOT `~` + +Prefix only. This can never have a left-hand value. + +* _Right side_: numeric +* _Returns_: numeric + +## Logical Operators + +For all logical operators, this library will short-circuit the operation if the left-hand side is sufficient to determine what to do. For instance, `true || expensiveOperation()` will not actually call `expensiveOperation()`, since it knows the left-hand side is `true`. + +### Logical AND/OR `&&` `||` + +* _Left side_: bool +* _Right side_: bool +* _Returns_: bool + +### Ternary true `?` + +Checks if the left side is `true`. If so, returns the right side. If the left side is `false`, returns `nil`. +In practice, this is commonly used with the other ternary operator. + +* _Left side_: bool +* _Right side_: Any type. +* _Returns_: Right side or `nil` + +### Ternary false `:` + +Checks if the left side is `nil`. If so, returns the right side. If the left side is non-nil, returns the left side. +In practice, this is commonly used with the other ternary operator. + +* _Left side_: Any type. +* _Right side_: Any type. +* _Returns_: Right side or `nil` + +### Null coalescence `??` + +Similar to the C# operator. If the left value is non-nil, it returns that. If not, then the right-value is returned. + +* _Left side_: Any type. +* _Right side_: Any type. +* _Returns_: No specific type - whichever is passed to it. + +## Comparators + +### Numeric/lexicographic comparators `>` `<` `>=` `<=` + +If both sides are numeric, this returns the usual greater/lesser behavior that would be expected. +If both sides are string, this returns the lexicographic comparison of the strings. This uses Go's standard lexicographic compare. + +* _Accepts_: Left and right side must either be both string, or both numeric. +* _Returns_: bool + +### Regex comparators `=~` `!~` + +These use go's standard `regexp` flavor of regex. The left side is expected to be the candidate string, the right side is the pattern. `=~` returns whether or not the candidate string matches the regex pattern given on the right. `!~` is the inverted version of the same logic. + +* _Left side_: string +* _Right side_: string +* _Returns_: bool + +## Arrays + +### Separator `,` + +The separator, always paired with parenthesis, creates arrays. It must always have both a left and right-hand value, so for instance `(, 0)` and `(0,)` are invalid uses of it. + +Again, this should always be used with parenthesis; like `(1, 2, 3, 4)`. + +### Membership `IN` + +The only operator with a text name, this operator checks the right-hand side array to see if it contains a value that is equal to the left-side value. +Equality is determined by the use of the `==` operator, and this library doesn't check types between the values. Any two values, when cast to `interface{}`, and can still be checked for equality with `==` will act as expected. + +Note that you can use a parameter for the array, but it must be an `[]interface{}`. + +* _Left side_: Any type. +* _Right side_: array +* _Returns_: bool + +# Parameters + +Parameters must be passed in every time the expression is evaluated. Parameters can be of any type, but will not cause errors unless actually used in an erroneous way. There is no difference in behavior for any of the above operators for parameters - they are type checked when used. + +All `int` and `float` values of any width will be converted to `float64` before use. + +At no point is the parameter structure, or any value thereof, modified by this library. + +## Alternates to maps + +The default form of parameters as a map may not serve your use case. You may have parameters in some other structure, you may want to change the no-parameter-found behavior, or maybe even just have some debugging print statements invoked when a parameter is accessed. + +To do this, define a type that implements the `govaluate.Parameters` interface. When you want to evaluate, instead call `EvaluableExpression.Eval` and pass your parameter structure. + +# Functions + +During expression parsing (_not_ evaluation), a map of functions can be given to `govaluate.NewEvaluableExpressionWithFunctions` (the lengthiest and finest of function names). The resultant expression will be able to invoke those functions during evaluation. Once parsed, an expression cannot have functions added or removed - a new expression will need to be created if you want to change the functions, or behavior of said functions. + +Functions always take the form `()`, including parens. Functions can have an empty list of parameters, like `()`, but still must have parens. + +If the expression contains something that looks like it ought to be a function (such as `foo()`), but no such function was given to it, it will error on parsing. + +Functions must be of type `map[string]govaluate.ExpressionFunction`. `ExpressionFunction`, for brevity, has the following signature: + +`func(args ...interface{}) (interface{}, error)` + +Where `args` is whatever is passed to the function when called. If a non-nil error is returned from a function during evaluation, the evaluation stops and ultimately returns that error to the caller of `Evaluate()` or `Eval()`. + +## Built-in functions + +There aren't any builtin functions. The author is opposed to maintaining a standard library of functions to be used. + +Every use case of this library is different, and even in simple use cases (such as parameters, see above) different users need different behavior, naming, or even functionality. The author prefers that users make their own decisions about what functions they need, and how they operate. + +# Equality + +The `==` and `!=` operators involve a moderately complex workflow. They use [`reflect.DeepEqual`](https://golang.org/pkg/reflect/#DeepEqual). This is for complicated reasons, but there are some types in Go that cannot be compared with the native `==` operator. Arrays, in particular, cannot be compared - Go will panic if you try. One might assume this could be handled with the type checking system in `govaluate`, but unfortunately without reflection there is no way to know if a variable is a slice/array. Worse, structs can be incomparable if they _contain incomparable types_. + +It's all very complicated. Fortunately, Go includes the `reflect.DeepEqual` function to handle all the edge cases. Currently, `govaluate` uses that for all equality/inequality. diff --git a/vendor/github.com/Knetic/govaluate/OperatorSymbol.go b/vendor/github.com/Knetic/govaluate/OperatorSymbol.go new file mode 100644 index 00000000..48813d29 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/OperatorSymbol.go @@ -0,0 +1,306 @@ +package govaluate + +/* + Represents the valid symbols for operators. + +*/ +type OperatorSymbol int + +const ( + VALUE OperatorSymbol = iota + LITERAL + NOOP + EQ + NEQ + GT + LT + GTE + LTE + REQ + NREQ + IN + + AND + OR + + PLUS + MINUS + BITWISE_AND + BITWISE_OR + BITWISE_XOR + BITWISE_LSHIFT + BITWISE_RSHIFT + MULTIPLY + DIVIDE + MODULUS + EXPONENT + + NEGATE + INVERT + BITWISE_NOT + + TERNARY_TRUE + TERNARY_FALSE + COALESCE + + FUNCTIONAL + SEPARATE +) + +type operatorPrecedence int + +const ( + noopPrecedence operatorPrecedence = iota + valuePrecedence + functionalPrecedence + prefixPrecedence + exponentialPrecedence + additivePrecedence + bitwisePrecedence + bitwiseShiftPrecedence + multiplicativePrecedence + comparatorPrecedence + ternaryPrecedence + logicalAndPrecedence + logicalOrPrecedence + separatePrecedence +) + +func findOperatorPrecedenceForSymbol(symbol OperatorSymbol) operatorPrecedence { + + switch symbol { + case NOOP: + return noopPrecedence + case VALUE: + return valuePrecedence + case EQ: + fallthrough + case NEQ: + fallthrough + case GT: + fallthrough + case LT: + fallthrough + case GTE: + fallthrough + case LTE: + fallthrough + case REQ: + fallthrough + case NREQ: + fallthrough + case IN: + return comparatorPrecedence + case AND: + return logicalAndPrecedence + case OR: + return logicalOrPrecedence + case BITWISE_AND: + fallthrough + case BITWISE_OR: + fallthrough + case BITWISE_XOR: + return bitwisePrecedence + case BITWISE_LSHIFT: + fallthrough + case BITWISE_RSHIFT: + return bitwiseShiftPrecedence + case PLUS: + fallthrough + case MINUS: + return additivePrecedence + case MULTIPLY: + fallthrough + case DIVIDE: + fallthrough + case MODULUS: + return multiplicativePrecedence + case EXPONENT: + return exponentialPrecedence + case BITWISE_NOT: + fallthrough + case NEGATE: + fallthrough + case INVERT: + return prefixPrecedence + case COALESCE: + fallthrough + case TERNARY_TRUE: + fallthrough + case TERNARY_FALSE: + return ternaryPrecedence + case FUNCTIONAL: + return functionalPrecedence + case SEPARATE: + return separatePrecedence + } + + return valuePrecedence +} + +/* + Map of all valid comparators, and their string equivalents. + Used during parsing of expressions to determine if a symbol is, in fact, a comparator. + Also used during evaluation to determine exactly which comparator is being used. +*/ +var comparatorSymbols = map[string]OperatorSymbol{ + "==": EQ, + "!=": NEQ, + ">": GT, + ">=": GTE, + "<": LT, + "<=": LTE, + "=~": REQ, + "!~": NREQ, + "in": IN, +} + +var logicalSymbols = map[string]OperatorSymbol{ + "&&": AND, + "||": OR, +} + +var bitwiseSymbols = map[string]OperatorSymbol{ + "^": BITWISE_XOR, + "&": BITWISE_AND, + "|": BITWISE_OR, +} + +var bitwiseShiftSymbols = map[string]OperatorSymbol{ + ">>": BITWISE_RSHIFT, + "<<": BITWISE_LSHIFT, +} + +var additiveSymbols = map[string]OperatorSymbol{ + "+": PLUS, + "-": MINUS, +} + +var multiplicativeSymbols = map[string]OperatorSymbol{ + "*": MULTIPLY, + "/": DIVIDE, + "%": MODULUS, +} + +var exponentialSymbolsS = map[string]OperatorSymbol{ + "**": EXPONENT, +} + +var prefixSymbols = map[string]OperatorSymbol{ + "-": NEGATE, + "!": INVERT, + "~": BITWISE_NOT, +} + +var ternarySymbols = map[string]OperatorSymbol{ + "?": TERNARY_TRUE, + ":": TERNARY_FALSE, + "??": COALESCE, +} + +// this is defined separately from additiveSymbols et al because it's needed for parsing, not stage planning. +var modifierSymbols = map[string]OperatorSymbol{ + "+": PLUS, + "-": MINUS, + "*": MULTIPLY, + "/": DIVIDE, + "%": MODULUS, + "**": EXPONENT, + "&": BITWISE_AND, + "|": BITWISE_OR, + "^": BITWISE_XOR, + ">>": BITWISE_RSHIFT, + "<<": BITWISE_LSHIFT, +} + +var separatorSymbols = map[string]OperatorSymbol{ + ",": SEPARATE, +} + +/* + Returns true if this operator is contained by the given array of candidate symbols. + False otherwise. +*/ +func (this OperatorSymbol) IsModifierType(candidate []OperatorSymbol) bool { + + for _, symbolType := range candidate { + if this == symbolType { + return true + } + } + + return false +} + +/* + Generally used when formatting type check errors. + We could store the stringified symbol somewhere else and not require a duplicated codeblock to translate + OperatorSymbol to string, but that would require more memory, and another field somewhere. + Adding operators is rare enough that we just stringify it here instead. +*/ +func (this OperatorSymbol) String() string { + + switch this { + case NOOP: + return "NOOP" + case VALUE: + return "VALUE" + case EQ: + return "=" + case NEQ: + return "!=" + case GT: + return ">" + case LT: + return "<" + case GTE: + return ">=" + case LTE: + return "<=" + case REQ: + return "=~" + case NREQ: + return "!~" + case AND: + return "&&" + case OR: + return "||" + case IN: + return "in" + case BITWISE_AND: + return "&" + case BITWISE_OR: + return "|" + case BITWISE_XOR: + return "^" + case BITWISE_LSHIFT: + return "<<" + case BITWISE_RSHIFT: + return ">>" + case PLUS: + return "+" + case MINUS: + return "-" + case MULTIPLY: + return "*" + case DIVIDE: + return "/" + case MODULUS: + return "%" + case EXPONENT: + return "**" + case NEGATE: + return "-" + case INVERT: + return "!" + case BITWISE_NOT: + return "~" + case TERNARY_TRUE: + return "?" + case TERNARY_FALSE: + return ":" + case COALESCE: + return "??" + } + return "" +} diff --git a/vendor/github.com/Knetic/govaluate/README.md b/vendor/github.com/Knetic/govaluate/README.md new file mode 100644 index 00000000..42768e6d --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/README.md @@ -0,0 +1,210 @@ +govaluate +==== + +[![Build Status](https://travis-ci.org/Knetic/govaluate.svg?branch=master)](https://travis-ci.org/Knetic/govaluate) +[![Godoc](https://godoc.org/github.com/Knetic/govaluate?status.png)](https://godoc.org/github.com/Knetic/govaluate) + + +Provides support for evaluating arbitrary C-like artithmetic/string expressions. + +Why can't you just write these expressions in code? +-- + +Sometimes, you can't know ahead-of-time what an expression will look like, or you want those expressions to be configurable. +Perhaps you've got a set of data running through your application, and you want to allow your users to specify some validations to run on it before committing it to a database. Or maybe you've written a monitoring framework which is capable of gathering a bunch of metrics, then evaluating a few expressions to see if any metrics should be alerted upon, but the conditions for alerting are different for each monitor. + +A lot of people wind up writing their own half-baked style of evaluation language that fits their needs, but isn't complete. Or they wind up baking the expression into the actual executable, even if they know it's subject to change. These strategies may work, but they take time to implement, time for users to learn, and induce technical debt as requirements change. This library is meant to cover all the normal C-like expressions, so that you don't have to reinvent one of the oldest wheels on a computer. + +How do I use it? +-- + +You create a new EvaluableExpression, then call "Evaluate" on it. + +```go + expression, err := govaluate.NewEvaluableExpression("10 > 0"); + result, err := expression.Evaluate(nil); + // result is now set to "true", the bool value. +``` + +Cool, but how about with parameters? + +```go + expression, err := govaluate.NewEvaluableExpression("foo > 0"); + + parameters := make(map[string]interface{}, 8) + parameters["foo"] = -1; + + result, err := expression.Evaluate(parameters); + // result is now set to "false", the bool value. +``` + +That's cool, but we can almost certainly have done all that in code. What about a complex use case that involves some math? + +```go + expression, err := govaluate.NewEvaluableExpression("(requests_made * requests_succeeded / 100) >= 90"); + + parameters := make(map[string]interface{}, 8) + parameters["requests_made"] = 100; + parameters["requests_succeeded"] = 80; + + result, err := expression.Evaluate(parameters); + // result is now set to "false", the bool value. +``` + +Or maybe you want to check the status of an alive check ("smoketest") page, which will be a string? + +```go + expression, err := govaluate.NewEvaluableExpression("http_response_body == 'service is ok'"); + + parameters := make(map[string]interface{}, 8) + parameters["http_response_body"] = "service is ok"; + + result, err := expression.Evaluate(parameters); + // result is now set to "true", the bool value. +``` + +These examples have all returned boolean values, but it's equally possible to return numeric ones. + +```go + expression, err := govaluate.NewEvaluableExpression("(mem_used / total_mem) * 100"); + + parameters := make(map[string]interface{}, 8) + parameters["total_mem"] = 1024; + parameters["mem_used"] = 512; + + result, err := expression.Evaluate(parameters); + // result is now set to "50.0", the float64 value. +``` + +You can also do date parsing, though the formats are somewhat limited. Stick to RF3339, ISO8061, unix date, or ruby date formats. If you're having trouble getting a date string to parse, check the list of formats actually used: [parsing.go:248](https://github.com/Knetic/govaluate/blob/0580e9b47a69125afa0e4ebd1cf93c49eb5a43ec/parsing.go#L258). + +```go + expression, err := govaluate.NewEvaluableExpression("'2014-01-02' > '2014-01-01 23:59:59'"); + result, err := expression.Evaluate(nil); + + // result is now set to true +``` + +Expressions are parsed once, and can be re-used multiple times. Parsing is the compute-intensive phase of the process, so if you intend to use the same expression with different parameters, just parse it once. Like so; + +```go + expression, err := govaluate.NewEvaluableExpression("response_time <= 100"); + parameters := make(map[string]interface{}, 8) + + for { + parameters["response_time"] = pingSomething(); + result, err := expression.Evaluate(parameters) + } +``` + +The normal C-standard order of operators is respected. When writing an expression, be sure that you either order the operators correctly, or use parenthesis to clarify which portions of an expression should be run first. + +Escaping characters +-- + +Sometimes you'll have parameters that have spaces, slashes, pluses, ampersands or some other character +that this library interprets as something special. For example, the following expression will not +act as one might expect: + + "response-time < 100" + +As written, the library will parse it as "[response] minus [time] is less than 100". In reality, +"response-time" is meant to be one variable that just happens to have a dash in it. + +There are two ways to work around this. First, you can escape the entire parameter name: + + "[response-time] < 100" + +Or you can use backslashes to escape only the minus sign. + + "response\\-time < 100" + +Backslashes can be used anywhere in an expression to escape the very next character. Square bracketed parameter names can be used instead of plain parameter names at any time. + +Functions +-- + +You may have cases where you want to call a function on a parameter during execution of the expression. Perhaps you want to aggregate some set of data, but don't know the exact aggregation you want to use until you're writing the expression itself. Or maybe you have a mathematical operation you want to perform, for which there is no operator; like `log` or `tan` or `sqrt`. For cases like this, you can provide a map of functions to `NewEvaluableExpressionWithFunctions`, which will then be able to use them during execution. For instance; + +```go + functions := map[string]govaluate.ExpressionFunction { + "strlen": func(args ...interface{}) (interface{}, error) { + length := len(args[0].(string)) + return (float64)(length), nil + }, + } + + expString := "strlen('someReallyLongInputString') <= 16" + expression, _ := govaluate.NewEvaluableExpressionWithFunctions(expString, functions) + + result, _ := expression.Evaluate(nil) + // result is now "false", the boolean value +``` + +Functions can accept any number of arguments, correctly handles nested functions, and arguments can be of any type (even if none of this library's operators support evaluation of that type). For instance, each of these usages of functions in an expression are valid (assuming that the appropriate functions and parameters are given): + +```go +"sqrt(x1 ** y1, x2 ** y2)" +"max(someValue, abs(anotherValue), 10 * lastValue)" +``` + +Functions cannot be passed as parameters, they must be known at the time when the expression is parsed, and are unchangeable after parsing. + +What operators and types does this support? +-- + +* Modifiers: `+` `-` `/` `*` `&` `|` `^` `**` `%` `>>` `<<` +* Comparators: `>` `>=` `<` `<=` `==` `!=` `=~` `!~` +* Logical ops: `||` `&&` +* Numeric constants, as 64-bit floating point (`12345.678`) +* String constants (single quotes: `'foobar'`) +* Date constants (single quotes, using any permutation of RFC3339, ISO8601, ruby date, or unix date; date parsing is automatically tried with any string constant) +* Boolean constants: `true` `false` +* Parenthesis to control order of evaluation `(` `)` +* Arrays (anything separated by `,` within parenthesis: `(1, 2, 'foo')`) +* Prefixes: `!` `-` `~` +* Ternary conditional: `?` `:` +* Null coalescence: `??` + +See [MANUAL.md](https://github.com/Knetic/govaluate/blob/master/MANUAL.md) for exacting details on what types each operator supports. + +Types +-- + +Some operators don't make sense when used with some types. For instance, what does it mean to get the modulo of a string? What happens if you check to see if two numbers are logically AND'ed together? + +Everyone has a different intuition about the answers to these questions. To prevent confusion, this library will _refuse to operate_ upon types for which there is not an unambiguous meaning for the operation. See [MANUAL.md](https://github.com/Knetic/govaluate/blob/master/MANUAL.md) for details about what operators are valid for which types. + +Benchmarks +-- + +If you're concerned about the overhead of this library, a good range of benchmarks are built into this repo. You can run them with `go test -bench=.`. The library is built with an eye towards being quick, but has not been aggressively profiled and optimized. For most applications, though, it is completely fine. + +For a very rough idea of performance, here are the results output from a benchmark run on a 3rd-gen Macbook Pro (Linux Mint 17.1). + +``` +BenchmarkSingleParse-12 1000000 1382 ns/op +BenchmarkSimpleParse-12 200000 10771 ns/op +BenchmarkFullParse-12 30000 49383 ns/op +BenchmarkEvaluationSingle-12 50000000 30.1 ns/op +BenchmarkEvaluationNumericLiteral-12 10000000 119 ns/op +BenchmarkEvaluationLiteralModifiers-12 10000000 236 ns/op +BenchmarkEvaluationParameters-12 5000000 260 ns/op +BenchmarkEvaluationParametersModifiers-12 3000000 547 ns/op +BenchmarkComplexExpression-12 2000000 963 ns/op +BenchmarkRegexExpression-12 100000 20357 ns/op +BenchmarkConstantRegexExpression-12 1000000 1392 ns/op +ok +``` + +API Breaks +-- + +While this library has very few cases which will ever result in an API break, it can (and [has](https://github.com/Knetic/govaluate/releases/tag/2.0.0)) happened. If you are using this in production, vendor the commit you've tested against, or use gopkg.in to redirect your import (e.g., `import "gopkg.in/Knetic/govaluate.v2"`). Master branch (while infrequent) _may_ at some point contain API breaking changes, and the author will have no way to communicate these to downstreams, other than creating a new major release. + +Releases will explicitly state when an API break happens, and if they do not specify an API break it should be safe to upgrade. + +License +-- + +This project is licensed under the MIT general use license. You're free to integrate, fork, and play with this code as you feel fit without consulting the author, as long as you provide proper credit to the author in your works. diff --git a/vendor/github.com/Knetic/govaluate/TokenKind.go b/vendor/github.com/Knetic/govaluate/TokenKind.go new file mode 100644 index 00000000..b7355bb7 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/TokenKind.go @@ -0,0 +1,72 @@ +package govaluate + +/* + Represents all valid types of tokens that a token can be. +*/ +type TokenKind int + +const ( + UNKNOWN TokenKind = iota + + PREFIX + NUMERIC + BOOLEAN + STRING + PATTERN + TIME + VARIABLE + FUNCTION + SEPARATOR + + COMPARATOR + LOGICALOP + MODIFIER + + CLAUSE + CLAUSE_CLOSE + + TERNARY +) + +/* + GetTokenKindString returns a string that describes the given TokenKind. + e.g., when passed the NUMERIC TokenKind, this returns the string "NUMERIC". +*/ +func (kind TokenKind) String() string { + + switch kind { + + case PREFIX: + return "PREFIX" + case NUMERIC: + return "NUMERIC" + case BOOLEAN: + return "BOOLEAN" + case STRING: + return "STRING" + case PATTERN: + return "PATTERN" + case TIME: + return "TIME" + case VARIABLE: + return "VARIABLE" + case FUNCTION: + return "FUNCTION" + case SEPARATOR: + return "SEPARATOR" + case COMPARATOR: + return "COMPARATOR" + case LOGICALOP: + return "LOGICALOP" + case MODIFIER: + return "MODIFIER" + case CLAUSE: + return "CLAUSE" + case CLAUSE_CLOSE: + return "CLAUSE_CLOSE" + case TERNARY: + return "TERNARY" + } + + return "UNKNOWN" +} diff --git a/vendor/github.com/Knetic/govaluate/evaluationStage.go b/vendor/github.com/Knetic/govaluate/evaluationStage.go new file mode 100644 index 00000000..bb668819 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/evaluationStage.go @@ -0,0 +1,359 @@ +package govaluate + +import ( + "errors" + "fmt" + "math" + "regexp" + "reflect" +) + +const ( + logicalErrorFormat string = "Value '%v' cannot be used with the logical operator '%v', it is not a bool" + modifierErrorFormat string = "Value '%v' cannot be used with the modifier '%v', it is not a number" + comparatorErrorFormat string = "Value '%v' cannot be used with the comparator '%v', it is not a number" + ternaryErrorFormat string = "Value '%v' cannot be used with the ternary operator '%v', it is not a bool" + prefixErrorFormat string = "Value '%v' cannot be used with the prefix '%v'" +) + +type evaluationOperator func(left interface{}, right interface{}, parameters Parameters) (interface{}, error) +type stageTypeCheck func(value interface{}) bool +type stageCombinedTypeCheck func(left interface{}, right interface{}) bool + +type evaluationStage struct { + symbol OperatorSymbol + + leftStage, rightStage *evaluationStage + + // the operation that will be used to evaluate this stage (such as adding [left] to [right] and return the result) + operator evaluationOperator + + // ensures that both left and right values are appropriate for this stage. Returns an error if they aren't operable. + leftTypeCheck stageTypeCheck + rightTypeCheck stageTypeCheck + + // if specified, will override whatever is used in "leftTypeCheck" and "rightTypeCheck". + // primarily used for specific operators that don't care which side a given type is on, but still requires one side to be of a given type + // (like string concat) + typeCheck stageCombinedTypeCheck + + // regardless of which type check is used, this string format will be used as the error message for type errors + typeErrorFormat string +} + +var ( + _true = interface{}(true) + _false = interface{}(false) +) + +func (this *evaluationStage) swapWith(other *evaluationStage) { + + temp := *other + other.setToNonStage(*this) + this.setToNonStage(temp) +} + +func (this *evaluationStage) setToNonStage(other evaluationStage) { + + this.symbol = other.symbol + this.operator = other.operator + this.leftTypeCheck = other.leftTypeCheck + this.rightTypeCheck = other.rightTypeCheck + this.typeCheck = other.typeCheck + this.typeErrorFormat = other.typeErrorFormat +} + +func (this *evaluationStage) isShortCircuitable() bool { + + switch this.symbol { + case AND: + fallthrough + case OR: + fallthrough + case TERNARY_TRUE: + fallthrough + case TERNARY_FALSE: + fallthrough + case COALESCE: + return true + } + + return false +} + +func noopStageRight(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return right, nil +} + +func addStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + + // string concat if either are strings + if isString(left) || isString(right) { + return fmt.Sprintf("%v%v", left, right), nil + } + + return left.(float64) + right.(float64), nil +} +func subtractStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return left.(float64) - right.(float64), nil +} +func multiplyStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return left.(float64) * right.(float64), nil +} +func divideStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return left.(float64) / right.(float64), nil +} +func exponentStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return math.Pow(left.(float64), right.(float64)), nil +} +func modulusStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return math.Mod(left.(float64), right.(float64)), nil +} +func gteStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + if isString(left) && isString(right) { + return boolIface(left.(string) >= right.(string)), nil + } + return boolIface(left.(float64) >= right.(float64)), nil +} +func gtStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + if isString(left) && isString(right) { + return boolIface(left.(string) > right.(string)), nil + } + return boolIface(left.(float64) > right.(float64)), nil +} +func lteStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + if isString(left) && isString(right) { + return boolIface(left.(string) <= right.(string)), nil + } + return boolIface(left.(float64) <= right.(float64)), nil +} +func ltStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + if isString(left) && isString(right) { + return boolIface(left.(string) < right.(string)), nil + } + return boolIface(left.(float64) < right.(float64)), nil +} +func equalStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return boolIface(reflect.DeepEqual(left, right)), nil +} +func notEqualStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return boolIface(!reflect.DeepEqual(left, right)), nil +} +func andStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return boolIface(left.(bool) && right.(bool)), nil +} +func orStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return boolIface(left.(bool) || right.(bool)), nil +} +func negateStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return -right.(float64), nil +} +func invertStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return boolIface(!right.(bool)), nil +} +func bitwiseNotStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return float64(^int64(right.(float64))), nil +} +func ternaryIfStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + if left.(bool) { + return right, nil + } + return nil, nil +} +func ternaryElseStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + if left != nil { + return left, nil + } + return right, nil +} + +func regexStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + + var pattern *regexp.Regexp + var err error + + switch right.(type) { + case string: + pattern, err = regexp.Compile(right.(string)) + if err != nil { + return nil, errors.New(fmt.Sprintf("Unable to compile regexp pattern '%v': %v", right, err)) + } + case *regexp.Regexp: + pattern = right.(*regexp.Regexp) + } + + return pattern.Match([]byte(left.(string))), nil +} + +func notRegexStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + + ret, err := regexStage(left, right, parameters) + if err != nil { + return nil, err + } + + return !(ret.(bool)), nil +} + +func bitwiseOrStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return float64(int64(left.(float64)) | int64(right.(float64))), nil +} +func bitwiseAndStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return float64(int64(left.(float64)) & int64(right.(float64))), nil +} +func bitwiseXORStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return float64(int64(left.(float64)) ^ int64(right.(float64))), nil +} +func leftShiftStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return float64(uint64(left.(float64)) << uint64(right.(float64))), nil +} +func rightShiftStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return float64(uint64(left.(float64)) >> uint64(right.(float64))), nil +} + +func makeParameterStage(parameterName string) evaluationOperator { + + return func(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + value, err := parameters.Get(parameterName) + if err != nil { + return nil, err + } + + return value, nil + } +} + +func makeLiteralStage(literal interface{}) evaluationOperator { + return func(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + return literal, nil + } +} + +func makeFunctionStage(function ExpressionFunction) evaluationOperator { + + return func(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + + if right == nil { + return function() + } + + switch right.(type) { + case []interface{}: + return function(right.([]interface{})...) + default: + return function(right) + } + + } +} + +func separatorStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + + var ret []interface{} + + switch left.(type) { + case []interface{}: + ret = append(left.([]interface{}), right) + default: + ret = []interface{}{left, right} + } + + return ret, nil +} + +func inStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { + + for _, value := range right.([]interface{}) { + if left == value { + return true, nil + } + } + return false, nil +} + +// + +func isString(value interface{}) bool { + + switch value.(type) { + case string: + return true + } + return false +} + +func isRegexOrString(value interface{}) bool { + + switch value.(type) { + case string: + return true + case *regexp.Regexp: + return true + } + return false +} + +func isBool(value interface{}) bool { + switch value.(type) { + case bool: + return true + } + return false +} + +func isFloat64(value interface{}) bool { + switch value.(type) { + case float64: + return true + } + return false +} + +/* + Addition usually means between numbers, but can also mean string concat. + String concat needs one (or both) of the sides to be a string. +*/ +func additionTypeCheck(left interface{}, right interface{}) bool { + + if isFloat64(left) && isFloat64(right) { + return true + } + if !isString(left) && !isString(right) { + return false + } + return true +} + +/* + Comparison can either be between numbers, or lexicographic between two strings, + but never between the two. +*/ +func comparatorTypeCheck(left interface{}, right interface{}) bool { + + if isFloat64(left) && isFloat64(right) { + return true + } + if isString(left) && isString(right) { + return true + } + return false +} + +func isArray(value interface{}) bool { + switch value.(type) { + case []interface{}: + return true + } + return false +} + +/* + Converting a boolean to an interface{} requires an allocation. + We can use interned bools to avoid this cost. +*/ +func boolIface(b bool) interface{} { + if b { + return _true + } + return _false +} diff --git a/vendor/github.com/Knetic/govaluate/expressionFunctions.go b/vendor/github.com/Knetic/govaluate/expressionFunctions.go new file mode 100644 index 00000000..ac6592b3 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/expressionFunctions.go @@ -0,0 +1,8 @@ +package govaluate + +/* + Represents a function that can be called from within an expression. + This method must return an error if, for any reason, it is unable to produce exactly one unambiguous result. + An error returned will halt execution of the expression. +*/ +type ExpressionFunction func(arguments ...interface{}) (interface{}, error) diff --git a/vendor/github.com/Knetic/govaluate/expressionOutputStream.go b/vendor/github.com/Knetic/govaluate/expressionOutputStream.go new file mode 100644 index 00000000..88a84163 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/expressionOutputStream.go @@ -0,0 +1,46 @@ +package govaluate + +import ( + "bytes" +) + +/* + Holds a series of "transactions" which represent each token as it is output by an outputter (such as ToSQLQuery()). + Some outputs (such as SQL) require a function call or non-c-like syntax to represent an expression. + To accomplish this, this struct keeps track of each translated token as it is output, and can return and rollback those transactions. +*/ +type expressionOutputStream struct { + transactions []string +} + +func (this *expressionOutputStream) add(transaction string) { + this.transactions = append(this.transactions, transaction) +} + +func (this *expressionOutputStream) rollback() string { + + index := len(this.transactions) - 1 + ret := this.transactions[index] + + this.transactions = this.transactions[:index] + return ret +} + +func (this *expressionOutputStream) createString(delimiter string) string { + + var retBuffer bytes.Buffer + var transaction string + + penultimate := len(this.transactions) - 1 + + for i := 0; i < penultimate; i++ { + + transaction = this.transactions[i] + + retBuffer.WriteString(transaction) + retBuffer.WriteString(delimiter) + } + retBuffer.WriteString(this.transactions[penultimate]) + + return retBuffer.String() +} diff --git a/vendor/github.com/Knetic/govaluate/lexerState.go b/vendor/github.com/Knetic/govaluate/lexerState.go new file mode 100644 index 00000000..9d43760e --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/lexerState.go @@ -0,0 +1,350 @@ +package govaluate + +import ( + "errors" + "fmt" +) + +type lexerState struct { + isEOF bool + isNullable bool + kind TokenKind + validNextKinds []TokenKind +} + +// lexer states. +// Constant for all purposes except compiler. +var validLexerStates = []lexerState{ + + lexerState{ + kind: UNKNOWN, + isEOF: false, + isNullable: true, + validNextKinds: []TokenKind{ + + PREFIX, + NUMERIC, + BOOLEAN, + VARIABLE, + PATTERN, + FUNCTION, + STRING, + TIME, + CLAUSE, + }, + }, + + lexerState{ + + kind: CLAUSE, + isEOF: false, + isNullable: true, + validNextKinds: []TokenKind{ + + PREFIX, + NUMERIC, + BOOLEAN, + VARIABLE, + PATTERN, + FUNCTION, + STRING, + TIME, + CLAUSE, + CLAUSE_CLOSE, + }, + }, + + lexerState{ + + kind: CLAUSE_CLOSE, + isEOF: true, + isNullable: true, + validNextKinds: []TokenKind{ + + COMPARATOR, + MODIFIER, + NUMERIC, + BOOLEAN, + VARIABLE, + STRING, + PATTERN, + TIME, + CLAUSE, + CLAUSE_CLOSE, + LOGICALOP, + TERNARY, + SEPARATOR, + }, + }, + + lexerState{ + + kind: NUMERIC, + isEOF: true, + isNullable: false, + validNextKinds: []TokenKind{ + + MODIFIER, + COMPARATOR, + LOGICALOP, + CLAUSE_CLOSE, + TERNARY, + SEPARATOR, + }, + }, + lexerState{ + + kind: BOOLEAN, + isEOF: true, + isNullable: false, + validNextKinds: []TokenKind{ + + MODIFIER, + COMPARATOR, + LOGICALOP, + CLAUSE_CLOSE, + TERNARY, + SEPARATOR, + }, + }, + lexerState{ + + kind: STRING, + isEOF: true, + isNullable: false, + validNextKinds: []TokenKind{ + + MODIFIER, + COMPARATOR, + LOGICALOP, + CLAUSE_CLOSE, + TERNARY, + SEPARATOR, + }, + }, + lexerState{ + + kind: TIME, + isEOF: true, + isNullable: false, + validNextKinds: []TokenKind{ + + MODIFIER, + COMPARATOR, + LOGICALOP, + CLAUSE_CLOSE, + SEPARATOR, + }, + }, + lexerState{ + + kind: PATTERN, + isEOF: true, + isNullable: false, + validNextKinds: []TokenKind{ + + MODIFIER, + COMPARATOR, + LOGICALOP, + CLAUSE_CLOSE, + SEPARATOR, + }, + }, + lexerState{ + + kind: VARIABLE, + isEOF: true, + isNullable: false, + validNextKinds: []TokenKind{ + + MODIFIER, + COMPARATOR, + LOGICALOP, + CLAUSE_CLOSE, + TERNARY, + SEPARATOR, + }, + }, + lexerState{ + + kind: MODIFIER, + isEOF: false, + isNullable: false, + validNextKinds: []TokenKind{ + + PREFIX, + NUMERIC, + VARIABLE, + FUNCTION, + STRING, + BOOLEAN, + CLAUSE, + CLAUSE_CLOSE, + }, + }, + lexerState{ + + kind: COMPARATOR, + isEOF: false, + isNullable: false, + validNextKinds: []TokenKind{ + + PREFIX, + NUMERIC, + BOOLEAN, + VARIABLE, + FUNCTION, + STRING, + TIME, + CLAUSE, + CLAUSE_CLOSE, + PATTERN, + }, + }, + lexerState{ + + kind: LOGICALOP, + isEOF: false, + isNullable: false, + validNextKinds: []TokenKind{ + + PREFIX, + NUMERIC, + BOOLEAN, + VARIABLE, + FUNCTION, + STRING, + TIME, + CLAUSE, + CLAUSE_CLOSE, + }, + }, + lexerState{ + + kind: PREFIX, + isEOF: false, + isNullable: false, + validNextKinds: []TokenKind{ + + NUMERIC, + BOOLEAN, + VARIABLE, + FUNCTION, + CLAUSE, + CLAUSE_CLOSE, + }, + }, + + lexerState{ + + kind: TERNARY, + isEOF: false, + isNullable: false, + validNextKinds: []TokenKind{ + + PREFIX, + NUMERIC, + BOOLEAN, + STRING, + TIME, + VARIABLE, + FUNCTION, + CLAUSE, + SEPARATOR, + }, + }, + lexerState{ + + kind: FUNCTION, + isEOF: false, + isNullable: false, + validNextKinds: []TokenKind{ + CLAUSE, + }, + }, + lexerState{ + + kind: SEPARATOR, + isEOF: false, + isNullable: true, + validNextKinds: []TokenKind{ + + PREFIX, + NUMERIC, + BOOLEAN, + STRING, + TIME, + VARIABLE, + FUNCTION, + CLAUSE, + }, + }, +} + +func (this lexerState) canTransitionTo(kind TokenKind) bool { + + for _, validKind := range this.validNextKinds { + + if validKind == kind { + return true + } + } + + return false +} + +func checkExpressionSyntax(tokens []ExpressionToken) error { + + var state lexerState + var lastToken ExpressionToken + var err error + + state = validLexerStates[0] + + for _, token := range tokens { + + if !state.canTransitionTo(token.Kind) { + + // call out a specific error for tokens looking like they want to be functions. + if lastToken.Kind == VARIABLE && token.Kind == CLAUSE { + return errors.New("Undefined function " + lastToken.Value.(string)) + } + + firstStateName := fmt.Sprintf("%s [%v]", state.kind.String(), lastToken.Value) + nextStateName := fmt.Sprintf("%s [%v]", token.Kind.String(), token.Value) + + return errors.New("Cannot transition token types from " + firstStateName + " to " + nextStateName) + } + + state, err = getLexerStateForToken(token.Kind) + if err != nil { + return err + } + + if !state.isNullable && token.Value == nil { + + errorMsg := fmt.Sprintf("Token kind '%v' cannot have a nil value", token.Kind.String()) + return errors.New(errorMsg) + } + + lastToken = token + } + + if !state.isEOF { + return errors.New("Unexpected end of expression") + } + return nil +} + +func getLexerStateForToken(kind TokenKind) (lexerState, error) { + + for _, possibleState := range validLexerStates { + + if possibleState.kind == kind { + return possibleState, nil + } + } + + errorMsg := fmt.Sprintf("No lexer state found for token kind '%v'\n", kind.String()) + return validLexerStates[0], errors.New(errorMsg) +} diff --git a/vendor/github.com/Knetic/govaluate/lexerStream.go b/vendor/github.com/Knetic/govaluate/lexerStream.go new file mode 100644 index 00000000..b72e6bdb --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/lexerStream.go @@ -0,0 +1,39 @@ +package govaluate + +type lexerStream struct { + source []rune + position int + length int +} + +func newLexerStream(source string) *lexerStream { + + var ret *lexerStream + var runes []rune + + for _, character := range source { + runes = append(runes, character) + } + + ret = new(lexerStream) + ret.source = runes + ret.length = len(runes) + return ret +} + +func (this *lexerStream) readCharacter() rune { + + var character rune + + character = this.source[this.position] + this.position += 1 + return character +} + +func (this *lexerStream) rewind(amount int) { + this.position -= amount +} + +func (this lexerStream) canRead() bool { + return this.position < this.length +} diff --git a/vendor/github.com/Knetic/govaluate/parameters.go b/vendor/github.com/Knetic/govaluate/parameters.go new file mode 100644 index 00000000..6c5b9ecb --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/parameters.go @@ -0,0 +1,32 @@ +package govaluate + +import ( + "errors" +) + +/* + Parameters is a collection of named parameters that can be used by an EvaluableExpression to retrieve parameters + when an expression tries to use them. +*/ +type Parameters interface { + + /* + Get gets the parameter of the given name, or an error if the parameter is unavailable. + Failure to find the given parameter should be indicated by returning an error. + */ + Get(name string) (interface{}, error) +} + +type MapParameters map[string]interface{} + +func (p MapParameters) Get(name string) (interface{}, error) { + + value, found := p[name] + + if !found { + errorMessage := "No parameter '" + name + "' found." + return nil, errors.New(errorMessage) + } + + return value, nil +} diff --git a/vendor/github.com/Knetic/govaluate/parsing.go b/vendor/github.com/Knetic/govaluate/parsing.go new file mode 100644 index 00000000..7eae7115 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/parsing.go @@ -0,0 +1,450 @@ +package govaluate + +import ( + "bytes" + "errors" + "fmt" + "regexp" + "strconv" + "time" + "unicode" +) + +func parseTokens(expression string, functions map[string]ExpressionFunction) ([]ExpressionToken, error) { + + var ret []ExpressionToken + var token ExpressionToken + var stream *lexerStream + var state lexerState + var err error + var found bool + + stream = newLexerStream(expression) + state = validLexerStates[0] + + for stream.canRead() { + + token, err, found = readToken(stream, state, functions) + + if err != nil { + return ret, err + } + + if !found { + break + } + + state, err = getLexerStateForToken(token.Kind) + if err != nil { + return ret, err + } + + // append this valid token + ret = append(ret, token) + } + + err = checkBalance(ret) + if err != nil { + return nil, err + } + + return ret, nil +} + +func readToken(stream *lexerStream, state lexerState, functions map[string]ExpressionFunction) (ExpressionToken, error, bool) { + + var function ExpressionFunction + var ret ExpressionToken + var tokenValue interface{} + var tokenTime time.Time + var tokenString string + var kind TokenKind + var character rune + var found bool + var completed bool + var err error + + // numeric is 0-9, or . + // string starts with ' + // variable is alphanumeric, always starts with a letter + // bracket always means variable + // symbols are anything non-alphanumeric + // all others read into a buffer until they reach the end of the stream + for stream.canRead() { + + character = stream.readCharacter() + + if unicode.IsSpace(character) { + continue + } + + kind = UNKNOWN + + // numeric constant + if isNumeric(character) { + + tokenString = readTokenUntilFalse(stream, isNumeric) + tokenValue, err = strconv.ParseFloat(tokenString, 64) + + if err != nil { + errorMsg := fmt.Sprintf("Unable to parse numeric value '%v' to float64\n", tokenString) + return ExpressionToken{}, errors.New(errorMsg), false + } + kind = NUMERIC + break + } + + // comma, separator + if character == ',' { + + tokenValue = "," + kind = SEPARATOR + break + } + + // escaped variable + if character == '[' { + + tokenValue, completed = readUntilFalse(stream, true, false, true, isNotClosingBracket) + kind = VARIABLE + + if !completed { + return ExpressionToken{}, errors.New("Unclosed parameter bracket"), false + } + + // above method normally rewinds us to the closing bracket, which we want to skip. + stream.rewind(-1) + break + } + + // regular variable - or function? + if unicode.IsLetter(character) { + + tokenString = readTokenUntilFalse(stream, isVariableName) + + tokenValue = tokenString + kind = VARIABLE + + // boolean? + if tokenValue == "true" { + + kind = BOOLEAN + tokenValue = true + } else { + + if tokenValue == "false" { + + kind = BOOLEAN + tokenValue = false + } + } + + // textual operator? + if tokenValue == "in" || tokenValue == "IN" { + + // force lower case for consistency + tokenValue = "in" + kind = COMPARATOR + } + + // function? + function, found = functions[tokenString] + if found { + kind = FUNCTION + tokenValue = function + } + break + } + + if !isNotQuote(character) { + tokenValue, completed = readUntilFalse(stream, true, false, true, isNotQuote) + + if !completed { + return ExpressionToken{}, errors.New("Unclosed string literal"), false + } + + // advance the stream one position, since reading until false assumes the terminator is a real token + stream.rewind(-1) + + // check to see if this can be parsed as a time. + tokenTime, found = tryParseTime(tokenValue.(string)) + if found { + kind = TIME + tokenValue = tokenTime + } else { + kind = STRING + } + break + } + + if character == '(' { + tokenValue = character + kind = CLAUSE + break + } + + if character == ')' { + tokenValue = character + kind = CLAUSE_CLOSE + break + } + + // must be a known symbol + tokenString = readTokenUntilFalse(stream, isNotAlphanumeric) + tokenValue = tokenString + + // quick hack for the case where "-" can mean "prefixed negation" or "minus", which are used + // very differently. + if state.canTransitionTo(PREFIX) { + _, found = prefixSymbols[tokenString] + if found { + + kind = PREFIX + break + } + } + _, found = modifierSymbols[tokenString] + if found { + + kind = MODIFIER + break + } + + _, found = logicalSymbols[tokenString] + if found { + + kind = LOGICALOP + break + } + + _, found = comparatorSymbols[tokenString] + if found { + + kind = COMPARATOR + break + } + + _, found = ternarySymbols[tokenString] + if found { + + kind = TERNARY + break + } + + errorMessage := fmt.Sprintf("Invalid token: '%s'", tokenString) + return ret, errors.New(errorMessage), false + } + + ret.Kind = kind + ret.Value = tokenValue + + return ret, nil, (kind != UNKNOWN) +} + +func readTokenUntilFalse(stream *lexerStream, condition func(rune) bool) string { + + var ret string + + stream.rewind(1) + ret, _ = readUntilFalse(stream, false, true, true, condition) + return ret +} + +/* + Returns the string that was read until the given [condition] was false, or whitespace was broken. + Returns false if the stream ended before whitespace was broken or condition was met. +*/ +func readUntilFalse(stream *lexerStream, includeWhitespace bool, breakWhitespace bool, allowEscaping bool, condition func(rune) bool) (string, bool) { + + var tokenBuffer bytes.Buffer + var character rune + var conditioned bool + + conditioned = false + + for stream.canRead() { + + character = stream.readCharacter() + + // Use backslashes to escape anything + if allowEscaping && character == '\\' { + + character = stream.readCharacter() + tokenBuffer.WriteString(string(character)) + continue + } + + if unicode.IsSpace(character) { + + if breakWhitespace && tokenBuffer.Len() > 0 { + conditioned = true + break + } + if !includeWhitespace { + continue + } + } + + if condition(character) { + tokenBuffer.WriteString(string(character)) + } else { + conditioned = true + stream.rewind(1) + break + } + } + + return tokenBuffer.String(), conditioned +} + +/* + Checks to see if any optimizations can be performed on the given [tokens], which form a complete, valid expression. + The returns slice will represent the optimized (or unmodified) list of tokens to use. +*/ +func optimizeTokens(tokens []ExpressionToken) ([]ExpressionToken, error) { + + var token ExpressionToken + var symbol OperatorSymbol + var err error + var index int + + for index, token = range tokens { + + // if we find a regex operator, and the right-hand value is a constant, precompile and replace with a pattern. + if token.Kind != COMPARATOR { + continue + } + + symbol = comparatorSymbols[token.Value.(string)] + if symbol != REQ && symbol != NREQ { + continue + } + + index++ + token = tokens[index] + if token.Kind == STRING { + + token.Kind = PATTERN + token.Value, err = regexp.Compile(token.Value.(string)) + + if err != nil { + return tokens, err + } + + tokens[index] = token + } + } + return tokens, nil +} + +/* + Checks the balance of tokens which have multiple parts, such as parenthesis. +*/ +func checkBalance(tokens []ExpressionToken) error { + + var stream *tokenStream + var token ExpressionToken + var parens int + + stream = newTokenStream(tokens) + + for stream.hasNext() { + + token = stream.next() + if token.Kind == CLAUSE { + parens++ + continue + } + if token.Kind == CLAUSE_CLOSE { + parens-- + continue + } + } + + if parens != 0 { + return errors.New("Unbalanced parenthesis") + } + return nil +} + +func isNumeric(character rune) bool { + + return unicode.IsDigit(character) || character == '.' +} + +func isNotQuote(character rune) bool { + + return character != '\'' && character != '"' +} + +func isNotAlphanumeric(character rune) bool { + + return !(unicode.IsDigit(character) || + unicode.IsLetter(character) || + character == '(' || + character == ')' || + !isNotQuote(character)) +} + +func isVariableName(character rune) bool { + + return unicode.IsLetter(character) || + unicode.IsDigit(character) || + character == '_' +} + +func isNotClosingBracket(character rune) bool { + + return character != ']' +} + +/* + Attempts to parse the [candidate] as a Time. + Tries a series of standardized date formats, returns the Time if one applies, + otherwise returns false through the second return. +*/ +func tryParseTime(candidate string) (time.Time, bool) { + + var ret time.Time + var found bool + + timeFormats := [...]string{ + time.ANSIC, + time.UnixDate, + time.RubyDate, + time.Kitchen, + time.RFC3339, + time.RFC3339Nano, + "2006-01-02", // RFC 3339 + "2006-01-02 15:04", // RFC 3339 with minutes + "2006-01-02 15:04:05", // RFC 3339 with seconds + "2006-01-02 15:04:05-07:00", // RFC 3339 with seconds and timezone + "2006-01-02T15Z0700", // ISO8601 with hour + "2006-01-02T15:04Z0700", // ISO8601 with minutes + "2006-01-02T15:04:05Z0700", // ISO8601 with seconds + "2006-01-02T15:04:05.999999999Z0700", // ISO8601 with nanoseconds + } + + for _, format := range timeFormats { + + ret, found = tryParseExactTime(candidate, format) + if found { + return ret, true + } + } + + return time.Now(), false +} + +func tryParseExactTime(candidate string, format string) (time.Time, bool) { + + var ret time.Time + var err error + + ret, err = time.ParseInLocation(format, candidate, time.Local) + if err != nil { + return time.Now(), false + } + + return ret, true +} diff --git a/vendor/github.com/Knetic/govaluate/sanitizedParameters.go b/vendor/github.com/Knetic/govaluate/sanitizedParameters.go new file mode 100644 index 00000000..5cfab9b0 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/sanitizedParameters.go @@ -0,0 +1,71 @@ +package govaluate + +// sanitizedParameters is a wrapper for Parameters that does sanitization as +// parameters are accessed. +type sanitizedParameters struct { + orig Parameters +} + +func (p sanitizedParameters) Get(key string) (interface{}, error) { + value, err := p.orig.Get(key) + if err != nil { + return nil, err + } + + // should be converted to fixed point? + if isFixedPoint(value) { + return castFixedPoint(value), nil + } + + return value, nil +} + +func isFixedPoint(value interface{}) bool { + + switch value.(type) { + case uint8: + return true + case uint16: + return true + case uint32: + return true + case uint64: + return true + case int8: + return true + case int16: + return true + case int32: + return true + case int64: + return true + case int: + return true + } + return false +} + +func castFixedPoint(value interface{}) float64 { + switch value.(type) { + case uint8: + return float64(value.(uint8)) + case uint16: + return float64(value.(uint16)) + case uint32: + return float64(value.(uint32)) + case uint64: + return float64(value.(uint64)) + case int8: + return float64(value.(int8)) + case int16: + return float64(value.(int16)) + case int32: + return float64(value.(int32)) + case int64: + return float64(value.(int64)) + case int: + return float64(value.(int)) + } + + return 0.0 +} diff --git a/vendor/github.com/Knetic/govaluate/stagePlanner.go b/vendor/github.com/Knetic/govaluate/stagePlanner.go new file mode 100644 index 00000000..a195c700 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/stagePlanner.go @@ -0,0 +1,675 @@ +package govaluate + +import ( + "errors" + "time" + "fmt" +) + +var stageSymbolMap = map[OperatorSymbol]evaluationOperator{ + EQ: equalStage, + NEQ: notEqualStage, + GT: gtStage, + LT: ltStage, + GTE: gteStage, + LTE: lteStage, + REQ: regexStage, + NREQ: notRegexStage, + AND: andStage, + OR: orStage, + IN: inStage, + BITWISE_OR: bitwiseOrStage, + BITWISE_AND: bitwiseAndStage, + BITWISE_XOR: bitwiseXORStage, + BITWISE_LSHIFT: leftShiftStage, + BITWISE_RSHIFT: rightShiftStage, + PLUS: addStage, + MINUS: subtractStage, + MULTIPLY: multiplyStage, + DIVIDE: divideStage, + MODULUS: modulusStage, + EXPONENT: exponentStage, + NEGATE: negateStage, + INVERT: invertStage, + BITWISE_NOT: bitwiseNotStage, + TERNARY_TRUE: ternaryIfStage, + TERNARY_FALSE: ternaryElseStage, + COALESCE: ternaryElseStage, + SEPARATE: separatorStage, +} + +/* + A "precedent" is a function which will recursively parse new evaluateionStages from a given stream of tokens. + It's called a `precedent` because it is expected to handle exactly what precedence of operator, + and defer to other `precedent`s for other operators. +*/ +type precedent func(stream *tokenStream) (*evaluationStage, error) + +/* + A convenience function for specifying the behavior of a `precedent`. + Most `precedent` functions can be described by the same function, just with different type checks, symbols, and error formats. + This struct is passed to `makePrecedentFromPlanner` to create a `precedent` function. +*/ +type precedencePlanner struct { + validSymbols map[string]OperatorSymbol + validKinds []TokenKind + + typeErrorFormat string + + next precedent + nextRight precedent +} + +var planPrefix precedent +var planExponential precedent +var planMultiplicative precedent +var planAdditive precedent +var planBitwise precedent +var planShift precedent +var planComparator precedent +var planLogicalAnd precedent +var planLogicalOr precedent +var planTernary precedent +var planSeparator precedent + +func init() { + + // all these stages can use the same code (in `planPrecedenceLevel`) to execute, + // they simply need different type checks, symbols, and recursive precedents. + // While not all precedent phases are listed here, most can be represented this way. + planPrefix = makePrecedentFromPlanner(&precedencePlanner{ + validSymbols: prefixSymbols, + validKinds: []TokenKind{PREFIX}, + typeErrorFormat: prefixErrorFormat, + nextRight: planFunction, + }) + planExponential = makePrecedentFromPlanner(&precedencePlanner{ + validSymbols: exponentialSymbolsS, + validKinds: []TokenKind{MODIFIER}, + typeErrorFormat: modifierErrorFormat, + next: planFunction, + }) + planMultiplicative = makePrecedentFromPlanner(&precedencePlanner{ + validSymbols: multiplicativeSymbols, + validKinds: []TokenKind{MODIFIER}, + typeErrorFormat: modifierErrorFormat, + next: planExponential, + }) + planAdditive = makePrecedentFromPlanner(&precedencePlanner{ + validSymbols: additiveSymbols, + validKinds: []TokenKind{MODIFIER}, + typeErrorFormat: modifierErrorFormat, + next: planMultiplicative, + }) + planShift = makePrecedentFromPlanner(&precedencePlanner{ + validSymbols: bitwiseShiftSymbols, + validKinds: []TokenKind{MODIFIER}, + typeErrorFormat: modifierErrorFormat, + next: planAdditive, + }) + planBitwise = makePrecedentFromPlanner(&precedencePlanner{ + validSymbols: bitwiseSymbols, + validKinds: []TokenKind{MODIFIER}, + typeErrorFormat: modifierErrorFormat, + next: planShift, + }) + planComparator = makePrecedentFromPlanner(&precedencePlanner{ + validSymbols: comparatorSymbols, + validKinds: []TokenKind{COMPARATOR}, + typeErrorFormat: comparatorErrorFormat, + next: planBitwise, + }) + planLogicalAnd = makePrecedentFromPlanner(&precedencePlanner{ + validSymbols: map[string]OperatorSymbol{"&&": AND}, + validKinds: []TokenKind{LOGICALOP}, + typeErrorFormat: logicalErrorFormat, + next: planComparator, + }) + planLogicalOr = makePrecedentFromPlanner(&precedencePlanner{ + validSymbols: map[string]OperatorSymbol{"||": OR}, + validKinds: []TokenKind{LOGICALOP}, + typeErrorFormat: logicalErrorFormat, + next: planLogicalAnd, + }) + planTernary = makePrecedentFromPlanner(&precedencePlanner{ + validSymbols: ternarySymbols, + validKinds: []TokenKind{TERNARY}, + typeErrorFormat: ternaryErrorFormat, + next: planLogicalOr, + }) + planSeparator = makePrecedentFromPlanner(&precedencePlanner{ + validSymbols: separatorSymbols, + validKinds: []TokenKind{SEPARATOR}, + next: planTernary, + }) +} + +/* + Given a planner, creates a function which will evaluate a specific precedence level of operators, + and link it to other `precedent`s which recurse to parse other precedence levels. +*/ +func makePrecedentFromPlanner(planner *precedencePlanner) precedent { + + var generated precedent + var nextRight precedent + + generated = func(stream *tokenStream) (*evaluationStage, error) { + return planPrecedenceLevel( + stream, + planner.typeErrorFormat, + planner.validSymbols, + planner.validKinds, + nextRight, + planner.next, + ) + } + + if planner.nextRight != nil { + nextRight = planner.nextRight + } else { + nextRight = generated + } + + return generated +} + +/* + Creates a `evaluationStageList` object which represents an execution plan (or tree) + which is used to completely evaluate a set of tokens at evaluation-time. + The three stages of evaluation can be thought of as parsing strings to tokens, then tokens to a stage list, then evaluation with parameters. +*/ +func planStages(tokens []ExpressionToken) (*evaluationStage, error) { + + stream := newTokenStream(tokens) + + stage, err := planTokens(stream) + if err != nil { + return nil, err + } + + // while we're now fully-planned, we now need to re-order same-precedence operators. + // this could probably be avoided with a different planning method + reorderStages(stage) + + stage = elideLiterals(stage) + return stage, nil +} + +func planTokens(stream *tokenStream) (*evaluationStage, error) { + + if !stream.hasNext() { + return nil, nil + } + + return planSeparator(stream) +} + +/* + The most usual method of parsing an evaluation stage for a given precedence. + Most stages use the same logic +*/ +func planPrecedenceLevel( + stream *tokenStream, + typeErrorFormat string, + validSymbols map[string]OperatorSymbol, + validKinds []TokenKind, + rightPrecedent precedent, + leftPrecedent precedent) (*evaluationStage, error) { + + var token ExpressionToken + var symbol OperatorSymbol + var leftStage, rightStage *evaluationStage + var checks typeChecks + var err error + var keyFound bool + + if leftPrecedent != nil { + + leftStage, err = leftPrecedent(stream) + if err != nil { + return nil, err + } + } + + for stream.hasNext() { + + token = stream.next() + + if len(validKinds) > 0 { + + keyFound = false + for _, kind := range validKinds { + if kind == token.Kind { + keyFound = true + break + } + } + + if !keyFound { + break + } + } + + if validSymbols != nil { + + if !isString(token.Value) { + break + } + + symbol, keyFound = validSymbols[token.Value.(string)] + if !keyFound { + break + } + } + + if rightPrecedent != nil { + rightStage, err = rightPrecedent(stream) + if err != nil { + return nil, err + } + } + + checks = findTypeChecks(symbol) + + return &evaluationStage{ + + symbol: symbol, + leftStage: leftStage, + rightStage: rightStage, + operator: stageSymbolMap[symbol], + + leftTypeCheck: checks.left, + rightTypeCheck: checks.right, + typeCheck: checks.combined, + typeErrorFormat: typeErrorFormat, + }, nil + } + + stream.rewind() + return leftStage, nil +} + +/* + A special case where functions need to be of higher precedence than values, and need a special wrapped execution stage operator. +*/ +func planFunction(stream *tokenStream) (*evaluationStage, error) { + + var token ExpressionToken + var rightStage *evaluationStage + var err error + + token = stream.next() + + if token.Kind != FUNCTION { + stream.rewind() + return planValue(stream) + } + + rightStage, err = planValue(stream) + if err != nil { + return nil, err + } + + return &evaluationStage{ + + symbol: FUNCTIONAL, + rightStage: rightStage, + operator: makeFunctionStage(token.Value.(ExpressionFunction)), + typeErrorFormat: "Unable to run function '%v': %v", + }, nil +} + +/* + A truly special precedence function, this handles all the "lowest-case" errata of the process, including literals, parmeters, + clauses, and prefixes. +*/ +func planValue(stream *tokenStream) (*evaluationStage, error) { + + var token ExpressionToken + var symbol OperatorSymbol + var ret *evaluationStage + var operator evaluationOperator + var err error + + token = stream.next() + + switch token.Kind { + + case CLAUSE: + + ret, err = planTokens(stream) + if err != nil { + return nil, err + } + + // advance past the CLAUSE_CLOSE token. We know that it's a CLAUSE_CLOSE, because at parse-time we check for unbalanced parens. + stream.next() + + // the stage we got represents all of the logic contained within the parens + // but for technical reasons, we need to wrap this stage in a "noop" stage which breaks long chains of precedence. + // see github #33. + ret = &evaluationStage { + rightStage: ret, + operator: noopStageRight, + symbol: NOOP, + } + + return ret, nil + + case CLAUSE_CLOSE: + + // when functions have empty params, this will be hit. In this case, we don't have any evaluation stage to do, + // so we just return nil so that the stage planner continues on its way. + stream.rewind() + return nil, nil + + case VARIABLE: + operator = makeParameterStage(token.Value.(string)) + + case NUMERIC: + fallthrough + case STRING: + fallthrough + case PATTERN: + fallthrough + case BOOLEAN: + symbol = LITERAL + operator = makeLiteralStage(token.Value) + case TIME: + symbol = LITERAL + operator = makeLiteralStage(float64(token.Value.(time.Time).Unix())) + + case PREFIX: + stream.rewind() + return planPrefix(stream) + } + + if operator == nil { + errorMsg := fmt.Sprintf("Unable to plan token kind: '%s', value: '%v'", token.Kind.String(), token.Value) + return nil, errors.New(errorMsg) + } + + return &evaluationStage{ + symbol: symbol, + operator: operator, + }, nil +} + +/* + Convenience function to pass a triplet of typechecks between `findTypeChecks` and `planPrecedenceLevel`. + Each of these members may be nil, which indicates that type does not matter for that value. +*/ +type typeChecks struct { + left stageTypeCheck + right stageTypeCheck + combined stageCombinedTypeCheck +} + +/* + Maps a given [symbol] to a set of typechecks to be used during runtime. +*/ +func findTypeChecks(symbol OperatorSymbol) typeChecks { + + switch symbol { + case GT: + fallthrough + case LT: + fallthrough + case GTE: + fallthrough + case LTE: + return typeChecks{ + combined: comparatorTypeCheck, + } + case REQ: + fallthrough + case NREQ: + return typeChecks{ + left: isString, + right: isRegexOrString, + } + case AND: + fallthrough + case OR: + return typeChecks{ + left: isBool, + right: isBool, + } + case IN: + return typeChecks{ + right: isArray, + } + case BITWISE_LSHIFT: + fallthrough + case BITWISE_RSHIFT: + fallthrough + case BITWISE_OR: + fallthrough + case BITWISE_AND: + fallthrough + case BITWISE_XOR: + return typeChecks{ + left: isFloat64, + right: isFloat64, + } + case PLUS: + return typeChecks{ + combined: additionTypeCheck, + } + case MINUS: + fallthrough + case MULTIPLY: + fallthrough + case DIVIDE: + fallthrough + case MODULUS: + fallthrough + case EXPONENT: + return typeChecks{ + left: isFloat64, + right: isFloat64, + } + case NEGATE: + return typeChecks{ + right: isFloat64, + } + case INVERT: + return typeChecks{ + right: isBool, + } + case BITWISE_NOT: + return typeChecks{ + right: isFloat64, + } + case TERNARY_TRUE: + return typeChecks{ + left: isBool, + } + + // unchecked cases + case EQ: + fallthrough + case NEQ: + return typeChecks{} + case TERNARY_FALSE: + fallthrough + case COALESCE: + fallthrough + default: + return typeChecks{} + } +} + +/* + During stage planning, stages of equal precedence are parsed such that they'll be evaluated in reverse order. + For commutative operators like "+" or "-", it's no big deal. But for order-specific operators, it ruins the expected result. +*/ +func reorderStages(rootStage *evaluationStage) { + + // traverse every rightStage until we find multiples in a row of the same precedence. + var identicalPrecedences []*evaluationStage + var currentStage, nextStage *evaluationStage + var precedence, currentPrecedence operatorPrecedence + + nextStage = rootStage + precedence = findOperatorPrecedenceForSymbol(rootStage.symbol) + + for nextStage != nil { + + currentStage = nextStage + nextStage = currentStage.rightStage + + // left depth first, since this entire method only looks for precedences down the right side of the tree + if currentStage.leftStage != nil { + reorderStages(currentStage.leftStage) + } + + currentPrecedence = findOperatorPrecedenceForSymbol(currentStage.symbol) + + if currentPrecedence == precedence { + identicalPrecedences = append(identicalPrecedences, currentStage) + continue + } + + // precedence break. + // See how many in a row we had, and reorder if there's more than one. + if len(identicalPrecedences) > 1 { + mirrorStageSubtree(identicalPrecedences) + } + + identicalPrecedences = []*evaluationStage{currentStage} + precedence = currentPrecedence + } + + if len(identicalPrecedences) > 1 { + mirrorStageSubtree(identicalPrecedences) + } +} + +/* + Performs a "mirror" on a subtree of stages. + This mirror functionally inverts the order of execution for all members of the [stages] list. + That list is assumed to be a root-to-leaf (ordered) list of evaluation stages, where each is a right-hand stage of the last. +*/ +func mirrorStageSubtree(stages []*evaluationStage) { + + var rootStage, inverseStage, carryStage, frontStage *evaluationStage + + stagesLength := len(stages) + + // reverse all right/left + for _, frontStage = range stages { + + carryStage = frontStage.rightStage + frontStage.rightStage = frontStage.leftStage + frontStage.leftStage = carryStage + } + + // end left swaps with root right + rootStage = stages[0] + frontStage = stages[stagesLength-1] + + carryStage = frontStage.leftStage + frontStage.leftStage = rootStage.rightStage + rootStage.rightStage = carryStage + + // for all non-root non-end stages, right is swapped with inverse stage right in list + for i := 0; i < (stagesLength-2)/2+1; i++ { + + frontStage = stages[i+1] + inverseStage = stages[stagesLength-i-1] + + carryStage = frontStage.rightStage + frontStage.rightStage = inverseStage.rightStage + inverseStage.rightStage = carryStage + } + + // swap all other information with inverse stages + for i := 0; i < stagesLength/2; i++ { + + frontStage = stages[i] + inverseStage = stages[stagesLength-i-1] + frontStage.swapWith(inverseStage) + } +} + +/* + Recurses through all operators in the entire tree, eliding operators where both sides are literals. +*/ +func elideLiterals(root *evaluationStage) *evaluationStage { + + if root.leftStage != nil { + root.leftStage = elideLiterals(root.leftStage) + } + + if root.rightStage != nil { + root.rightStage = elideLiterals(root.rightStage) + } + + return elideStage(root) +} + +/* + Elides a specific stage, if possible. + Returns the unmodified [root] stage if it cannot or should not be elided. + Otherwise, returns a new stage representing the condensed value from the elided stages. +*/ +func elideStage(root *evaluationStage) *evaluationStage { + + var leftValue, rightValue, result interface{} + var err error + + // right side must be a non-nil value. Left side must be nil or a value. + if root.rightStage == nil || + root.rightStage.symbol != LITERAL || + root.leftStage == nil || + root.leftStage.symbol != LITERAL { + return root + } + + // don't elide some operators + switch root.symbol { + case SEPARATE: + fallthrough + case IN: + return root + } + + // both sides are values, get their actual values. + // errors should be near-impossible here. If we encounter them, just abort this optimization. + leftValue, err = root.leftStage.operator(nil, nil, nil) + if err != nil { + return root + } + + rightValue, err = root.rightStage.operator(nil, nil, nil) + if err != nil { + return root + } + + // typcheck, since the grammar checker is a bit loose with which operator symbols go together. + err = typeCheck(root.leftTypeCheck, leftValue, root.symbol, root.typeErrorFormat) + if err != nil { + return root + } + + err = typeCheck(root.rightTypeCheck, rightValue, root.symbol, root.typeErrorFormat) + if err != nil { + return root + } + + if root.typeCheck != nil && !root.typeCheck(leftValue, rightValue) { + return root + } + + // pre-calculate, and return a new stage representing the result. + result, err = root.operator(leftValue, rightValue, nil) + if err != nil { + return root + } + + return &evaluationStage { + symbol: LITERAL, + operator: makeLiteralStage(result), + } +} diff --git a/vendor/github.com/Knetic/govaluate/test.sh b/vendor/github.com/Knetic/govaluate/test.sh new file mode 100644 index 00000000..6a67ae7a --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/test.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +# Script that runs tests, code coverage, and benchmarks all at once. +# Builds a symlink in /tmp, mostly to avoid messing with GOPATH at the user's shell level. + +TEMPORARY_PATH="/tmp/govaluate_test" +SRC_PATH="${TEMPORARY_PATH}/src" +FULL_PATH="${TEMPORARY_PATH}/src/govaluate" + +# set up temporary directory +rm -rf "${FULL_PATH}" +mkdir -p "${SRC_PATH}" + +ln -s $(pwd) "${FULL_PATH}" +export GOPATH="${TEMPORARY_PATH}" + +pushd "${TEMPORARY_PATH}/src/govaluate" + +# run the actual tests. +export GOVALUATE_TORTURE_TEST="true" +go test -bench=. -benchmem -coverprofile coverage.out +status=$? + +if [ "${status}" != 0 ]; +then + exit $status +fi + +# coverage +go tool cover -func=coverage.out + +popd diff --git a/vendor/github.com/Knetic/govaluate/tokenStream.go b/vendor/github.com/Knetic/govaluate/tokenStream.go new file mode 100644 index 00000000..d0029209 --- /dev/null +++ b/vendor/github.com/Knetic/govaluate/tokenStream.go @@ -0,0 +1,36 @@ +package govaluate + +type tokenStream struct { + tokens []ExpressionToken + index int + tokenLength int +} + +func newTokenStream(tokens []ExpressionToken) *tokenStream { + + var ret *tokenStream + + ret = new(tokenStream) + ret.tokens = tokens + ret.tokenLength = len(tokens) + return ret +} + +func (this *tokenStream) rewind() { + this.index -= 1 +} + +func (this *tokenStream) next() ExpressionToken { + + var token ExpressionToken + + token = this.tokens[this.index] + + this.index += 1 + return token +} + +func (this tokenStream) hasNext() bool { + + return this.index < this.tokenLength +} diff --git a/vendor/github.com/beego/goyaml2/.gitignore b/vendor/github.com/beego/goyaml2/.gitignore new file mode 100644 index 00000000..00268614 --- /dev/null +++ b/vendor/github.com/beego/goyaml2/.gitignore @@ -0,0 +1,22 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe diff --git a/vendor/github.com/beego/goyaml2/README.md b/vendor/github.com/beego/goyaml2/README.md new file mode 100644 index 00000000..7394caed --- /dev/null +++ b/vendor/github.com/beego/goyaml2/README.md @@ -0,0 +1,4 @@ +goyaml2 +======= + +YAML for Golang \ No newline at end of file diff --git a/vendor/github.com/beego/goyaml2/conv.go b/vendor/github.com/beego/goyaml2/conv.go new file mode 100644 index 00000000..9a1cbaa4 --- /dev/null +++ b/vendor/github.com/beego/goyaml2/conv.go @@ -0,0 +1,41 @@ +package goyaml2 + +import ( + "regexp" + "strconv" + //"time" +) + +var ( + RE_INT, _ = regexp.Compile("^[0-9,]+$") + RE_FLOAT, _ = regexp.Compile("^[0-9]+[.][0-9]+$") + RE_DATE, _ = regexp.Compile("^[0-9]{4}-[0-9]{2}-[0-9]{2}$") + RE_TIME, _ = regexp.Compile("^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$") +) + +func string2Val(str string) interface{} { + tmp := []byte(str) + switch { + case str == "false": + return false + case str == "true": + return true + case RE_INT.Match(tmp): + // TODO check err + _int, _ := strconv.ParseInt(str, 10, 64) + return _int + case RE_FLOAT.Match(tmp): + _float, _ := strconv.ParseFloat(str, 64) + return _float + //TODO support time or Not? + /* + case RE_DATE.Match(tmp): + _date, _ := time.Parse("2006-01-02", str) + return _date + case RE_TIME.Match(tmp): + _time, _ := time.Parse("2006-01-02 03:04:05", str) + return _time + */ + } + return str +} diff --git a/vendor/github.com/beego/goyaml2/reader.go b/vendor/github.com/beego/goyaml2/reader.go new file mode 100644 index 00000000..5044d671 --- /dev/null +++ b/vendor/github.com/beego/goyaml2/reader.go @@ -0,0 +1,433 @@ +package goyaml2 + +import ( + "bufio" + "fmt" + "github.com/wendal/errors" + "io" + "log" + "strings" +) + +const ( + DEBUG = true + MAP_KEY_ONLY = iota +) + +func Read(r io.Reader) (interface{}, error) { + yr := &yamlReader{} + yr.br = bufio.NewReader(r) + obj, err := yr.ReadObject(0) + if err == io.EOF { + err = nil + } + if obj == nil { + log.Println("Obj == nil") + } + return obj, err +} + +type yamlReader struct { + br *bufio.Reader + nodes []interface{} + lineNum int + lastLine string +} + +func (y *yamlReader) ReadObject(minIndent int) (interface{}, error) { + line, err := y.NextLine() + if err != nil { + if err == io.EOF && line != "" { + //log.Println("Read EOF , but still some data here") + } else { + //log.Println("ReadERR", err) + return nil, err + } + } + y.lastLine = line + indent, str := getIndent(line) + if indent < minIndent { + //log.Println("Current Indent Unexpect : ", str, indent, minIndent) + return nil, y.Error("Unexpect Indent", nil) + } + if indent > minIndent { + //log.Println("Change minIndent from %d to %d", minIndent, indent) + minIndent = indent + } + switch str[0] { + case '-': + return y.ReadList(minIndent) + case '[': + fallthrough + case '{': + y.lastLine = "" + _, value, err := y.asMapKeyValue("tmp:" + str) + if err != nil { + return nil, y.Error("Err inline map/list", nil) + } + return value, nil + } + //log.Println("Read Objcet as Map", indent, str) + + return y.ReadMap(minIndent) + +} + +func (y *yamlReader) ReadList(minIndent int) ([]interface{}, error) { + list := []interface{}{} + for { + line, err := y.NextLine() + if err != nil { + return list, err + } + indent, str := getIndent(line) + switch { + case indent < minIndent: + y.lastLine = line + if len(list) == 0 { + return nil, nil + } + return list, nil + case indent == minIndent: + if str[0] != '-' { + y.lastLine = line + return list, nil + } + if len(str) < 2 { + return nil, y.Error("ListItem is Emtry", nil) + } + key, value, err := y.asMapKeyValue(str[1:]) + if err != nil { + return nil, err + } + switch value { + case nil: + list = append(list, key) + case MAP_KEY_ONLY: + return nil, y.Error("Not support List-Map yet", nil) + default: + _map := map[string]interface{}{key.(string): value} + list = append(list, _map) + + _line, _err := y.NextLine() + if _err != nil && _err != io.EOF { + return nil, err + } + if _line == "" { + return list, nil + } + y.lastLine = _line + _indent, _str := getIndent(line) + if _indent >= minIndent+2 { + switch _str[0] { + case '-': + return nil, y.Error("Unexpect", nil) + case '[': + return nil, y.Error("Unexpect", nil) + case '{': + return nil, y.Error("Unexpect", nil) + } + // look like a map + _map2, _err := y.ReadMap(_indent) + if _map2 != nil { + _map2[key.(string)] = value + } + if err != nil { + return list, _err + } + } + } + continue + default: + return nil, y.Error("Bad Indent\n"+line, nil) + } + } + panic("ERROR") + return nil, errors.New("Impossible") +} + +func (y *yamlReader) ReadMap(minIndent int) (map[string]interface{}, error) { + _map := map[string]interface{}{} + //log.Println("ReadMap", minIndent) +OUT: + for { + line, err := y.NextLine() + if err != nil { + return _map, err + } + indent, str := getIndent(line) + //log.Printf("Indent : %d, str = %s", indent, str) + switch { + case indent < minIndent: + y.lastLine = line + if len(_map) == 0 { + return nil, nil + } + return _map, nil + case indent == minIndent: + key, value, err := y.asMapKeyValue(str) + if err != nil { + return nil, err + } + //log.Println("Key=", key, "value=", value) + switch value { + case nil: + return nil, y.Error("Unexpect", nil) + case MAP_KEY_ONLY: + //log.Println("KeyOnly, read inner Map", key) + + //-------------------------------------- + _line, err := y.NextLine() + if err != nil { + if err == io.EOF { + if _line == "" { + // Emtry map item? + _map[key.(string)] = nil + return _map, err + } + } else { + return nil, y.Error("ERR?", err) + } + } + y.lastLine = _line + _indent, _str := getIndent(_line) + if _indent < minIndent { + return _map, nil + } + ////log.Println("##>>", _indent, _str) + if _indent == minIndent { + if _str[0] == '-' { + //log.Println("Read Same-Indent ListItem for Map") + _list, err := y.ReadList(minIndent) + if _list != nil { + _map[key.(string)] = _list + } + if err != nil { + return _map, nil + } + continue OUT + } else { + // Emtry map item? + _map[key.(string)] = nil + continue OUT + } + } + //-------------------------------------- + //log.Println("Read Map Item", _indent, _str) + + obj, err := y.ReadObject(_indent) + if obj != nil { + _map[key.(string)] = obj + } + if err != nil { + return _map, err + } + default: + _map[key.(string)] = value + } + default: + //log.Println("Bad", indent, str) + return nil, y.Error("Bad Indent\n"+line, nil) + } + } + panic("ERROR") + return nil, errors.New("Impossible") + +} + +func (y *yamlReader) NextLine() (line string, err error) { + if y.lastLine != "" { + line = y.lastLine + y.lastLine = "" + //log.Println("Return lastLine", line) + return + } + for { + line, err = y.br.ReadString('\n') + y.lineNum++ + if err != nil { + return + } + if strings.HasPrefix(line, "---") || strings.HasPrefix(line, "#") { + continue + } + + line = strings.TrimRight(line, "\n\t\r ") + if line == "" { + continue + } + //log.Println("Return Line", line) + return + } + //log.Println("Impossbible : " + line) + return // impossbile! +} + +func getIndent(str string) (int, string) { + indent := 0 + for i, s := range str { + switch s { + case ' ': + indent++ + case '\t': + indent += 4 + default: + return indent, str[i:] + } + } + panic("Invalid indent : " + str) + return -1, "" +} + +func (y *yamlReader) asMapKeyValue(str string) (key interface{}, val interface{}, err error) { + tokens := splitToken(str) + key = tokens[0] + if len(tokens) == 1 { + return key, nil, nil + } + if tokens[1] != ":" { + return "", nil, y.Error("Unexpect "+str, nil) + } + + if len(tokens) == 2 { + return key, MAP_KEY_ONLY, nil + } + if len(tokens) == 3 { + return key, tokens[2], nil + } + switch tokens[2] { + case "[": + list := []interface{}{} + for i := 3; i < len(tokens)-1; i++ { + list = append(list, tokens[i]) + } + return key, list, nil + case "{": + _map := map[string]interface{}{} + for i := 3; i < len(tokens)-1; i += 4 { + //log.Println(">>>", i, tokens[i]) + if i > len(tokens)-2 { + return "", nil, y.Error("Unexpect "+str, nil) + } + if tokens[i+1] != ":" { + return "", nil, y.Error("Unexpect "+str, nil) + } + _map[tokens[i].(string)] = tokens[i+2] + if (i + 3) < (len(tokens) - 1) { + if tokens[i+3] != "," { + return "", "", y.Error("Unexpect "+str, nil) + } + } else { + break + } + } + return key, _map, nil + } + //log.Println(str, tokens) + return "", nil, y.Error("Unexpect "+str, nil) +} + +func splitToken(str string) (tokens []interface{}) { + str = strings.Trim(str, "\r\t\n ") + if str == "" { + panic("Impossbile") + return + } + + tokens = []interface{}{} + lastPos := 0 + for i := 0; i < len(str); i++ { + switch str[i] { + case ':': + fallthrough + case '{': + fallthrough + case '[': + fallthrough + case '}': + fallthrough + case ']': + fallthrough + case ',': + if i > lastPos { + tokens = append(tokens, str[lastPos:i]) + } + tokens = append(tokens, str[i:i+1]) + lastPos = i + 1 + case ' ': + if i > lastPos { + tokens = append(tokens, str[lastPos:i]) + } + lastPos = i + 1 + case '\'': + //log.Println("Scan End of String") + i++ + start := i + for ; i < len(str); i++ { + if str[i] == '\'' { + //log.Println("Found End of String", start, i) + break + } + } + tokens = append(tokens, str[start:i]) + lastPos = i + 1 + case '"': + i++ + start := i + for ; i < len(str); i++ { + if str[i] == '"' { + break + } + } + tokens = append(tokens, str[start:i]) + lastPos = i + 1 + } + } + ////log.Println("last", lastPos) + if lastPos < len(str) { + tokens = append(tokens, str[lastPos:]) + } + + if len(tokens) == 1 { + tokens[0] = string2Val(tokens[0].(string)) + return + } + + if tokens[1] == ":" { + if len(tokens) == 2 { + return + } + if tokens[2] == "{" || tokens[2] == "[" { + return + } + str = strings.Trim(strings.SplitN(str, ":", 2)[1], "\t ") + if len(str) > 2 { + if str[0] == '\'' && str[len(str)-1] == '\'' { + str = str[1 : len(str)-1] + } else if str[0] == '"' && str[len(str)-1] == '"' { + str = str[1 : len(str)-1] + } + } + val := string2Val(str) + tokens = []interface{}{tokens[0], tokens[1], val} + return + } + + if len(str) > 2 { + if str[0] == '\'' && str[len(str)-1] == '\'' { + str = str[1 : len(str)-1] + } else if str[0] == '"' && str[len(str)-1] == '"' { + str = str[1 : len(str)-1] + } + } + val := string2Val(str) + tokens = []interface{}{val} + return +} + +func (y *yamlReader) Error(msg string, err error) error { + if err != nil { + return errors.New(fmt.Sprintf("line %d : %s : %v", y.lineNum, msg, err.Error())) + } + return errors.New(fmt.Sprintf("line %d >> %s", y.lineNum, msg)) +} diff --git a/vendor/github.com/beego/goyaml2/types.go b/vendor/github.com/beego/goyaml2/types.go new file mode 100644 index 00000000..ac595fcb --- /dev/null +++ b/vendor/github.com/beego/goyaml2/types.go @@ -0,0 +1,29 @@ +package goyaml2 + +const ( + N_Map = iota + N_List + N_String +) + +type Node interface { + Type() int +} + +type MapNode map[string]interface{} + +type ListNode []interface{} + +type StringNode string + +func (m *MapNode) Type() int { + return N_Map +} + +func (l *ListNode) Type() int { + return N_List +} + +func (s *StringNode) Type() int { + return N_String +} diff --git a/vendor/github.com/beego/goyaml2/writer.go b/vendor/github.com/beego/goyaml2/writer.go new file mode 100644 index 00000000..4143fe98 --- /dev/null +++ b/vendor/github.com/beego/goyaml2/writer.go @@ -0,0 +1,9 @@ +package goyaml2 + +import ( + "io" +) + +func Write(w io.Writer, v interface{}) error { + return nil +} diff --git a/vendor/github.com/beego/x2j/LICENSE b/vendor/github.com/beego/x2j/LICENSE new file mode 100644 index 00000000..93bacb6c --- /dev/null +++ b/vendor/github.com/beego/x2j/LICENSE @@ -0,0 +1,28 @@ +Copyright (c) 2012-2013 Charles Banning . +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/beego/x2j/README b/vendor/github.com/beego/x2j/README new file mode 100644 index 00000000..bad2f232 --- /dev/null +++ b/vendor/github.com/beego/x2j/README @@ -0,0 +1,157 @@ +x2j.go - Unmarshal dynamic / arbitrary XML docs and extract values (using wildcards, if necessary). + +ANNOUNCEMENTS + +20 December 2013: + +Non-UTF8 character sets supported via the X2jCharsetReader variable. + +12 December 2013: + +For symmetry, the package j2x has functions that marshal JSON strings and +map[string]interface{} values to XML encoded strings: http://godoc.org/github.com/clbanning/j2x. + +Also, ToTree(), ToMap(), ToJson(), ToJsonIndent(), ReaderValuesFromTagPath() and ReaderValuesForTag() use io.Reader instead of string or []byte. + +If you want to process a stream of XML messages check out XmlMsgsFromReader(). + +MOTIVATION + +I make extensive use of JSON for messaging and typically unmarshal the messages into +map[string]interface{} variables. This is easily done using json.Unmarshal from the +standard Go libraries. Unfortunately, many legacy solutions use structured +XML messages; in those environments the applications would have to be refitted to +interoperate with my components. + +The better solution is to just provide an alternative HTTP handler that receives +XML doc messages and parses it into a map[string]interface{} variable and then reuse +all the JSON-based code. The Go xml.Unmarshal() function does not provide the same +option of unmarshaling XML messages into map[string]interface{} variables. So I wrote +a couple of small functions to fill this gap. + +Of course, once the XML doc was unmarshal'd into a map[string]interface{} variable it +was just a matter of calling json.Marshal() to provide it as a JSON string. Hence 'x2j' +rather than just 'x2m'. + +USAGE + +The package is fairly well self-documented. (http://godoc.org/github.com/clbanning/x2j) +The one really useful function is: + + - Unmarshal(doc []byte, v interface{}) error + where v is a pointer to a variable of type 'map[string]interface{}', 'string', or + any other type supported by xml.Unmarshal(). + +To retrieve a value for specific tag use: + + - DocValue(doc, path string, attrs ...string) (interface{},error) + - MapValue(m map[string]interface{}, path string, attr map[string]interface{}, recast ...bool) (interface{}, error) + +The 'path' argument is a period-separated tag hierarchy - also known as dot-notation. +It is the program's responsibility to cast the returned value to the proper type; possible +types are the normal JSON unmarshaling types: string, float64, bool, []interface, map[string]interface{}. + +To retrieve all values associated with a tag occurring anywhere in the XML document use: + + - ValuesForTag(doc, tag string) ([]interface{}, error) + - ValuesForKey(m map[string]interface{}, key string) []interface{} + + Demos: http://play.golang.org/p/m8zP-cpk0O + http://play.golang.org/p/cIteTS1iSg + http://play.golang.org/p/vd8pMiI21b + +Returned values should be one of map[string]interface, []interface{}, or string. + +All the values assocated with a tag-path that may include one or more wildcard characters - +'*' - can also be retrieved using: + + - ValuesFromTagPath(doc, path string, getAttrs ...bool) ([]interface{}, error) + - ValuesFromKeyPath(map[string]interface{}, path string, getAttrs ...bool) []interface{} + + Demos: http://play.golang.org/p/kUQnZ8VuhS + http://play.golang.org/p/l1aMHYtz7G + +NOTE: care should be taken when using "*" at the end of a path - i.e., "books.book.*". See +the x2jpath_test.go case on how the wildcard returns all key values and collapses list values; +the same message structure can load a []interface{} or a map[string]interface{} (or an interface{}) +value for a tag. + +See the test cases in "x2jpath_test.go" and programs in "example" subdirectory for more. + +XML PARSING CONVENTIONS + + - Attributes are parsed to map[string]interface{} values by prefixing a hyphen, '-', + to the attribute label. + - If the element is a simple element and has attributes, the element value + is given the key '#text' for its map[string]interface{} representation. (See + the 'atomFeedString.xml' test data, below.) + +BULK PROCESSING OF MESSAGE FILES + +Sometime messages may be logged into files for transmission via FTP (e.g.) and subsequent +processing. You can use the bulk XML message processor to convert files of XML messages into +map[string]interface{} values with custom processing and error handler functions. See +the notes and test code for: + + - XmlMsgsFromFile(fname string, phandler func(map[string]interface{}) bool, ehandler func(error) bool,recast ...bool) error + +IMPLEMENTATION NOTES + +Nothing fancy here, just brute force. + + - Use xml.Decoder to parse the XML doc and build a tree. + - Walk the tree and load values into a map[string]interface{} variable, 'm', as + appropriate. + - Use json.Marshaler to convert 'm' to JSON. + +As for testing: + + - Copy an XML doc into 'x2j_test.xml'. + - Run "go test" and you'll get a full dump. + ("pathTestString.xml" and "atomFeedString.xml" are test data from "read_test.go" + in the encoding/xml directory of the standard package library.) + +USES + + - putting a XML API on our message hub middleware (http://jsonhub.net) + - loading XML data into NoSQL database, such as, mongoDB + +PERFORMANCE IMPROVEMENTS WITH GO 1.1 and 1.2 + +Upgrading to Go 1.1 environment results in performance improvements for XML and JSON +unmarshalling, in general. The x2j package gets an average performance boost of 40%. + + ----- Go 1.0.2 ----- ----------- Go 1.1 ----------- + iterations ns/op iterations ns/op % improved +Benchmark_UseXml-4 100000 18776 200000 10377 45% +Benchmark_UseX2j-4 50000 55323 50000 33958 39% +Benchmark_UseJson-4 1000000 2257 1000000 1484 34% +Benchmark_UseJsonToMap-4 1000000 2531 1000000 1566 38% +BenchmarkBig_UseXml-4 100000 28918 100000 15876 45% +BenchmarkBig_UseX2j-4 20000 86338 50000 52661 39% +BenchmarkBig_UseJson-4 500000 4448 1000000 2664 40% +BenchmarkBig_UseJsonToMap-4 200000 9076 500000 5753 37% +BenchmarkBig3_UseXml-4 50000 42224 100000 24686 42% +BenchmarkBig3_UseX2j-4 10000 147407 20000 84332 43% +BenchmarkBig3_UseJson-4 500000 5921 500000 3930 34% +BenchmarkBig3_UseJsonToMap-4 200000 13037 200000 8670 33% + +The x2j package gets an additional 15-20% performance boost going to Go 1.2. + + ------ Go 1.1 ------ ----------- Go 1.2 ----------- + iterations ns/op iterations ns/op % improved +Benchmark_UseXml-4 200000 10377 200000 11031 -6% +Benchmark_UseX2j-4 50000 33958 100000 29188 14% +Benchmark_UseJson-4 1000000 1484 1000000 1347 9% +Benchmark_UseJsonToMap-4 1000000 1566 1000000 1434 8% +BenchmarkBig_UseXml-4 100000 15876 100000 16585 -4% +BenchmarkBig_UseX2j-4 50000 52661 50000 43452 17% +BenchmarkBig_UseJson-4 1000000 2664 1000000 2523 5% +BenchmarkBig_UseJsonToMap-4 500000 5753 500000 4992 13% +BenchmarkBig3_UseXml-4 100000 24686 100000 24348 1% +BenchmarkBig3_UseX2j-4 20000 84332 50000 66736 21% +BenchmarkBig3_UseJson-4 500000 3930 500000 3733 5% +BenchmarkBig3_UseJsonToMap-4 200000 8670 200000 7810 10% + + + diff --git a/vendor/github.com/beego/x2j/atomFeedString.xml b/vendor/github.com/beego/x2j/atomFeedString.xml new file mode 100644 index 00000000..474575a4 --- /dev/null +++ b/vendor/github.com/beego/x2j/atomFeedString.xml @@ -0,0 +1,54 @@ + +Code Review - My issueshttp://codereview.appspot.com/rietveld<>rietveld: an attempt at pubsubhubbub +2009-10-04T01:35:58+00:00email-address-removedurn:md5:134d9179c41f806be79b3a5f7877d19a + An attempt at adding pubsubhubbub support to Rietveld. +http://code.google.com/p/pubsubhubbub +http://code.google.com/p/rietveld/issues/detail?id=155 + +The server side of the protocol is trivial: + 1. add a &lt;link rel=&quot;hub&quot; href=&quot;hub-server&quot;&gt; tag to all + feeds that will be pubsubhubbubbed. + 2. every time one of those feeds changes, tell the hub + with a simple POST request. + +I have tested this by adding debug prints to a local hub +server and checking that the server got the right publish +requests. + +I can&#39;t quite get the server to work, but I think the bug +is not in my code. I think that the server expects to be +able to grab the feed and see the feed&#39;s actual URL in +the link rel=&quot;self&quot;, but the default value for that drops +the :port from the URL, and I cannot for the life of me +figure out how to get the Atom generator deep inside +django not to do that, or even where it is doing that, +or even what code is running to generate the Atom feed. +(I thought I knew but I added some assert False statements +and it kept running!) + +Ignoring that particular problem, I would appreciate +feedback on the right way to get the two values at +the top of feeds.py marked NOTE(rsc). + + +rietveld: correct tab handling +2009-10-03T23:02:17+00:00email-address-removedurn:md5:0a2a4f19bb815101f0ba2904aed7c35a + This fixes the buggy tab rendering that can be seen at +http://codereview.appspot.com/116075/diff/1/2 + +The fundamental problem was that the tab code was +not being told what column the text began in, so it +didn&#39;t know where to put the tab stops. Another problem +was that some of the code assumed that string byte +offsets were the same as column offsets, which is only +true if there are no tabs. + +In the process of fixing this, I cleaned up the arguments +to Fold and ExpandTabs and renamed them Break and +_ExpandTabs so that I could be sure that I found all the +call sites. I also wanted to verify that ExpandTabs was +not being used from outside intra_region_diff.py. + + + ` + diff --git a/vendor/github.com/beego/x2j/pathTestString.xml b/vendor/github.com/beego/x2j/pathTestString.xml new file mode 100644 index 00000000..b297e9d0 --- /dev/null +++ b/vendor/github.com/beego/x2j/pathTestString.xml @@ -0,0 +1,20 @@ + + 1 + + + A + + + B + + + C + D + + <_> + E + + + 2 + + diff --git a/vendor/github.com/beego/x2j/reader2j.go b/vendor/github.com/beego/x2j/reader2j.go new file mode 100644 index 00000000..11f08289 --- /dev/null +++ b/vendor/github.com/beego/x2j/reader2j.go @@ -0,0 +1,105 @@ +// io.Reader --> map[string]interface{} or JSON string +// nothing magic - just implements generic Go case + +package x2j + +import ( + "encoding/json" + "encoding/xml" + "io" +) + +// ToTree() - parse a XML io.Reader to a tree of Nodes +func ToTree(rdr io.Reader) (*Node, error) { + p := xml.NewDecoder(rdr) + p.CharsetReader = X2jCharsetReader + n, perr := xmlToTree("", nil, p) + if perr != nil { + return nil, perr + } + + return n, nil +} + +// ToMap() - parse a XML io.Reader to a map[string]interface{} +func ToMap(rdr io.Reader, recast ...bool) (map[string]interface{}, error) { + var r bool + if len(recast) == 1 { + r = recast[0] + } + n, err := ToTree(rdr) + if err != nil { + return nil, err + } + + m := make(map[string]interface{}) + m[n.key] = n.treeToMap(r) + + return m, nil +} + +// ToJson() - parse a XML io.Reader to a JSON string +func ToJson(rdr io.Reader, recast ...bool) (string, error) { + var r bool + if len(recast) == 1 { + r = recast[0] + } + m, merr := ToMap(rdr, r) + if m == nil || merr != nil { + return "", merr + } + + b, berr := json.Marshal(m) + if berr != nil { + return "", berr + } + + return string(b), nil +} + +// ToJsonIndent - the pretty form of ReaderToJson +func ToJsonIndent(rdr io.Reader, recast ...bool) (string, error) { + var r bool + if len(recast) == 1 { + r = recast[0] + } + m, merr := ToMap(rdr, r) + if m == nil || merr != nil { + return "", merr + } + + b, berr := json.MarshalIndent(m, "", " ") + if berr != nil { + return "", berr + } + + // NOTE: don't have to worry about safe JSON marshaling with json.Marshal, since '<' and '>" are reservedin XML. + return string(b), nil +} + + +// ReaderValuesFromTagPath - io.Reader version of ValuesFromTagPath() +func ReaderValuesFromTagPath(rdr io.Reader, path string, getAttrs ...bool) ([]interface{}, error) { + var a bool + if len(getAttrs) == 1 { + a = getAttrs[0] + } + m, err := ToMap(rdr) + if err != nil { + return nil, err + } + + return ValuesFromKeyPath(m, path, a), nil +} + +// ReaderValuesForTag - io.Reader version of ValuesForTag() +func ReaderValuesForTag(rdr io.Reader, tag string) ([]interface{}, error) { + m, err := ToMap(rdr) + if err != nil { + return nil, err + } + + return ValuesForKey(m, tag), nil +} + + diff --git a/vendor/github.com/beego/x2j/songTextString.xml b/vendor/github.com/beego/x2j/songTextString.xml new file mode 100644 index 00000000..8c0f2bec --- /dev/null +++ b/vendor/github.com/beego/x2j/songTextString.xml @@ -0,0 +1,29 @@ + + help me! + + + + Henry was a renegade + Didn't like to play it safe + One component at a time + There's got to be a better way + Oh, people came from miles around + Searching for a steady job + Welcome to the Motor Town + Booming like an atom bomb + + + Oh, Henry was the end of the story + Then everything went wrong + And we'll return it to its former glory + But it just takes so long + + + + It's going to take a long time + It's going to take it, but we'll make it one day + It's going to take a long time + It's going to take it, but we'll make it one day + + + diff --git a/vendor/github.com/beego/x2j/x2j.go b/vendor/github.com/beego/x2j/x2j.go new file mode 100644 index 00000000..b9ca80e9 --- /dev/null +++ b/vendor/github.com/beego/x2j/x2j.go @@ -0,0 +1,656 @@ +// Unmarshal dynamic / arbitrary XML docs and extract values (using wildcards, if necessary). +// Copyright 2012-2013 Charles Banning. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file +/* + Unmarshal dynamic / arbitrary XML docs and extract values (using wildcards, if necessary). + + One useful function is: + + - Unmarshal(doc []byte, v interface{}) error + where v is a pointer to a variable of type 'map[string]interface{}', 'string', or + any other type supported by xml.Unmarshal(). + + To retrieve a value for specific tag use: + + - DocValue(doc, path string, attrs ...string) (interface{},error) + - MapValue(m map[string]interface{}, path string, attr map[string]interface{}, recast ...bool) (interface{}, error) + + The 'path' argument is a period-separated tag hierarchy - also known as dot-notation. + It is the program's responsibility to cast the returned value to the proper type; possible + types are the normal JSON unmarshaling types: string, float64, bool, []interface, map[string]interface{}. + + To retrieve all values associated with a tag occurring anywhere in the XML document use: + + - ValuesForTag(doc, tag string) ([]interface{}, error) + - ValuesForKey(m map[string]interface{}, key string) []interface{} + + Demos: http://play.golang.org/p/m8zP-cpk0O + http://play.golang.org/p/cIteTS1iSg + http://play.golang.org/p/vd8pMiI21b + + Returned values should be one of map[string]interface, []interface{}, or string. + + All the values assocated with a tag-path that may include one or more wildcard characters - + '*' - can also be retrieved using: + + - ValuesFromTagPath(doc, path string, getAttrs ...bool) ([]interface{}, error) + - ValuesFromKeyPath(map[string]interface{}, path string, getAttrs ...bool) []interface{} + + Demos: http://play.golang.org/p/kUQnZ8VuhS + http://play.golang.org/p/l1aMHYtz7G + + NOTE: care should be taken when using "*" at the end of a path - i.e., "books.book.*". See + the x2jpath_test.go case on how the wildcard returns all key values and collapses list values; + the same message structure can load a []interface{} or a map[string]interface{} (or an interface{}) + value for a tag. + + See the test cases in "x2jpath_test.go" and programs in "example" subdirectory for more. + + XML PARSING CONVENTIONS + + - Attributes are parsed to map[string]interface{} values by prefixing a hyphen, '-', + to the attribute label. + - If the element is a simple element and has attributes, the element value + is given the key '#text' for its map[string]interface{} representation. (See + the 'atomFeedString.xml' test data, below.) + + io.Reader HANDLING + + ToTree(), ToMap(), ToJson(), and ToJsonIndent() provide parsing of messages from an io.Reader. + If you want to handle a message stream, look at XmlMsgsFromReader(). + + NON-UTF8 CHARACTER SETS + + Use the X2jCharsetReader variable to assign io.Reader for alternative character sets. +*/ +package x2j + +import ( + "bytes" + "encoding/json" + "encoding/xml" + "errors" + "fmt" + "io" + "regexp" + "strconv" + "strings" +) + +// If X2jCharsetReader != nil, it will be used to decode the doc or stream if required +// import charset "code.google.com/p/go-charset/charset" +// ... +// x2j.X2jCharsetReader = charset.NewReader +// s, err := x2j.DocToJson(doc) +var X2jCharsetReader func(charset string, input io.Reader)(io.Reader, error) + +type Node struct { + dup bool // is member of a list + attr bool // is an attribute + key string // XML tag + val string // element value + nodes []*Node +} + +// DocToJson - return an XML doc as a JSON string. +// If the optional argument 'recast' is 'true', then values will be converted to boolean or float64 if possible. +func DocToJson(doc string, recast ...bool) (string, error) { + var r bool + if len(recast) == 1 { + r = recast[0] + } + m, merr := DocToMap(doc, r) + if m == nil || merr != nil { + return "", merr + } + + b, berr := json.Marshal(m) + if berr != nil { + return "", berr + } + + // NOTE: don't have to worry about safe JSON marshaling with json.Marshal, since '<' and '>" are reservedin XML. + return string(b), nil +} + +// DocToJsonIndent - return an XML doc as a prettified JSON string. +// If the optional argument 'recast' is 'true', then values will be converted to boolean or float64 if possible. +// Note: recasting is only applied to element values, not attribute values. +func DocToJsonIndent(doc string, recast ...bool) (string, error) { + var r bool + if len(recast) == 1 { + r = recast[0] + } + m, merr := DocToMap(doc, r) + if m == nil || merr != nil { + return "", merr + } + + b, berr := json.MarshalIndent(m, "", " ") + if berr != nil { + return "", berr + } + + // NOTE: don't have to worry about safe JSON marshaling with json.Marshal, since '<' and '>" are reservedin XML. + return string(b), nil +} + +// DocToMap - convert an XML doc into a map[string]interface{}. +// (This is analogous to unmarshalling a JSON string to map[string]interface{} using json.Unmarshal().) +// If the optional argument 'recast' is 'true', then values will be converted to boolean or float64 if possible. +// Note: recasting is only applied to element values, not attribute values. +func DocToMap(doc string, recast ...bool) (map[string]interface{}, error) { + var r bool + if len(recast) == 1 { + r = recast[0] + } + n, err := DocToTree(doc) + if err != nil { + return nil, err + } + + m := make(map[string]interface{}) + m[n.key] = n.treeToMap(r) + + return m, nil +} + +// DocToTree - convert an XML doc into a tree of nodes. +func DocToTree(doc string) (*Node, error) { + // xml.Decoder doesn't properly handle whitespace in some doc + // see songTextString.xml test case ... + reg, _ := regexp.Compile("[ \t\n\r]*<") + doc = reg.ReplaceAllString(doc, "<") + + b := bytes.NewBufferString(doc) + p := xml.NewDecoder(b) + p.CharsetReader = X2jCharsetReader + n, berr := xmlToTree("", nil, p) + if berr != nil { + return nil, berr + } + + return n, nil +} + +// (*Node)WriteTree - convert a tree of nodes into a printable string. +// 'padding' is the starting indentation; typically: n.WriteTree(). +func (n *Node) WriteTree(padding ...int) string { + var indent int + if len(padding) == 1 { + indent = padding[0] + } + + var s string + if n.val != "" { + for i := 0; i < indent; i++ { + s += " " + } + s += n.key + " : " + n.val + "\n" + } else { + for i := 0; i < indent; i++ { + s += " " + } + s += n.key + " :" + "\n" + for _, nn := range n.nodes { + s += nn.WriteTree(indent + 1) + } + } + return s +} + +// xmlToTree - load a 'clean' XML doc into a tree of *Node. +func xmlToTree(skey string, a []xml.Attr, p *xml.Decoder) (*Node, error) { + n := new(Node) + n.nodes = make([]*Node, 0) + + if skey != "" { + n.key = skey + if len(a) > 0 { + for _, v := range a { + na := new(Node) + na.attr = true + na.key = `-` + v.Name.Local + na.val = v.Value + n.nodes = append(n.nodes, na) + } + } + } + for { + t, err := p.Token() + if err != nil { + return nil, err + } + switch t.(type) { + case xml.StartElement: + tt := t.(xml.StartElement) + // handle root + if n.key == "" { + n.key = tt.Name.Local + if len(tt.Attr) > 0 { + for _, v := range tt.Attr { + na := new(Node) + na.attr = true + na.key = `-` + v.Name.Local + na.val = v.Value + n.nodes = append(n.nodes, na) + } + } + } else { + nn, nnerr := xmlToTree(tt.Name.Local, tt.Attr, p) + if nnerr != nil { + return nil, nnerr + } + n.nodes = append(n.nodes, nn) + } + case xml.EndElement: + // scan n.nodes for duplicate n.key values + n.markDuplicateKeys() + return n, nil + case xml.CharData: + tt := string(t.(xml.CharData)) + if len(n.nodes) > 0 { + nn := new(Node) + nn.key = "#text" + nn.val = tt + n.nodes = append(n.nodes, nn) + } else { + n.val = tt + } + default: + // noop + } + } + // Logically we can't get here, but provide an error message anyway. + return nil, errors.New("Unknown parse error in xmlToTree() for: " + n.key) +} + +// (*Node)markDuplicateKeys - set node.dup flag for loading map[string]interface{}. +func (n *Node) markDuplicateKeys() { + l := len(n.nodes) + for i := 0; i < l; i++ { + if n.nodes[i].dup { + continue + } + for j := i + 1; j < l; j++ { + if n.nodes[i].key == n.nodes[j].key { + n.nodes[i].dup = true + n.nodes[j].dup = true + } + } + } +} + +// (*Node)treeToMap - convert a tree of nodes into a map[string]interface{}. +// (Parses to map that is structurally the same as from json.Unmarshal().) +// Note: root is not instantiated; call with: "m[n.key] = n.treeToMap(recast)". +func (n *Node) treeToMap(r bool) interface{} { + if len(n.nodes) == 0 { + return recast(n.val, r) + } + + m := make(map[string]interface{}, 0) + for _, v := range n.nodes { + // just a value + if !v.dup && len(v.nodes) == 0 { + m[v.key] = recast(v.val, r) + continue + } + + // a list of values + if v.dup { + var a []interface{} + if vv, ok := m[v.key]; ok { + a = vv.([]interface{}) + } else { + a = make([]interface{}, 0) + } + a = append(a, v.treeToMap(r)) + m[v.key] = interface{}(a) + continue + } + + // it's a unique key + m[v.key] = v.treeToMap(r) + } + + return interface{}(m) +} + +// recast - try to cast string values to bool or float64 +func recast(s string, r bool) interface{} { + if r { + // handle numeric strings ahead of boolean + if f, err := strconv.ParseFloat(s, 64); err == nil { + return interface{}(f) + } + // ParseBool treats "1"==true & "0"==false + if b, err := strconv.ParseBool(s); err == nil { + return interface{}(b) + } + } + return interface{}(s) +} + +// WriteMap - dumps the map[string]interface{} for examination. +// 'offset' is initial indentation count; typically: WriteMap(m). +// NOTE: with XML all element types are 'string'. +// But code written as generic for use with maps[string]interface{} values from json.Unmarshal(). +// Or it can handle a DocToMap(doc,true) result where values have been recast'd. +func WriteMap(m interface{}, offset ...int) string { + var indent int + if len(offset) == 1 { + indent = offset[0] + } + + var s string + switch m.(type) { + case nil: + return "[nil] nil" + case string: + return "[string] " + m.(string) + case float64: + return "[float64] " + strconv.FormatFloat(m.(float64), 'e', 2, 64) + case bool: + return "[bool] " + strconv.FormatBool(m.(bool)) + case []interface{}: + s += "[[]interface{}]" + for i, v := range m.([]interface{}) { + s += "\n" + for i := 0; i < indent; i++ { + s += " " + } + s += "[item: " + strconv.FormatInt(int64(i), 10) + "]" + switch v.(type) { + case string, float64, bool: + s += "\n" + default: + // noop + } + for i := 0; i < indent; i++ { + s += " " + } + s += WriteMap(v, indent+1) + } + case map[string]interface{}: + for k, v := range m.(map[string]interface{}) { + s += "\n" + for i := 0; i < indent; i++ { + s += " " + } + // s += "[map[string]interface{}] "+k+" :"+WriteMap(v,indent+1) + s += k + " :" + WriteMap(v, indent+1) + } + default: + // shouldn't ever be here ... + s += fmt.Sprintf("unknown type for: %v", m) + } + return s +} + +// ------------------------ value extraction from XML doc -------------------------- + +// DocValue - return a value for a specific tag +// 'doc' is a valid XML message. +// 'path' is a hierarchy of XML tags, e.g., "doc.name". +// 'attrs' is an OPTIONAL list of "name:value" pairs for attributes. +// Note: 'recast' is not enabled here. Use DocToMap(), NewAttributeMap(), and MapValue() calls for that. +func DocValue(doc, path string, attrs ...string) (interface{}, error) { + n, err := DocToTree(doc) + if err != nil { + return nil, err + } + + m := make(map[string]interface{}) + m[n.key] = n.treeToMap(false) + + a, aerr := NewAttributeMap(attrs...) + if aerr != nil { + return nil, aerr + } + v, verr := MapValue(m, path, a) + if verr != nil { + return nil, verr + } + return v, nil +} + +// MapValue - retrieves value based on walking the map, 'm'. +// 'm' is the map value of interest. +// 'path' is a period-separated hierarchy of keys in the map. +// 'attr' is a map of attribute "name:value" pairs from NewAttributeMap(). May be 'nil'. +// If the path can't be traversed, an error is returned. +// Note: the optional argument 'r' can be used to coerce attribute values, 'attr', if done so for 'm'. +func MapValue(m map[string]interface{}, path string, attr map[string]interface{}, r ...bool) (interface{}, error) { + // attribute values may have been recasted during map construction; default is 'false'. + if len(r) == 1 && r[0] == true { + for k, v := range attr { + attr[k] = recast(v.(string), true) + } + } + + // parse the path + keys := strings.Split(path, ".") + + // initialize return value to 'm' so a path of "" will work correctly + var v interface{} = m + var ok bool + var okey string + var isMap bool = true + if keys[0] == "" && len(attr) == 0 { + return v, nil + } + for _, key := range keys { + if !isMap { + return nil, errors.New("no keys beyond: " + okey) + } + if v, ok = m[key]; !ok { + return nil, errors.New("no key in map: " + key) + } else { + switch v.(type) { + case map[string]interface{}: + m = v.(map[string]interface{}) + isMap = true + default: + isMap = false + } + } + // save 'key' for error reporting + okey = key + } + + // match attributes; value is "#text" or nil + if attr == nil { + return v, nil + } + return hasAttributes(v, attr) +} + +// hasAttributes() - interface{} equality works for string, float64, bool +func hasAttributes(v interface{}, a map[string]interface{}) (interface{}, error) { + switch v.(type) { + case []interface{}: + // run through all entries looking one with matching attributes + for _, vv := range v.([]interface{}) { + if vvv, vvverr := hasAttributes(vv, a); vvverr == nil { + return vvv, nil + } + } + return nil, errors.New("no list member with matching attributes") + case map[string]interface{}: + // do all attribute name:value pairs match? + nv := v.(map[string]interface{}) + for key, val := range a { + if vv, ok := nv[key]; !ok { + return nil, errors.New("no attribute with name: " + key[1:]) + } else if val != vv { + return nil, errors.New("no attribute key:value pair: " + fmt.Sprintf("%s:%v", key[1:], val)) + } + } + // they all match; so return value associated with "#text" key. + if vv, ok := nv["#text"]; ok { + return vv, nil + } else { + // this happens when another element is value of tag rather than just a string value + return nv, nil + } + } + return nil, errors.New("no match for attributes") +} + +// NewAttributeMap() - generate map of attributes=value entries as map["-"+string]string. +// 'kv' arguments are "name:value" pairs that appear as attributes, name="value". +// If len(kv) == 0, the return is (nil, nil). +func NewAttributeMap(kv ...string) (map[string]interface{}, error) { + if len(kv) == 0 { + return nil, nil + } + m := make(map[string]interface{}, 0) + for _, v := range kv { + vv := strings.Split(v, ":") + if len(vv) != 2 { + return nil, errors.New("attribute not \"name:value\" pair: " + v) + } + // attributes are stored as keys prepended with hyphen + m["-"+vv[0]] = interface{}(vv[1]) + } + return m, nil +} + +//------------------------- get values for key ---------------------------- + +// ValuesForTag - return all values in doc associated with 'tag'. +// Returns nil if the 'tag' does not occur in the doc. +// If there is an error encounted while parsing doc, that is returned. +// If you want values 'recast' use DocToMap() and ValuesForKey(). +func ValuesForTag(doc, tag string) ([]interface{}, error) { + m, err := DocToMap(doc) + if err != nil { + return nil, err + } + + return ValuesForKey(m, tag), nil +} + +// ValuesForKey - return all values in map associated with 'key' +// Returns nil if the 'key' does not occur in the map +func ValuesForKey(m map[string]interface{}, key string) []interface{} { + ret := make([]interface{}, 0) + + hasKey(m, key, &ret) + if len(ret) > 0 { + return ret + } + return nil +} + +// hasKey - if the map 'key' exists append it to array +// if it doesn't do nothing except scan array and map values +func hasKey(iv interface{}, key string, ret *[]interface{}) { + switch iv.(type) { + case map[string]interface{}: + vv := iv.(map[string]interface{}) + if v, ok := vv[key]; ok { + *ret = append(*ret, v) + } + for _, v := range iv.(map[string]interface{}) { + hasKey(v, key, ret) + } + case []interface{}: + for _, v := range iv.([]interface{}) { + hasKey(v, key, ret) + } + } +} + +// ======== 2013.07.01 - x2j.Unmarshal, wraps xml.Unmarshal ============== + +// Unmarshal - wraps xml.Unmarshal with handling of map[string]interface{} +// and string type variables. +// Usage: x2j.Unmarshal(doc,&m) where m of type map[string]interface{} +// x2j.Unmarshal(doc,&s) where s of type string (Overrides xml.Unmarshal().) +// x2j.Unmarshal(doc,&struct) - passed to xml.Unmarshal() +// x2j.Unmarshal(doc,&slice) - passed to xml.Unmarshal() +func Unmarshal(doc []byte, v interface{}) error { + switch v.(type) { + case *map[string]interface{}: + m, err := ByteDocToMap(doc) + vv := *v.(*map[string]interface{}) + for k, v := range m { + vv[k] = v + } + return err + case *string: + s, err := ByteDocToJson(doc) + *(v.(*string)) = s + return err + default: + b := bytes.NewBuffer(doc) + p := xml.NewDecoder(b) + p.CharsetReader = X2jCharsetReader + return p.Decode(v) + // return xml.Unmarshal(doc, v) + } + return nil +} + +// ByteDocToJson - return an XML doc as a JSON string. +// If the optional argument 'recast' is 'true', then values will be converted to boolean or float64 if possible. +func ByteDocToJson(doc []byte, recast ...bool) (string, error) { + var r bool + if len(recast) == 1 { + r = recast[0] + } + m, merr := ByteDocToMap(doc, r) + if m == nil || merr != nil { + return "", merr + } + + b, berr := json.Marshal(m) + if berr != nil { + return "", berr + } + + // NOTE: don't have to worry about safe JSON marshaling with json.Marshal, since '<' and '>" are reservedin XML. + return string(b), nil +} + +// ByteDocToMap - convert an XML doc into a map[string]interface{}. +// (This is analogous to unmarshalling a JSON string to map[string]interface{} using json.Unmarshal().) +// If the optional argument 'recast' is 'true', then values will be converted to boolean or float64 if possible. +// Note: recasting is only applied to element values, not attribute values. +func ByteDocToMap(doc []byte, recast ...bool) (map[string]interface{}, error) { + var r bool + if len(recast) == 1 { + r = recast[0] + } + n, err := ByteDocToTree(doc) + if err != nil { + return nil, err + } + + m := make(map[string]interface{}) + m[n.key] = n.treeToMap(r) + + return m, nil +} + +// ByteDocToTree - convert an XML doc into a tree of nodes. +func ByteDocToTree(doc []byte) (*Node, error) { + // xml.Decoder doesn't properly handle whitespace in some doc + // see songTextString.xml test case ... + reg, _ := regexp.Compile("[ \t\n\r]*<") + doc = reg.ReplaceAll(doc, []byte("<")) + + b := bytes.NewBuffer(doc) + p := xml.NewDecoder(b) + p.CharsetReader = X2jCharsetReader + n, berr := xmlToTree("", nil, p) + if berr != nil { + return nil, berr + } + + return n, nil +} + diff --git a/vendor/github.com/beego/x2j/x2j_bulk.go b/vendor/github.com/beego/x2j/x2j_bulk.go new file mode 100644 index 00000000..bc887753 --- /dev/null +++ b/vendor/github.com/beego/x2j/x2j_bulk.go @@ -0,0 +1,127 @@ +// Copyright 2012-2013 Charles Banning. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file + +// x2j_bulk.go: Process files with multiple XML messages. +// Extends x2m_bulk.go to work with JSON strings rather than map[string]interface{}. + +package x2j + +import ( + "bytes" + "encoding/json" + "io" + "os" + "regexp" +) + +// XmlMsgsFromFileAsJson() +// 'fname' is name of file +// 'phandler' is the JSON string processing handler. Return of 'false' stops further processing. +// 'ehandler' is the parsing error handler. Return of 'false' stops further processing and returns error. +// Note: phandler() and ehandler() calls are blocking, so reading and processing of messages is serialized. +// This means that you can stop reading the file on error or after processing a particular message. +// To have reading and handling run concurrently, pass arguments to a go routine in handler and return true. +func XmlMsgsFromFileAsJson(fname string, phandler func(string)(bool), ehandler func(error)(bool), recast ...bool) error { + var r bool + if len(recast) == 1 { + r = recast[0] + } + fi, fierr := os.Stat(fname) + if fierr != nil { + return fierr + } + fh, fherr := os.Open(fname) + if fherr != nil { + return fherr + } + defer fh.Close() + buf := make([]byte,fi.Size()) + _, rerr := fh.Read(buf) + if rerr != nil { + return rerr + } + doc := string(buf) + + // xml.Decoder doesn't properly handle whitespace in some doc + // see songTextString.xml test case ... + reg,_ := regexp.Compile("[ \t\n\r]*<") + doc = reg.ReplaceAllString(doc,"<") + b := bytes.NewBufferString(doc) + + for { + s, serr := XmlBufferToJson(b,r) + if serr != nil && serr != io.EOF { + if ok := ehandler(serr); !ok { + // caused reader termination + return serr + } + } + if s != "" { + if ok := phandler(s); !ok { + break + } + } + if serr == io.EOF { + break + } + } + return nil +} + +// XmlBufferToJson - process XML message from a bytes.Buffer +// 'b' is the buffer +// Optional argument 'recast' coerces values to float64 or bool where possible. +func XmlBufferToJson(b *bytes.Buffer,recast ...bool) (string,error) { + var r bool + if len(recast) == 1 { + r = recast[0] + } + + n,err := XmlBufferToTree(b) + if err != nil { + return "", err + } + + m := make(map[string]interface{}) + m[n.key] = n.treeToMap(r) + + j, jerr := json.Marshal(m) + return string(j), jerr +} + +// ============================= io.Reader version for stream processing ====================== + +// XmlMsgsFromReaderAsJson() - io.Reader version of XmlMsgsFromFileAsJson +// 'rdr' is an io.Reader for an XML message (stream) +// 'phandler' is the JSON string processing handler. Return of 'false' stops further processing. +// 'ehandler' is the parsing error handler. Return of 'false' stops further processing and returns error. +// Note: phandler() and ehandler() calls are blocking, so reading and processing of messages is serialized. +// This means that you can stop reading the file on error or after processing a particular message. +// To have reading and handling run concurrently, pass arguments to a go routine in handler and return true. +func XmlMsgsFromReaderAsJson(rdr io.Reader, phandler func(string)(bool), ehandler func(error)(bool), recast ...bool) error { + var r bool + if len(recast) == 1 { + r = recast[0] + } + + for { + s, serr := ToJson(rdr,r) + if serr != nil && serr != io.EOF { + if ok := ehandler(serr); !ok { + // caused reader termination + return serr + } + } + if s != "" { + if ok := phandler(s); !ok { + break + } + } + if serr == io.EOF { + break + } + } + return nil +} + diff --git a/vendor/github.com/beego/x2j/x2j_test.xml b/vendor/github.com/beego/x2j/x2j_test.xml new file mode 100644 index 00000000..8c0f2bec --- /dev/null +++ b/vendor/github.com/beego/x2j/x2j_test.xml @@ -0,0 +1,29 @@ + + help me! + + + + Henry was a renegade + Didn't like to play it safe + One component at a time + There's got to be a better way + Oh, people came from miles around + Searching for a steady job + Welcome to the Motor Town + Booming like an atom bomb + + + Oh, Henry was the end of the story + Then everything went wrong + And we'll return it to its former glory + But it just takes so long + + + + It's going to take a long time + It's going to take it, but we'll make it one day + It's going to take a long time + It's going to take it, but we'll make it one day + + + diff --git a/vendor/github.com/beego/x2j/x2j_valuesFrom.go b/vendor/github.com/beego/x2j/x2j_valuesFrom.go new file mode 100644 index 00000000..fdb6cec4 --- /dev/null +++ b/vendor/github.com/beego/x2j/x2j_valuesFrom.go @@ -0,0 +1,125 @@ +// Copyright 2012-2013 Charles Banning. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file + +// x2j_valuesFrom.go: Extract values from an arbitrary XML doc. Tag path can include wildcard characters. + +package x2j + +import ( + "strings" +) + +// ------------------- sweep up everything for some point in the node tree --------------------- + +// ValuesFromTagPath - deliver all values for a path node from a XML doc +// If there are no values for the path 'nil' is returned. +// A return value of (nil, nil) means that there were no values and no errors parsing the doc. +// 'doc' is the XML document +// 'path' is a dot-separated path of tag nodes +// 'getAttrs' can be set 'true' to return attribute values for "*"-terminated path +// If a node is '*', then everything beyond is scanned for values. +// E.g., "doc.books' might return a single value 'book' of type []interface{}, but +// "doc.books.*" could return all the 'book' entries as []map[string]interface{}. +// "doc.books.*.author" might return all the 'author' tag values as []string - or +// "doc.books.*.author.lastname" might be required, depending on he schema. +func ValuesFromTagPath(doc, path string, getAttrs ...bool) ([]interface{}, error) { + var a bool + if len(getAttrs) == 1 { + a = getAttrs[0] + } + m, err := DocToMap(doc) + if err != nil { + return nil, err + } + + v := ValuesFromKeyPath(m, path, a) + return v, nil +} + +// ValuesFromKeyPath - deliver all values for a path node from a map[string]interface{} +// If there are no values for the path 'nil' is returned. +// 'm' is the map to be walked +// 'path' is a dot-separated path of key values +// 'getAttrs' can be set 'true' to return attribute values for "*"-terminated path +// If a node is '*', then everything beyond is walked. +// E.g., see ValuesFromTagPath documentation. +func ValuesFromKeyPath(m map[string]interface{}, path string, getAttrs ...bool) []interface{} { + var a bool + if len(getAttrs) == 1 { + a = getAttrs[0] + } + keys := strings.Split(path, ".") + ret := make([]interface{}, 0) + valuesFromKeyPath(&ret, m, keys, a) + if len(ret) == 0 { + return nil + } + return ret +} + +func valuesFromKeyPath(ret *[]interface{}, m interface{}, keys []string, getAttrs bool) { + lenKeys := len(keys) + + // load 'm' values into 'ret' + // expand any lists + if lenKeys == 0 { + switch m.(type) { + case map[string]interface{}: + *ret = append(*ret, m) + case []interface{}: + for _, v := range m.([]interface{}) { + *ret = append(*ret, v) + } + default: + *ret = append(*ret, m) + } + return + } + + // key of interest + key := keys[0] + switch key { + case "*": // wildcard - scan all values + switch m.(type) { + case map[string]interface{}: + for k, v := range m.(map[string]interface{}) { + if string(k[:1]) == "-" && !getAttrs { // skip attributes? + continue + } + valuesFromKeyPath(ret, v, keys[1:], getAttrs) + } + case []interface{}: + for _, v := range m.([]interface{}) { + switch v.(type) { + // flatten out a list of maps - keys are processed + case map[string]interface{}: + for kk, vv := range v.(map[string]interface{}) { + if string(kk[:1]) == "-" && !getAttrs { // skip attributes? + continue + } + valuesFromKeyPath(ret, vv, keys[1:], getAttrs) + } + default: + valuesFromKeyPath(ret, v, keys[1:], getAttrs) + } + } + } + default: // key - must be map[string]interface{} + switch m.(type) { + case map[string]interface{}: + if v, ok := m.(map[string]interface{})[key]; ok { + valuesFromKeyPath(ret, v, keys[1:], getAttrs) + } + case []interface{}: // may be buried in list + for _, v := range m.([]interface{}) { + switch v.(type) { + case map[string]interface{}: + if vv, ok := v.(map[string]interface{})[key]; ok { + valuesFromKeyPath(ret, vv, keys[1:], getAttrs) + } + } + } + } + } +} diff --git a/vendor/github.com/beego/x2j/x2m_bulk.go b/vendor/github.com/beego/x2j/x2m_bulk.go new file mode 100644 index 00000000..4e6c5afb --- /dev/null +++ b/vendor/github.com/beego/x2j/x2m_bulk.go @@ -0,0 +1,201 @@ +// Copyright 2012-2013 Charles Banning. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file + +// x2m_bulk.go: Process files with multiple XML messages. + +package x2j + +import ( + "bytes" + "encoding/xml" + "errors" + "io" + "os" + "regexp" + "sync" +) + +// XmlMsgsFromFile() +// 'fname' is name of file +// 'phandler' is the map processing handler. Return of 'false' stops further processing. +// 'ehandler' is the parsing error handler. Return of 'false' stops further processing and returns error. +// Note: phandler() and ehandler() calls are blocking, so reading and processing of messages is serialized. +// This means that you can stop reading the file on error or after processing a particular message. +// To have reading and handling run concurrently, pass arguments to a go routine in handler and return true. +func XmlMsgsFromFile(fname string, phandler func(map[string]interface{})(bool), ehandler func(error)(bool), recast ...bool) error { + var r bool + if len(recast) == 1 { + r = recast[0] + } + fi, fierr := os.Stat(fname) + if fierr != nil { + return fierr + } + fh, fherr := os.Open(fname) + if fherr != nil { + return fherr + } + defer fh.Close() + buf := make([]byte,fi.Size()) + _, rerr := fh.Read(buf) + if rerr != nil { + return rerr + } + doc := string(buf) + + // xml.Decoder doesn't properly handle whitespace in some doc + // see songTextString.xml test case ... + reg,_ := regexp.Compile("[ \t\n\r]*<") + doc = reg.ReplaceAllString(doc,"<") + b := bytes.NewBufferString(doc) + + for { + m, merr := XmlBufferToMap(b,r) + if merr != nil && merr != io.EOF { + if ok := ehandler(merr); !ok { + // caused reader termination + return merr + } + } + if m != nil { + if ok := phandler(m); !ok { + break + } + } + if merr == io.EOF { + break + } + } + return nil +} + +// XmlBufferToMap - process XML message from a bytes.Buffer +// 'b' is the buffer +// Optional argument 'recast' coerces map values to float64 or bool where possible. +func XmlBufferToMap(b *bytes.Buffer,recast ...bool) (map[string]interface{},error) { + var r bool + if len(recast) == 1 { + r = recast[0] + } + + n,err := XmlBufferToTree(b) + if err != nil { + return nil, err + } + + m := make(map[string]interface{}) + m[n.key] = n.treeToMap(r) + + return m,nil +} + +// BufferToTree - derived from DocToTree() +func XmlBufferToTree(b *bytes.Buffer) (*Node, error) { + p := xml.NewDecoder(b) + p.CharsetReader = X2jCharsetReader + n, berr := xmlToTree("",nil,p) + if berr != nil { + return nil, berr + } + + return n,nil +} + +// XmlBuffer - create XML decoder buffer for a string from anywhere, not necessarily a file. +type XmlBuffer struct { + cnt uint64 + str *string + buf *bytes.Buffer +} +var mtx sync.Mutex +var cnt uint64 +var activeXmlBufs = make(map[uint64]*XmlBuffer) + +// NewXmlBuffer() - creates a bytes.Buffer from a string with multiple messages +// Use Close() function to release the buffer for garbage collection. +func NewXmlBuffer(s string) *XmlBuffer { + // xml.Decoder doesn't properly handle whitespace in some doc + // see songTextString.xml test case ... + reg,_ := regexp.Compile("[ \t\n\r]*<") + s = reg.ReplaceAllString(s,"<") + b := bytes.NewBufferString(s) + buf := new(XmlBuffer) + buf.str = &s + buf.buf = b + mtx.Lock() + defer mtx.Unlock() + buf.cnt = cnt ; cnt++ + activeXmlBufs[buf.cnt] = buf + return buf +} + +// BytesNewXmlBuffer() - creates a bytes.Buffer from b with possibly multiple messages +// Use Close() function to release the buffer for garbage collection. +func BytesNewXmlBuffer(b []byte) *XmlBuffer { + bb := bytes.NewBuffer(b) + buf := new(XmlBuffer) + buf.buf = bb + mtx.Lock() + defer mtx.Unlock() + buf.cnt = cnt ; cnt++ + activeXmlBufs[buf.cnt] = buf + return buf +} + +// Close() - release the buffer address for garbage collection +func (buf *XmlBuffer)Close() { + mtx.Lock() + defer mtx.Unlock() + delete(activeXmlBufs,buf.cnt) +} + +// NextMap() - retrieve next XML message in buffer as a map[string]interface{} value. +// The optional argument 'recast' will try and coerce values to float64 or bool as appropriate. +func (buf *XmlBuffer)NextMap(recast ...bool) (map[string]interface{}, error) { + var r bool + if len(recast) == 1 { + r = recast[0] + } + if _, ok := activeXmlBufs[buf.cnt]; !ok { + return nil, errors.New("Buffer is not active.") + } + return XmlBufferToMap(buf.buf,r) +} + + +// ============================= io.Reader version for stream processing ====================== + +// XmlMsgsFromReader() - io.Reader version of XmlMsgsFromFile +// 'rdr' is an io.Reader for an XML message (stream) +// 'phandler' is the map processing handler. Return of 'false' stops further processing. +// 'ehandler' is the parsing error handler. Return of 'false' stops further processing and returns error. +// Note: phandler() and ehandler() calls are blocking, so reading and processing of messages is serialized. +// This means that you can stop reading the file on error or after processing a particular message. +// To have reading and handling run concurrently, pass arguments to a go routine in handler and return true. +func XmlMsgsFromReader(rdr io.Reader, phandler func(map[string]interface{})(bool), ehandler func(error)(bool), recast ...bool) error { + var r bool + if len(recast) == 1 { + r = recast[0] + } + + for { + m, merr := ToMap(rdr,r) + if merr != nil && merr != io.EOF { + if ok := ehandler(merr); !ok { + // caused reader termination + return merr + } + } + if m != nil { + if ok := phandler(m); !ok { + break + } + } + if merr == io.EOF { + break + } + } + return nil +} + diff --git a/vendor/github.com/beego/x2j/x2m_bulk.xml b/vendor/github.com/beego/x2j/x2m_bulk.xml new file mode 100644 index 00000000..0db72b73 --- /dev/null +++ b/vendor/github.com/beego/x2j/x2m_bulk.xml @@ -0,0 +1,67 @@ + + help me! + + + + Henry was a renegade + Didn't like to play it safe + One component at a time + There's got to be a better way + Oh, people came from miles around + Searching for a steady job + Welcome to the Motor Town + Booming like an atom bomb + + + Oh, Henry was the end of the story + Then everything went wrong + And we'll return it to its former glory + But it just takes so long + + + + It's going to take a long time + It's going to take it, but we'll make it one day + It's going to take a long time + It's going to take it, but we'll make it one day + + + + + help me! + + + + Henry was a renegade + Didn't like to play it safe + One component at a time + There's got to be a better way + Oh, people came from miles around + Searching for a steady job + Welcome to the Motor Town + Booming like an atom bomb + + + + + + help me! + + + It's going to take a long time + It's going to take it, but we'll make it one day + It's going to take a long time + It's going to take it, but we'll make it one day + + + + + help me! + + + It's going to take a long time + It's going to take it, but we'll make it one day + It's going to take a long time + It's going to take it, but we'll make it one day + + diff --git a/vendor/github.com/belogik/goes/.gitignore b/vendor/github.com/belogik/goes/.gitignore new file mode 100644 index 00000000..6788dd3a --- /dev/null +++ b/vendor/github.com/belogik/goes/.gitignore @@ -0,0 +1,2 @@ +*.test +*.swp diff --git a/vendor/github.com/belogik/goes/.travis.yml b/vendor/github.com/belogik/goes/.travis.yml new file mode 100644 index 00000000..4afb3109 --- /dev/null +++ b/vendor/github.com/belogik/goes/.travis.yml @@ -0,0 +1,35 @@ +language: go + +go: + - 1.1 + - 1.2 + - 1.3 + - 1.4.2 + - 1.4.3 + - 1.5 + - 1.5.1 + +env: + matrix: + - ES_VERSION=1.0.3 GROOVY_VER=2.0.0 + - ES_VERSION=1.1.2 GROOVY_VER=2.0.0 + - ES_VERSION=1.2.1 GROOVY_VER=2.2.0 + - ES_VERSION=1.3.4 + - ES_VERSION=1.4.4 + - ES_VERSION=1.5.2 + - ES_VERSION=1.6.0 + - ES_VERSION=1.7.0 + +before_script: + - mkdir ${HOME}/elasticsearch + - wget https://download.elastic.co/elasticsearch/elasticsearch/elasticsearch-${ES_VERSION}.tar.gz + - tar -xzf elasticsearch-${ES_VERSION}.tar.gz -C ${HOME}/elasticsearch + - "echo 'script.groovy.sandbox.enabled: true' >> ${HOME}/elasticsearch/elasticsearch-${ES_VERSION}/config/elasticsearch.yml" + - 'if [[ "${ES_VERSION}" < "1.3" ]]; then ${HOME}/elasticsearch/elasticsearch-${ES_VERSION}/bin/plugin --install elasticsearch/elasticsearch-lang-groovy/${GROOVY_VER}; fi' + - ${HOME}/elasticsearch/elasticsearch-${ES_VERSION}/bin/elasticsearch >/dev/null & + +install: + - go get gopkg.in/check.v1 + +script: + - make test diff --git a/vendor/github.com/belogik/goes/LICENSE b/vendor/github.com/belogik/goes/LICENSE new file mode 100644 index 00000000..cc3fe42e --- /dev/null +++ b/vendor/github.com/belogik/goes/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2013 Belogik. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Belogik nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/belogik/goes/Makefile b/vendor/github.com/belogik/goes/Makefile new file mode 100644 index 00000000..bfcb4139 --- /dev/null +++ b/vendor/github.com/belogik/goes/Makefile @@ -0,0 +1,10 @@ +help: + @echo "Available targets:" + @echo "- test: run tests" + @echo "- installdependencies: installs dependencies declared in dependencies.txt" + +installdependencies: + cat dependencies.txt | xargs go get + +test: installdependencies + go test -i && go test diff --git a/vendor/github.com/belogik/goes/README b/vendor/github.com/belogik/goes/README new file mode 100644 index 00000000..fc997ba9 --- /dev/null +++ b/vendor/github.com/belogik/goes/README @@ -0,0 +1,7 @@ +There is a new maintener for this library. + +Please go here : https://github.com/OwnLocal/goes + +!!!!! By using this repo you are running on thin ice !!!!!! + +https://github.com/belogik/goes/issues/40 might be of interest for you. diff --git a/vendor/github.com/belogik/goes/TODO b/vendor/github.com/belogik/goes/TODO new file mode 100644 index 00000000..7431ef89 --- /dev/null +++ b/vendor/github.com/belogik/goes/TODO @@ -0,0 +1 @@ +- Add Gzip support to bulk data to save bandwith diff --git a/vendor/github.com/belogik/goes/dependencies.txt b/vendor/github.com/belogik/goes/dependencies.txt new file mode 100644 index 00000000..5e35e095 --- /dev/null +++ b/vendor/github.com/belogik/goes/dependencies.txt @@ -0,0 +1 @@ +launchpad.net/gocheck diff --git a/vendor/github.com/belogik/goes/goes.go b/vendor/github.com/belogik/goes/goes.go new file mode 100644 index 00000000..107ce6ab --- /dev/null +++ b/vendor/github.com/belogik/goes/goes.go @@ -0,0 +1,603 @@ +// Copyright 2013 Belogik. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package goes provides an API to access Elasticsearch. +package goes + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "reflect" + "strconv" + "strings" +) + +const ( + BULK_COMMAND_INDEX = "index" + BULK_COMMAND_DELETE = "delete" +) + +func (err *SearchError) Error() string { + return fmt.Sprintf("[%d] %s", err.StatusCode, err.Msg) +} + +// NewConnection initiates a new Connection to an elasticsearch server +// +// This function is pretty useless for now but might be useful in a near future +// if wee need more features like connection pooling or load balancing. +func NewConnection(host string, port string) *Connection { + return &Connection{host, port, http.DefaultClient} +} + +func (c *Connection) WithClient(cl *http.Client) *Connection { + c.Client = cl + return c +} + +// CreateIndex creates a new index represented by a name and a mapping +func (c *Connection) CreateIndex(name string, mapping interface{}) (*Response, error) { + r := Request{ + Conn: c, + Query: mapping, + IndexList: []string{name}, + method: "PUT", + } + + return r.Run() +} + +// DeleteIndex deletes an index represented by a name +func (c *Connection) DeleteIndex(name string) (*Response, error) { + r := Request{ + Conn: c, + IndexList: []string{name}, + method: "DELETE", + } + + return r.Run() +} + +// RefreshIndex refreshes an index represented by a name +func (c *Connection) RefreshIndex(name string) (*Response, error) { + r := Request{ + Conn: c, + IndexList: []string{name}, + method: "POST", + api: "_refresh", + } + + return r.Run() +} + +// UpdateIndexSettings updates settings for existing index represented by a name and a settings +// as described here: https://www.elastic.co/guide/en/elasticsearch/reference/current/indices-update-settings.html +func (c *Connection) UpdateIndexSettings(name string, settings interface{}) (*Response, error) { + r := Request{ + Conn: c, + Query: settings, + IndexList: []string{name}, + method: "PUT", + api: "_settings", + } + + return r.Run() +} + +// Optimize an index represented by a name, extra args are also allowed please check: +// http://www.elasticsearch.org/guide/en/elasticsearch/reference/current/indices-optimize.html#indices-optimize +func (c *Connection) Optimize(indexList []string, extraArgs url.Values) (*Response, error) { + r := Request{ + Conn: c, + IndexList: indexList, + ExtraArgs: extraArgs, + method: "POST", + api: "_optimize", + } + + return r.Run() +} + +// Stats fetches statistics (_stats) for the current elasticsearch server +func (c *Connection) Stats(indexList []string, extraArgs url.Values) (*Response, error) { + r := Request{ + Conn: c, + IndexList: indexList, + ExtraArgs: extraArgs, + method: "GET", + api: "_stats", + } + + return r.Run() +} + +// IndexStatus fetches the status (_status) for the indices defined in +// indexList. Use _all in indexList to get stats for all indices +func (c *Connection) IndexStatus(indexList []string) (*Response, error) { + r := Request{ + Conn: c, + IndexList: indexList, + method: "GET", + api: "_status", + } + + return r.Run() +} + +// Bulk adds multiple documents in bulk mode +func (c *Connection) BulkSend(documents []Document) (*Response, error) { + // We do not generate a traditional JSON here (often a one liner) + // Elasticsearch expects one line of JSON per line (EOL = \n) + // plus an extra \n at the very end of the document + // + // More informations about the Bulk JSON format for Elasticsearch: + // + // - http://www.elasticsearch.org/guide/reference/api/bulk.html + // + // This is quite annoying for us as we can not use the simple JSON + // Marshaler available in Run(). + // + // We have to generate this special JSON by ourselves which leads to + // the code below. + // + // I know it is unreadable I must find an elegant way to fix this. + + // len(documents) * 2 : action + optional_sources + // + 1 : room for the trailing \n + bulkData := make([][]byte, len(documents)*2+1) + i := 0 + + for _, doc := range documents { + action, err := json.Marshal(map[string]interface{}{ + doc.BulkCommand: map[string]interface{}{ + "_index": doc.Index, + "_type": doc.Type, + "_id": doc.Id, + }, + }) + + if err != nil { + return &Response{}, err + } + + bulkData[i] = action + i++ + + if doc.Fields != nil { + if docFields, ok := doc.Fields.(map[string]interface{}); ok { + if len(docFields) == 0 { + continue + } + } else { + typeOfFields := reflect.TypeOf(doc.Fields) + if typeOfFields.Kind() == reflect.Ptr { + typeOfFields = typeOfFields.Elem() + } + if typeOfFields.Kind() != reflect.Struct { + return &Response{}, fmt.Errorf("Document fields not in struct or map[string]interface{} format") + } + if typeOfFields.NumField() == 0 { + continue + } + } + + sources, err := json.Marshal(doc.Fields) + if err != nil { + return &Response{}, err + } + + bulkData[i] = sources + i++ + } + } + + // forces an extra trailing \n absolutely necessary for elasticsearch + bulkData[len(bulkData)-1] = []byte(nil) + + r := Request{ + Conn: c, + method: "POST", + api: "_bulk", + bulkData: bytes.Join(bulkData, []byte("\n")), + } + + return r.Run() +} + +// Search executes a search query against an index +func (c *Connection) Search(query interface{}, indexList []string, typeList []string, extraArgs url.Values) (*Response, error) { + r := Request{ + Conn: c, + Query: query, + IndexList: indexList, + TypeList: typeList, + method: "POST", + api: "_search", + ExtraArgs: extraArgs, + } + + return r.Run() +} + +// Count executes a count query against an index, use the Count field in the response for the result +func (c *Connection) Count(query interface{}, indexList []string, typeList []string, extraArgs url.Values) (*Response, error) { + r := Request{ + Conn: c, + Query: query, + IndexList: indexList, + TypeList: typeList, + method: "POST", + api: "_count", + ExtraArgs: extraArgs, + } + + return r.Run() +} + +//Query runs a query against an index using the provided http method. +//This method can be used to execute a delete by query, just pass in "DELETE" +//for the HTTP method. +func (c *Connection) Query(query interface{}, indexList []string, typeList []string, httpMethod string, extraArgs url.Values) (*Response, error) { + r := Request{ + Conn: c, + Query: query, + IndexList: indexList, + TypeList: typeList, + method: httpMethod, + api: "_query", + ExtraArgs: extraArgs, + } + + return r.Run() +} + +// Scan starts scroll over an index +func (c *Connection) Scan(query interface{}, indexList []string, typeList []string, timeout string, size int) (*Response, error) { + v := url.Values{} + v.Add("search_type", "scan") + v.Add("scroll", timeout) + v.Add("size", strconv.Itoa(size)) + + r := Request{ + Conn: c, + Query: query, + IndexList: indexList, + TypeList: typeList, + method: "POST", + api: "_search", + ExtraArgs: v, + } + + return r.Run() +} + +// Scroll fetches data by scroll id +func (c *Connection) Scroll(scrollId string, timeout string) (*Response, error) { + v := url.Values{} + v.Add("scroll", timeout) + + r := Request{ + Conn: c, + method: "POST", + api: "_search/scroll", + ExtraArgs: v, + Body: []byte(scrollId), + } + + return r.Run() +} + +// Get a typed document by its id +func (c *Connection) Get(index string, documentType string, id string, extraArgs url.Values) (*Response, error) { + r := Request{ + Conn: c, + IndexList: []string{index}, + method: "GET", + api: documentType + "/" + id, + ExtraArgs: extraArgs, + } + + return r.Run() +} + +// Index indexes a Document +// The extraArgs is a list of url.Values that you can send to elasticsearch as +// URL arguments, for example, to control routing, ttl, version, op_type, etc. +func (c *Connection) Index(d Document, extraArgs url.Values) (*Response, error) { + r := Request{ + Conn: c, + Query: d.Fields, + IndexList: []string{d.Index.(string)}, + TypeList: []string{d.Type}, + ExtraArgs: extraArgs, + method: "POST", + } + + if d.Id != nil { + r.method = "PUT" + r.id = d.Id.(string) + } + + return r.Run() +} + +// Delete deletes a Document d +// The extraArgs is a list of url.Values that you can send to elasticsearch as +// URL arguments, for example, to control routing. +func (c *Connection) Delete(d Document, extraArgs url.Values) (*Response, error) { + r := Request{ + Conn: c, + IndexList: []string{d.Index.(string)}, + TypeList: []string{d.Type}, + ExtraArgs: extraArgs, + method: "DELETE", + id: d.Id.(string), + } + + return r.Run() +} + +// Run executes an elasticsearch Request. It converts data to Json, sends the +// request and returns the Response obtained +func (req *Request) Run() (*Response, error) { + body, statusCode, err := req.run() + esResp := &Response{Status: statusCode} + + if err != nil { + return esResp, err + } + + if req.method != "HEAD" { + err = json.Unmarshal(body, &esResp) + if err != nil { + return esResp, err + } + err = json.Unmarshal(body, &esResp.Raw) + if err != nil { + return esResp, err + } + } + + if req.api == "_bulk" && esResp.Errors { + for _, item := range esResp.Items { + for _, i := range item { + if i.Error != "" { + return esResp, &SearchError{i.Error, i.Status} + } + } + } + return esResp, &SearchError{Msg: "Unknown error while bulk indexing"} + } + + if esResp.Error != "" { + return esResp, &SearchError{esResp.Error, esResp.Status} + } + + return esResp, nil +} + +func (req *Request) run() ([]byte, uint64, error) { + postData := []byte{} + + // XXX : refactor this + if len(req.Body) > 0 { + postData = req.Body + } else if req.api == "_bulk" { + postData = req.bulkData + } else { + b, err := json.Marshal(req.Query) + if err != nil { + return nil, 0, err + } + postData = b + } + + reader := bytes.NewReader(postData) + + newReq, err := http.NewRequest(req.method, req.Url(), reader) + if err != nil { + return nil, 0, err + } + + if req.method == "POST" || req.method == "PUT" { + newReq.Header.Set("Content-Type", "application/json") + } + + resp, err := req.Conn.Client.Do(newReq) + if err != nil { + return nil, 0, err + } + + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, uint64(resp.StatusCode), err + } + + if resp.StatusCode > 201 && resp.StatusCode < 400 { + return nil, uint64(resp.StatusCode), errors.New(string(body)) + } + + return body, uint64(resp.StatusCode), nil +} + +// Url builds a Request for a URL +func (r *Request) Url() string { + path := "/" + strings.Join(r.IndexList, ",") + + if len(r.TypeList) > 0 { + path += "/" + strings.Join(r.TypeList, ",") + } + + // XXX : for indexing documents using the normal (non bulk) API + if len(r.id) > 0 { + path += "/" + r.id + } + + path += "/" + r.api + + u := url.URL{ + Scheme: "http", + Host: fmt.Sprintf("%s:%s", r.Conn.Host, r.Conn.Port), + Path: path, + RawQuery: r.ExtraArgs.Encode(), + } + + return u.String() +} + +// Buckets returns list of buckets in aggregation +func (a Aggregation) Buckets() []Bucket { + result := []Bucket{} + if buckets, ok := a["buckets"]; ok { + for _, bucket := range buckets.([]interface{}) { + result = append(result, bucket.(map[string]interface{})) + } + } + + return result +} + +// Key returns key for aggregation bucket +func (b Bucket) Key() interface{} { + return b["key"] +} + +// DocCount returns count of documents in this bucket +func (b Bucket) DocCount() uint64 { + return uint64(b["doc_count"].(float64)) +} + +// Aggregation returns aggregation by name from bucket +func (b Bucket) Aggregation(name string) Aggregation { + if agg, ok := b[name]; ok { + return agg.(map[string]interface{}) + } else { + return Aggregation{} + } +} + +// PutMapping registers a specific mapping for one or more types in one or more indexes +func (c *Connection) PutMapping(typeName string, mapping interface{}, indexes []string) (*Response, error) { + + r := Request{ + Conn: c, + Query: mapping, + IndexList: indexes, + method: "PUT", + api: "_mappings/" + typeName, + } + + return r.Run() +} + +func (c *Connection) GetMapping(types []string, indexes []string) (*Response, error) { + + r := Request{ + Conn: c, + IndexList: indexes, + method: "GET", + api: "_mapping/" + strings.Join(types, ","), + } + + return r.Run() +} + +// IndicesExist checks whether index (or indices) exist on the server +func (c *Connection) IndicesExist(indexes []string) (bool, error) { + + r := Request{ + Conn: c, + IndexList: indexes, + method: "HEAD", + } + + resp, err := r.Run() + + return resp.Status == 200, err +} + +func (c *Connection) Update(d Document, query interface{}, extraArgs url.Values) (*Response, error) { + r := Request{ + Conn: c, + Query: query, + IndexList: []string{d.Index.(string)}, + TypeList: []string{d.Type}, + ExtraArgs: extraArgs, + method: "POST", + api: "_update", + } + + if d.Id != nil { + r.id = d.Id.(string) + } + + return r.Run() +} + +// DeleteMapping deletes a mapping along with all data in the type +func (c *Connection) DeleteMapping(typeName string, indexes []string) (*Response, error) { + + r := Request{ + Conn: c, + IndexList: indexes, + method: "DELETE", + api: "_mappings/" + typeName, + } + + return r.Run() +} + +func (c *Connection) modifyAlias(action string, alias string, indexes []string) (*Response, error) { + command := map[string]interface{}{ + "actions": make([]map[string]interface{}, 1), + } + + for _, index := range indexes { + command["actions"] = append(command["actions"].([]map[string]interface{}), map[string]interface{}{ + action: map[string]interface{}{ + "index": index, + "alias": alias, + }, + }) + } + + r := Request{ + Conn: c, + Query: command, + method: "POST", + api: "_aliases", + } + + return r.Run() +} + +// AddAlias creates an alias to one or more indexes +func (c *Connection) AddAlias(alias string, indexes []string) (*Response, error) { + return c.modifyAlias("add", alias, indexes) +} + +// RemoveAlias removes an alias to one or more indexes +func (c *Connection) RemoveAlias(alias string, indexes []string) (*Response, error) { + return c.modifyAlias("remove", alias, indexes) +} + +// AliasExists checks whether alias is defined on the server +func (c *Connection) AliasExists(alias string) (bool, error) { + + r := Request{ + Conn: c, + method: "HEAD", + api: "_alias/" + alias, + } + + resp, err := r.Run() + + return resp.Status == 200, err +} diff --git a/vendor/github.com/belogik/goes/structs.go b/vendor/github.com/belogik/goes/structs.go new file mode 100644 index 00000000..c7264613 --- /dev/null +++ b/vendor/github.com/belogik/goes/structs.go @@ -0,0 +1,184 @@ +// Copyright 2013 Belogik. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package goes + +import ( + "net/http" + "net/url" +) + +// Represents a Connection object to elasticsearch +type Connection struct { + // The host to connect to + Host string + + // The port to use + Port string + + // Client is the http client used to make requests, allowing settings things + // such as timeouts etc + Client *http.Client +} + +// Represents a Request to elasticsearch +type Request struct { + // Which connection will be used + Conn *Connection + + // A search query + Query interface{} + + // Which index to search into + IndexList []string + + // Which type to search into + TypeList []string + + // HTTP Method to user (GET, POST ...) + method string + + // Which api keyword (_search, _bulk, etc) to use + api string + + // Bulk data + bulkData []byte + + // Request body + Body []byte + + // A list of extra URL arguments + ExtraArgs url.Values + + // Used for the id field when indexing a document + id string +} + +// Represents a Response from elasticsearch +type Response struct { + Acknowledged bool + Error string + Errors bool + Status uint64 + Took uint64 + TimedOut bool `json:"timed_out"` + Shards Shard `json:"_shards"` + Hits Hits + Index string `json:"_index"` + Id string `json:"_id"` + Type string `json:"_type"` + Version int `json:"_version"` + Found bool + Count int + + // Used by the _stats API + All All `json:"_all"` + + // Used by the _bulk API + Items []map[string]Item `json:"items,omitempty"` + + // Used by the GET API + Source map[string]interface{} `json:"_source"` + Fields map[string]interface{} `json:"fields"` + + // Used by the _status API + Indices map[string]IndexStatus + + // Scroll id for iteration + ScrollId string `json:"_scroll_id"` + + Aggregations map[string]Aggregation `json:"aggregations,omitempty"` + + Raw map[string]interface{} +} + +// Represents an aggregation from response +type Aggregation map[string]interface{} + +// Represents a bucket for aggregation +type Bucket map[string]interface{} + +// Represents a document to send to elasticsearch +type Document struct { + // XXX : interface as we can support nil values + Index interface{} + Type string + Id interface{} + BulkCommand string + Fields interface{} +} + +// Represents the "items" field in a _bulk response +type Item struct { + Type string `json:"_type"` + Id string `json:"_id"` + Index string `json:"_index"` + Version int `json:"_version"` + Error string `json:"error"` + Status uint64 `json:"status"` +} + +// Represents the "_all" field when calling the _stats API +// This is minimal but this is what I only need +type All struct { + Indices map[string]StatIndex `json:"indices"` + Primaries map[string]StatPrimary `json:"primaries"` +} + +type StatIndex struct { + Primaries map[string]StatPrimary `json:"primaries"` +} + +type StatPrimary struct { + // primary/docs: + Count int + Deleted int +} + +// Represents the "shard" struct as returned by elasticsearch +type Shard struct { + Total uint64 + Successful uint64 + Failed uint64 +} + +// Represent a hit returned by a search +type Hit struct { + Index string `json:"_index"` + Type string `json:"_type"` + Id string `json:"_id"` + Score float64 `json:"_score"` + Source map[string]interface{} `json:"_source"` + Highlight map[string]interface{} `json:"highlight"` + Fields map[string]interface{} `json:"fields"` +} + +// Represent the hits structure as returned by elasticsearch +type Hits struct { + Total uint64 + // max_score may contain the "null" value + MaxScore interface{} `json:"max_score"` + Hits []Hit +} + +type SearchError struct { + Msg string + StatusCode uint64 +} + +// Represent the status for a given index for the _status command +type IndexStatus struct { + // XXX : problem, int will be marshaled to a float64 which seems logical + // XXX : is it better to use strings even for int values or to keep + // XXX : interfaces and deal with float64 ? + Index map[string]interface{} + + Translog map[string]uint64 + Docs map[string]uint64 + Merges map[string]interface{} + Refresh map[string]interface{} + Flush map[string]interface{} + + // TODO: add shards support later, we do not need it for the moment +} diff --git a/vendor/github.com/bradfitz/gomemcache/LICENSE b/vendor/github.com/bradfitz/gomemcache/LICENSE new file mode 100644 index 00000000..d6456956 --- /dev/null +++ b/vendor/github.com/bradfitz/gomemcache/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/bradfitz/gomemcache/memcache/memcache.go b/vendor/github.com/bradfitz/gomemcache/memcache/memcache.go new file mode 100644 index 00000000..25e88ca2 --- /dev/null +++ b/vendor/github.com/bradfitz/gomemcache/memcache/memcache.go @@ -0,0 +1,687 @@ +/* +Copyright 2011 Google Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package memcache provides a client for the memcached cache server. +package memcache + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "net" + + "strconv" + "strings" + "sync" + "time" +) + +// Similar to: +// https://godoc.org/google.golang.org/appengine/memcache + +var ( + // ErrCacheMiss means that a Get failed because the item wasn't present. + ErrCacheMiss = errors.New("memcache: cache miss") + + // ErrCASConflict means that a CompareAndSwap call failed due to the + // cached value being modified between the Get and the CompareAndSwap. + // If the cached value was simply evicted rather than replaced, + // ErrNotStored will be returned instead. + ErrCASConflict = errors.New("memcache: compare-and-swap conflict") + + // ErrNotStored means that a conditional write operation (i.e. Add or + // CompareAndSwap) failed because the condition was not satisfied. + ErrNotStored = errors.New("memcache: item not stored") + + // ErrServer means that a server error occurred. + ErrServerError = errors.New("memcache: server error") + + // ErrNoStats means that no statistics were available. + ErrNoStats = errors.New("memcache: no statistics available") + + // ErrMalformedKey is returned when an invalid key is used. + // Keys must be at maximum 250 bytes long and not + // contain whitespace or control characters. + ErrMalformedKey = errors.New("malformed: key is too long or contains invalid characters") + + // ErrNoServers is returned when no servers are configured or available. + ErrNoServers = errors.New("memcache: no servers configured or available") +) + +const ( + // DefaultTimeout is the default socket read/write timeout. + DefaultTimeout = 100 * time.Millisecond + + // DefaultMaxIdleConns is the default maximum number of idle connections + // kept for any single address. + DefaultMaxIdleConns = 2 +) + +const buffered = 8 // arbitrary buffered channel size, for readability + +// resumableError returns true if err is only a protocol-level cache error. +// This is used to determine whether or not a server connection should +// be re-used or not. If an error occurs, by default we don't reuse the +// connection, unless it was just a cache error. +func resumableError(err error) bool { + switch err { + case ErrCacheMiss, ErrCASConflict, ErrNotStored, ErrMalformedKey: + return true + } + return false +} + +func legalKey(key string) bool { + if len(key) > 250 { + return false + } + for i := 0; i < len(key); i++ { + if key[i] <= ' ' || key[i] == 0x7f { + return false + } + } + return true +} + +var ( + crlf = []byte("\r\n") + space = []byte(" ") + resultOK = []byte("OK\r\n") + resultStored = []byte("STORED\r\n") + resultNotStored = []byte("NOT_STORED\r\n") + resultExists = []byte("EXISTS\r\n") + resultNotFound = []byte("NOT_FOUND\r\n") + resultDeleted = []byte("DELETED\r\n") + resultEnd = []byte("END\r\n") + resultOk = []byte("OK\r\n") + resultTouched = []byte("TOUCHED\r\n") + + resultClientErrorPrefix = []byte("CLIENT_ERROR ") +) + +// New returns a memcache client using the provided server(s) +// with equal weight. If a server is listed multiple times, +// it gets a proportional amount of weight. +func New(server ...string) *Client { + ss := new(ServerList) + ss.SetServers(server...) + return NewFromSelector(ss) +} + +// NewFromSelector returns a new Client using the provided ServerSelector. +func NewFromSelector(ss ServerSelector) *Client { + return &Client{selector: ss} +} + +// Client is a memcache client. +// It is safe for unlocked use by multiple concurrent goroutines. +type Client struct { + // Timeout specifies the socket read/write timeout. + // If zero, DefaultTimeout is used. + Timeout time.Duration + + // MaxIdleConns specifies the maximum number of idle connections that will + // be maintained per address. If less than one, DefaultMaxIdleConns will be + // used. + // + // Consider your expected traffic rates and latency carefully. This should + // be set to a number higher than your peak parallel requests. + MaxIdleConns int + + selector ServerSelector + + lk sync.Mutex + freeconn map[string][]*conn +} + +// Item is an item to be got or stored in a memcached server. +type Item struct { + // Key is the Item's key (250 bytes maximum). + Key string + + // Value is the Item's value. + Value []byte + + // Flags are server-opaque flags whose semantics are entirely + // up to the app. + Flags uint32 + + // Expiration is the cache expiration time, in seconds: either a relative + // time from now (up to 1 month), or an absolute Unix epoch time. + // Zero means the Item has no expiration time. + Expiration int32 + + // Compare and swap ID. + casid uint64 +} + +// conn is a connection to a server. +type conn struct { + nc net.Conn + rw *bufio.ReadWriter + addr net.Addr + c *Client +} + +// release returns this connection back to the client's free pool +func (cn *conn) release() { + cn.c.putFreeConn(cn.addr, cn) +} + +func (cn *conn) extendDeadline() { + cn.nc.SetDeadline(time.Now().Add(cn.c.netTimeout())) +} + +// condRelease releases this connection if the error pointed to by err +// is nil (not an error) or is only a protocol level error (e.g. a +// cache miss). The purpose is to not recycle TCP connections that +// are bad. +func (cn *conn) condRelease(err *error) { + if *err == nil || resumableError(*err) { + cn.release() + } else { + cn.nc.Close() + } +} + +func (c *Client) putFreeConn(addr net.Addr, cn *conn) { + c.lk.Lock() + defer c.lk.Unlock() + if c.freeconn == nil { + c.freeconn = make(map[string][]*conn) + } + freelist := c.freeconn[addr.String()] + if len(freelist) >= c.maxIdleConns() { + cn.nc.Close() + return + } + c.freeconn[addr.String()] = append(freelist, cn) +} + +func (c *Client) getFreeConn(addr net.Addr) (cn *conn, ok bool) { + c.lk.Lock() + defer c.lk.Unlock() + if c.freeconn == nil { + return nil, false + } + freelist, ok := c.freeconn[addr.String()] + if !ok || len(freelist) == 0 { + return nil, false + } + cn = freelist[len(freelist)-1] + c.freeconn[addr.String()] = freelist[:len(freelist)-1] + return cn, true +} + +func (c *Client) netTimeout() time.Duration { + if c.Timeout != 0 { + return c.Timeout + } + return DefaultTimeout +} + +func (c *Client) maxIdleConns() int { + if c.MaxIdleConns > 0 { + return c.MaxIdleConns + } + return DefaultMaxIdleConns +} + +// ConnectTimeoutError is the error type used when it takes +// too long to connect to the desired host. This level of +// detail can generally be ignored. +type ConnectTimeoutError struct { + Addr net.Addr +} + +func (cte *ConnectTimeoutError) Error() string { + return "memcache: connect timeout to " + cte.Addr.String() +} + +func (c *Client) dial(addr net.Addr) (net.Conn, error) { + type connError struct { + cn net.Conn + err error + } + + nc, err := net.DialTimeout(addr.Network(), addr.String(), c.netTimeout()) + if err == nil { + return nc, nil + } + + if ne, ok := err.(net.Error); ok && ne.Timeout() { + return nil, &ConnectTimeoutError{addr} + } + + return nil, err +} + +func (c *Client) getConn(addr net.Addr) (*conn, error) { + cn, ok := c.getFreeConn(addr) + if ok { + cn.extendDeadline() + return cn, nil + } + nc, err := c.dial(addr) + if err != nil { + return nil, err + } + cn = &conn{ + nc: nc, + addr: addr, + rw: bufio.NewReadWriter(bufio.NewReader(nc), bufio.NewWriter(nc)), + c: c, + } + cn.extendDeadline() + return cn, nil +} + +func (c *Client) onItem(item *Item, fn func(*Client, *bufio.ReadWriter, *Item) error) error { + addr, err := c.selector.PickServer(item.Key) + if err != nil { + return err + } + cn, err := c.getConn(addr) + if err != nil { + return err + } + defer cn.condRelease(&err) + if err = fn(c, cn.rw, item); err != nil { + return err + } + return nil +} + +func (c *Client) FlushAll() error { + return c.selector.Each(c.flushAllFromAddr) +} + +// Get gets the item for the given key. ErrCacheMiss is returned for a +// memcache cache miss. The key must be at most 250 bytes in length. +func (c *Client) Get(key string) (item *Item, err error) { + err = c.withKeyAddr(key, func(addr net.Addr) error { + return c.getFromAddr(addr, []string{key}, func(it *Item) { item = it }) + }) + if err == nil && item == nil { + err = ErrCacheMiss + } + return +} + +// Touch updates the expiry for the given key. The seconds parameter is either +// a Unix timestamp or, if seconds is less than 1 month, the number of seconds +// into the future at which time the item will expire. Zero means the item has +// no expiration time. ErrCacheMiss is returned if the key is not in the cache. +// The key must be at most 250 bytes in length. +func (c *Client) Touch(key string, seconds int32) (err error) { + return c.withKeyAddr(key, func(addr net.Addr) error { + return c.touchFromAddr(addr, []string{key}, seconds) + }) +} + +func (c *Client) withKeyAddr(key string, fn func(net.Addr) error) (err error) { + if !legalKey(key) { + return ErrMalformedKey + } + addr, err := c.selector.PickServer(key) + if err != nil { + return err + } + return fn(addr) +} + +func (c *Client) withAddrRw(addr net.Addr, fn func(*bufio.ReadWriter) error) (err error) { + cn, err := c.getConn(addr) + if err != nil { + return err + } + defer cn.condRelease(&err) + return fn(cn.rw) +} + +func (c *Client) withKeyRw(key string, fn func(*bufio.ReadWriter) error) error { + return c.withKeyAddr(key, func(addr net.Addr) error { + return c.withAddrRw(addr, fn) + }) +} + +func (c *Client) getFromAddr(addr net.Addr, keys []string, cb func(*Item)) error { + return c.withAddrRw(addr, func(rw *bufio.ReadWriter) error { + if _, err := fmt.Fprintf(rw, "gets %s\r\n", strings.Join(keys, " ")); err != nil { + return err + } + if err := rw.Flush(); err != nil { + return err + } + if err := parseGetResponse(rw.Reader, cb); err != nil { + return err + } + return nil + }) +} + +// flushAllFromAddr send the flush_all command to the given addr +func (c *Client) flushAllFromAddr(addr net.Addr) error { + return c.withAddrRw(addr, func(rw *bufio.ReadWriter) error { + if _, err := fmt.Fprintf(rw, "flush_all\r\n"); err != nil { + return err + } + if err := rw.Flush(); err != nil { + return err + } + line, err := rw.ReadSlice('\n') + if err != nil { + return err + } + switch { + case bytes.Equal(line, resultOk): + break + default: + return fmt.Errorf("memcache: unexpected response line from flush_all: %q", string(line)) + } + return nil + }) +} + +func (c *Client) touchFromAddr(addr net.Addr, keys []string, expiration int32) error { + return c.withAddrRw(addr, func(rw *bufio.ReadWriter) error { + for _, key := range keys { + if _, err := fmt.Fprintf(rw, "touch %s %d\r\n", key, expiration); err != nil { + return err + } + if err := rw.Flush(); err != nil { + return err + } + line, err := rw.ReadSlice('\n') + if err != nil { + return err + } + switch { + case bytes.Equal(line, resultTouched): + break + case bytes.Equal(line, resultNotFound): + return ErrCacheMiss + default: + return fmt.Errorf("memcache: unexpected response line from touch: %q", string(line)) + } + } + return nil + }) +} + +// GetMulti is a batch version of Get. The returned map from keys to +// items may have fewer elements than the input slice, due to memcache +// cache misses. Each key must be at most 250 bytes in length. +// If no error is returned, the returned map will also be non-nil. +func (c *Client) GetMulti(keys []string) (map[string]*Item, error) { + var lk sync.Mutex + m := make(map[string]*Item) + addItemToMap := func(it *Item) { + lk.Lock() + defer lk.Unlock() + m[it.Key] = it + } + + keyMap := make(map[net.Addr][]string) + for _, key := range keys { + if !legalKey(key) { + return nil, ErrMalformedKey + } + addr, err := c.selector.PickServer(key) + if err != nil { + return nil, err + } + keyMap[addr] = append(keyMap[addr], key) + } + + ch := make(chan error, buffered) + for addr, keys := range keyMap { + go func(addr net.Addr, keys []string) { + ch <- c.getFromAddr(addr, keys, addItemToMap) + }(addr, keys) + } + + var err error + for _ = range keyMap { + if ge := <-ch; ge != nil { + err = ge + } + } + return m, err +} + +// parseGetResponse reads a GET response from r and calls cb for each +// read and allocated Item +func parseGetResponse(r *bufio.Reader, cb func(*Item)) error { + for { + line, err := r.ReadSlice('\n') + if err != nil { + return err + } + if bytes.Equal(line, resultEnd) { + return nil + } + it := new(Item) + size, err := scanGetResponseLine(line, it) + if err != nil { + return err + } + it.Value = make([]byte, size+2) + _, err = io.ReadFull(r, it.Value) + if err != nil { + it.Value = nil + return err + } + if !bytes.HasSuffix(it.Value, crlf) { + it.Value = nil + return fmt.Errorf("memcache: corrupt get result read") + } + it.Value = it.Value[:size] + cb(it) + } +} + +// scanGetResponseLine populates it and returns the declared size of the item. +// It does not read the bytes of the item. +func scanGetResponseLine(line []byte, it *Item) (size int, err error) { + pattern := "VALUE %s %d %d %d\r\n" + dest := []interface{}{&it.Key, &it.Flags, &size, &it.casid} + if bytes.Count(line, space) == 3 { + pattern = "VALUE %s %d %d\r\n" + dest = dest[:3] + } + n, err := fmt.Sscanf(string(line), pattern, dest...) + if err != nil || n != len(dest) { + return -1, fmt.Errorf("memcache: unexpected line in get response: %q", line) + } + return size, nil +} + +// Set writes the given item, unconditionally. +func (c *Client) Set(item *Item) error { + return c.onItem(item, (*Client).set) +} + +func (c *Client) set(rw *bufio.ReadWriter, item *Item) error { + return c.populateOne(rw, "set", item) +} + +// Add writes the given item, if no value already exists for its +// key. ErrNotStored is returned if that condition is not met. +func (c *Client) Add(item *Item) error { + return c.onItem(item, (*Client).add) +} + +func (c *Client) add(rw *bufio.ReadWriter, item *Item) error { + return c.populateOne(rw, "add", item) +} + +// Replace writes the given item, but only if the server *does* +// already hold data for this key +func (c *Client) Replace(item *Item) error { + return c.onItem(item, (*Client).replace) +} + +func (c *Client) replace(rw *bufio.ReadWriter, item *Item) error { + return c.populateOne(rw, "replace", item) +} + +// CompareAndSwap writes the given item that was previously returned +// by Get, if the value was neither modified or evicted between the +// Get and the CompareAndSwap calls. The item's Key should not change +// between calls but all other item fields may differ. ErrCASConflict +// is returned if the value was modified in between the +// calls. ErrNotStored is returned if the value was evicted in between +// the calls. +func (c *Client) CompareAndSwap(item *Item) error { + return c.onItem(item, (*Client).cas) +} + +func (c *Client) cas(rw *bufio.ReadWriter, item *Item) error { + return c.populateOne(rw, "cas", item) +} + +func (c *Client) populateOne(rw *bufio.ReadWriter, verb string, item *Item) error { + if !legalKey(item.Key) { + return ErrMalformedKey + } + var err error + if verb == "cas" { + _, err = fmt.Fprintf(rw, "%s %s %d %d %d %d\r\n", + verb, item.Key, item.Flags, item.Expiration, len(item.Value), item.casid) + } else { + _, err = fmt.Fprintf(rw, "%s %s %d %d %d\r\n", + verb, item.Key, item.Flags, item.Expiration, len(item.Value)) + } + if err != nil { + return err + } + if _, err = rw.Write(item.Value); err != nil { + return err + } + if _, err := rw.Write(crlf); err != nil { + return err + } + if err := rw.Flush(); err != nil { + return err + } + line, err := rw.ReadSlice('\n') + if err != nil { + return err + } + switch { + case bytes.Equal(line, resultStored): + return nil + case bytes.Equal(line, resultNotStored): + return ErrNotStored + case bytes.Equal(line, resultExists): + return ErrCASConflict + case bytes.Equal(line, resultNotFound): + return ErrCacheMiss + } + return fmt.Errorf("memcache: unexpected response line from %q: %q", verb, string(line)) +} + +func writeReadLine(rw *bufio.ReadWriter, format string, args ...interface{}) ([]byte, error) { + _, err := fmt.Fprintf(rw, format, args...) + if err != nil { + return nil, err + } + if err := rw.Flush(); err != nil { + return nil, err + } + line, err := rw.ReadSlice('\n') + return line, err +} + +func writeExpectf(rw *bufio.ReadWriter, expect []byte, format string, args ...interface{}) error { + line, err := writeReadLine(rw, format, args...) + if err != nil { + return err + } + switch { + case bytes.Equal(line, resultOK): + return nil + case bytes.Equal(line, expect): + return nil + case bytes.Equal(line, resultNotStored): + return ErrNotStored + case bytes.Equal(line, resultExists): + return ErrCASConflict + case bytes.Equal(line, resultNotFound): + return ErrCacheMiss + } + return fmt.Errorf("memcache: unexpected response line: %q", string(line)) +} + +// Delete deletes the item with the provided key. The error ErrCacheMiss is +// returned if the item didn't already exist in the cache. +func (c *Client) Delete(key string) error { + return c.withKeyRw(key, func(rw *bufio.ReadWriter) error { + return writeExpectf(rw, resultDeleted, "delete %s\r\n", key) + }) +} + +// DeleteAll deletes all items in the cache. +func (c *Client) DeleteAll() error { + return c.withKeyRw("", func(rw *bufio.ReadWriter) error { + return writeExpectf(rw, resultDeleted, "flush_all\r\n") + }) +} + +// Increment atomically increments key by delta. The return value is +// the new value after being incremented or an error. If the value +// didn't exist in memcached the error is ErrCacheMiss. The value in +// memcached must be an decimal number, or an error will be returned. +// On 64-bit overflow, the new value wraps around. +func (c *Client) Increment(key string, delta uint64) (newValue uint64, err error) { + return c.incrDecr("incr", key, delta) +} + +// Decrement atomically decrements key by delta. The return value is +// the new value after being decremented or an error. If the value +// didn't exist in memcached the error is ErrCacheMiss. The value in +// memcached must be an decimal number, or an error will be returned. +// On underflow, the new value is capped at zero and does not wrap +// around. +func (c *Client) Decrement(key string, delta uint64) (newValue uint64, err error) { + return c.incrDecr("decr", key, delta) +} + +func (c *Client) incrDecr(verb, key string, delta uint64) (uint64, error) { + var val uint64 + err := c.withKeyRw(key, func(rw *bufio.ReadWriter) error { + line, err := writeReadLine(rw, "%s %s %d\r\n", verb, key, delta) + if err != nil { + return err + } + switch { + case bytes.Equal(line, resultNotFound): + return ErrCacheMiss + case bytes.HasPrefix(line, resultClientErrorPrefix): + errMsg := line[len(resultClientErrorPrefix) : len(line)-2] + return errors.New("memcache: client error: " + string(errMsg)) + } + val, err = strconv.ParseUint(string(line[:len(line)-2]), 10, 64) + if err != nil { + return err + } + return nil + }) + return val, err +} diff --git a/vendor/github.com/bradfitz/gomemcache/memcache/selector.go b/vendor/github.com/bradfitz/gomemcache/memcache/selector.go new file mode 100644 index 00000000..89ad81e0 --- /dev/null +++ b/vendor/github.com/bradfitz/gomemcache/memcache/selector.go @@ -0,0 +1,129 @@ +/* +Copyright 2011 Google Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package memcache + +import ( + "hash/crc32" + "net" + "strings" + "sync" +) + +// ServerSelector is the interface that selects a memcache server +// as a function of the item's key. +// +// All ServerSelector implementations must be safe for concurrent use +// by multiple goroutines. +type ServerSelector interface { + // PickServer returns the server address that a given item + // should be shared onto. + PickServer(key string) (net.Addr, error) + Each(func(net.Addr) error) error +} + +// ServerList is a simple ServerSelector. Its zero value is usable. +type ServerList struct { + mu sync.RWMutex + addrs []net.Addr +} + +// staticAddr caches the Network() and String() values from any net.Addr. +type staticAddr struct { + ntw, str string +} + +func newStaticAddr(a net.Addr) net.Addr { + return &staticAddr{ + ntw: a.Network(), + str: a.String(), + } +} + +func (s *staticAddr) Network() string { return s.ntw } +func (s *staticAddr) String() string { return s.str } + +// SetServers changes a ServerList's set of servers at runtime and is +// safe for concurrent use by multiple goroutines. +// +// Each server is given equal weight. A server is given more weight +// if it's listed multiple times. +// +// SetServers returns an error if any of the server names fail to +// resolve. No attempt is made to connect to the server. If any error +// is returned, no changes are made to the ServerList. +func (ss *ServerList) SetServers(servers ...string) error { + naddr := make([]net.Addr, len(servers)) + for i, server := range servers { + if strings.Contains(server, "/") { + addr, err := net.ResolveUnixAddr("unix", server) + if err != nil { + return err + } + naddr[i] = newStaticAddr(addr) + } else { + tcpaddr, err := net.ResolveTCPAddr("tcp", server) + if err != nil { + return err + } + naddr[i] = newStaticAddr(tcpaddr) + } + } + + ss.mu.Lock() + defer ss.mu.Unlock() + ss.addrs = naddr + return nil +} + +// Each iterates over each server calling the given function +func (ss *ServerList) Each(f func(net.Addr) error) error { + ss.mu.RLock() + defer ss.mu.RUnlock() + for _, a := range ss.addrs { + if err := f(a); nil != err { + return err + } + } + return nil +} + +// keyBufPool returns []byte buffers for use by PickServer's call to +// crc32.ChecksumIEEE to avoid allocations. (but doesn't avoid the +// copies, which at least are bounded in size and small) +var keyBufPool = sync.Pool{ + New: func() interface{} { + b := make([]byte, 256) + return &b + }, +} + +func (ss *ServerList) PickServer(key string) (net.Addr, error) { + ss.mu.RLock() + defer ss.mu.RUnlock() + if len(ss.addrs) == 0 { + return nil, ErrNoServers + } + if len(ss.addrs) == 1 { + return ss.addrs[0], nil + } + bufp := keyBufPool.Get().(*[]byte) + n := copy(*bufp, key) + cs := crc32.ChecksumIEEE((*bufp)[:n]) + keyBufPool.Put(bufp) + + return ss.addrs[cs%uint32(len(ss.addrs))], nil +} diff --git a/vendor/github.com/casbin/casbin/.gitignore b/vendor/github.com/casbin/casbin/.gitignore new file mode 100644 index 00000000..72ed6f47 --- /dev/null +++ b/vendor/github.com/casbin/casbin/.gitignore @@ -0,0 +1,27 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test +*.prof + +.idea/ +*.iml \ No newline at end of file diff --git a/vendor/github.com/casbin/casbin/.travis.yml b/vendor/github.com/casbin/casbin/.travis.yml new file mode 100644 index 00000000..be7f6341 --- /dev/null +++ b/vendor/github.com/casbin/casbin/.travis.yml @@ -0,0 +1,12 @@ +language: go + +sudo: false + +go: + - tip + +before_install: + - go get github.com/mattn/goveralls + +script: + - $HOME/gopath/bin/goveralls -service=travis-ci \ No newline at end of file diff --git a/vendor/github.com/casbin/casbin/LICENSE b/vendor/github.com/casbin/casbin/LICENSE new file mode 100644 index 00000000..8dada3ed --- /dev/null +++ b/vendor/github.com/casbin/casbin/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/casbin/casbin/README.md b/vendor/github.com/casbin/casbin/README.md new file mode 100644 index 00000000..c4704ae7 --- /dev/null +++ b/vendor/github.com/casbin/casbin/README.md @@ -0,0 +1,366 @@ +Casbin +==== + +[![Go Report Card](https://goreportcard.com/badge/github.com/casbin/casbin)](https://goreportcard.com/report/github.com/casbin/casbin) +[![Build Status](https://travis-ci.org/casbin/casbin.svg?branch=master)](https://travis-ci.org/casbin/casbin) +[![Coverage Status](https://coveralls.io/repos/github/casbin/casbin/badge.svg?branch=master)](https://coveralls.io/github/casbin/casbin?branch=master) +[![Godoc](https://godoc.org/github.com/casbin/casbin?status.svg)](https://godoc.org/github.com/casbin/casbin) +[![Release](https://img.shields.io/github/release/casbin/casbin.svg)](https://github.com/casbin/casbin/releases/latest) +[![Gitter](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/casbin/lobby) +[![Patreon](https://img.shields.io/badge/patreon-donate-yellow.svg)](http://www.patreon.com/yangluo) +[![Sourcegraph Badge](https://sourcegraph.com/github.com/casbin/casbin/-/badge.svg)](https://sourcegraph.com/github.com/casbin/casbin?badge) + +**News**: still worry about how to write the correct Casbin policy? ``Casbin online editor`` is coming to help! Try it at: http://casbin.org/editor/ + +![casbin Logo](casbin-logo.png) + +Casbin is a powerful and efficient open-source access control library for Golang projects. It provides support for enforcing authorization based on various [access control models](https://en.wikipedia.org/wiki/Computer_security_model). + +## All the languages supported by Casbin: + +- Golang: [Casbin](https://github.com/casbin/casbin) (production-ready) +- Java: [jCasbin](https://github.com/casbin/jcasbin) (production-ready) +- PHP: [PHP-Casbin](https://github.com/sstutz/php-casbin) (experimental) +- Node.js: [node-casbin](https://github.com/casbin/node-casbin) (WIP) +- C++: xCasbin (WIP) + +## Table of contents + +- [Supported models](#supported-models) +- [How it works?](#how-it-works) +- [Features](#features) +- [Installation](#installation) +- [Documentation](#documentation) +- [Online editor](#online-editor) +- [Tutorials](#tutorials) +- [Get started](#get-started) +- [Policy management](#policy-management) +- [Policy persistence](#policy-persistence) +- [Policy consistence between multiple nodes](#policy-consistence-between-multiple-nodes) +- [Role manager](#role-manager) +- [Multi-threading](#multi-threading) +- [Benchmarks](#benchmarks) +- [Examples](#examples) +- [How to use Casbin as a service?](#how-to-use-casbin-as-a-service) +- [Our adopters](#our-adopters) + +## Supported models + +1. [**ACL (Access Control List)**](https://en.wikipedia.org/wiki/Access_control_list) +2. **ACL with [superuser](https://en.wikipedia.org/wiki/Superuser)** +3. **ACL without users**: especially useful for systems that don't have authentication or user log-ins. +3. **ACL without resources**: some scenarios may target for a type of resources instead of an individual resource by using permissions like ``write-article``, ``read-log``. It doesn't control the access to a specific article or log. +4. **[RBAC (Role-Based Access Control)](https://en.wikipedia.org/wiki/Role-based_access_control)** +5. **RBAC with resource roles**: both users and resources can have roles (or groups) at the same time. +6. **RBAC with domains/tenants**: users can have different role sets for different domains/tenants. +7. **[ABAC (Attribute-Based Access Control)](https://en.wikipedia.org/wiki/Attribute-Based_Access_Control)**: syntax sugar like ``resource.Owner`` can be used to get the attribute for a resource. +8. **[RESTful](https://en.wikipedia.org/wiki/Representational_state_transfer)**: supports paths like ``/res/*``, ``/res/:id`` and HTTP methods like ``GET``, ``POST``, ``PUT``, ``DELETE``. +9. **Deny-override**: both allow and deny authorizations are supported, deny overrides the allow. +10. **Priority**: the policy rules can be prioritized like firewall rules. + +## How it works? + +In Casbin, an access control model is abstracted into a CONF file based on the **PERM metamodel (Policy, Effect, Request, Matchers)**. So switching or upgrading the authorization mechanism for a project is just as simple as modifying a configuration. You can customize your own access control model by combining the available models. For example, you can get RBAC roles and ABAC attributes together inside one model and share one set of policy rules. + +The most basic and simplest model in Casbin is ACL. ACL's model CONF is: + +```ini +# Request definition +[request_definition] +r = sub, obj, act + +# Policy definition +[policy_definition] +p = sub, obj, act + +# Policy effect +[policy_effect] +e = some(where (p.eft == allow)) + +# Matchers +[matchers] +m = r.sub == p.sub && r.obj == p.obj && r.act == p.act + +# We also support multi-line mode by appending '\' in the end: +# m = r.sub == p.sub && r.obj == p.obj \ +# && r.act == p.act +``` + +An example policy for ACL model is like: + +``` +p, alice, data1, read +p, bob, data2, write +``` + +It means: + +- alice can read data1 +- bob can write data2 + +## Features + +What Casbin does: + +1. enforce the policy in the classic ``{subject, object, action}`` form or a customized form as you defined, both allow and deny authorizations are supported. +2. handle the storage of the access control model and its policy. +3. manage the role-user mappings and role-role mappings (aka role hierarchy in RBAC). +4. support built-in superuser like ``root`` or ``administrator``. A superuser can do anything without explict permissions. +5. multiple built-in operators to support the rule matching. For example, ``keyMatch`` can map a resource key ``/foo/bar`` to the pattern ``/foo*``. + +What Casbin does NOT do: + +1. authentication (aka verify ``username`` and ``password`` when a user logs in) +2. manage the list of users or roles. I believe it's more convenient for the project itself to manage these entities. Users usually have their passwords, and Casbin is not designed as a password container. However, Casbin stores the user-role mapping for the RBAC scenario. + +## Installation + +``` +go get github.com/casbin/casbin +``` + +## Documentation + +For documentation, please see: [Our Wiki](https://github.com/casbin/casbin/wiki) + +## Online editor + +You can also use the online editor (http://casbin.org/editor/) to write your Casbin model and policy in your web browser. It provides functionality such as ``syntax highlighting`` and ``code completion``, just like an IDE for a programming language. + +## Tutorials + +- [Basic Role-Based HTTP Authorization in Go with Casbin](https://zupzup.org/casbin-http-role-auth) (or [Chinese translation](https://studygolang.com/articles/12323)) +- [Policy enforcements on Kubernetes with Banzai Cloud's Pipeline and Casbin](https://banzaicloud.com/blog/policy-enforcement-k8s/) +- [Using Casbin with Beego: 1. Get started and test (in Chinese)](https://blog.csdn.net/hotqin888/article/details/78460385) +- [Using Casbin with Beego: 2. Policy storage (in Chinese)](https://blog.csdn.net/hotqin888/article/details/78571240) +- [Using Casbin with Beego: 3. Policy query (in Chinese)](https://blog.csdn.net/hotqin888/article/details/78992250) +- [Using Casbin with Beego: 4. Policy update (in Chinese)](https://blog.csdn.net/hotqin888/article/details/80032538) + +## Get started + +1. New a Casbin enforcer with a model file and a policy file: + + ```go + e := casbin.NewEnforcer("path/to/model.conf", "path/to/policy.csv") + ``` + +Note: you can also initialize an enforcer with policy in DB instead of file, see [Persistence](#persistence) section for details. + +2. Add an enforcement hook into your code right before the access happens: + + ```go + sub := "alice" // the user that wants to access a resource. + obj := "data1" // the resource that is going to be accessed. + act := "read" // the operation that the user performs on the resource. + + if e.Enforce(sub, obj, act) == true { + // permit alice to read data1 + } else { + // deny the request, show an error + } + ``` + +3. Besides the static policy file, Casbin also provides API for permission management at run-time. For example, You can get all the roles assigned to a user as below: + + ```go + roles := e.GetRoles("alice") + ``` + +See [Policy management APIs](#policy-management) for more usage. + +4. Please refer to the ``_test.go`` files for more usage. + +## Policy management + +Casbin provides two sets of APIs to manage permissions: + +- [Management API](https://github.com/casbin/casbin/blob/master/management_api.go): the primitive API that provides full support for Casbin policy management. See [here](https://github.com/casbin/casbin/blob/master/management_api_test.go) for examples. +- [RBAC API](https://github.com/casbin/casbin/blob/master/rbac_api.go): a more friendly API for RBAC. This API is a subset of Management API. The RBAC users could use this API to simplify the code. See [here](https://github.com/casbin/casbin/blob/master/rbac_api_test.go) for examples. + +We also provide a web-based UI for model management and policy management: + +![model editor](https://hsluoyz.github.io/casbin/ui_model_editor.png) + +![policy editor](https://hsluoyz.github.io/casbin/ui_policy_editor.png) + +## Policy persistence + +In Casbin, the policy storage is implemented as an adapter (aka middleware for Casbin). To keep light-weight, we don't put adapter code in the main library (except the default file adapter). A complete list of Casbin adapters is provided as below. Any 3rd-party contribution on a new adapter is welcomed, please inform us and I will put it in this list:) + +Adapter | Type | Author | Description +----|------|----|---- +[File Adapter (built-in)](https://github.com/casbin/casbin/wiki/Policy-persistence#file-adapter) | File | Casbin | Persistence for [.CSV (Comma-Separated Values)](https://en.wikipedia.org/wiki/Comma-separated_values) files +[Filtered File Adapter (built-in)](https://github.com/casbin/casbin#policy-enforcement-at-scale) | File | [@faceless-saint](https://github.com/faceless-saint) | Persistence for [.CSV (Comma-Separated Values)](https://en.wikipedia.org/wiki/Comma-separated_values) files with policy subset loading support +[Xorm Adapter](https://github.com/casbin/xorm-adapter) | ORM | Casbin | MySQL, PostgreSQL, TiDB, SQLite, SQL Server, Oracle are supported by [Xorm](https://github.com/go-xorm/xorm/) +[Gorm Adapter](https://github.com/casbin/gorm-adapter) | ORM | Casbin | MySQL, PostgreSQL, Sqlite3, SQL Server are supported by [Gorm](https://github.com/jinzhu/gorm/) +[Beego ORM Adapter](https://github.com/casbin/beego-orm-adapter) | ORM | Casbin | MySQL, PostgreSQL, Sqlite3 are supported by [Beego ORM](https://beego.me/docs/mvc/model/overview.md) +[MongoDB Adapter](https://github.com/casbin/mongodb-adapter) | NoSQL | Casbin | Persistence for [MongoDB](https://www.mongodb.com) +[Cassandra Adapter](https://github.com/casbin/cassandra-adapter) | NoSQL | Casbin | Persistence for [Apache Cassandra DB](http://cassandra.apache.org) +[Consul Adapter](https://github.com/ankitm123/consul-adapter) | KV store | [@ankitm123](https://github.com/ankitm123) | Persistence for [HashiCorp Consul](https://www.consul.io/) +[Redis Adapter](https://github.com/casbin/redis-adapter) | KV store | Casbin | Persistence for [Redis](https://redis.io/) +[Protobuf Adapter](https://github.com/casbin/protobuf-adapter) | Stream | Casbin | Persistence for [Google Protocol Buffers](https://developers.google.com/protocol-buffers/) +[JSON Adapter](https://github.com/casbin/json-adapter) | String | Casbin | Persistence for [JSON](https://www.json.org/) +[String Adapter](https://github.com/qiangmzsx/string-adapter) | String | [@qiangmzsx](https://github.com/qiangmzsx) | Persistence for String +[RQLite Adapter](https://github.com/edomosystems/rqlite-adapter) | SQL | [EDOMO Systems](https://github.com/edomosystems) | Persistence for [RQLite](https://github.com/rqlite/rqlite/) +[PostgreSQL Adapter](https://github.com/going/casbin-postgres-adapter) | SQL | [Going](https://github.com/going) | Persistence for [PostgreSQL](https://www.postgresql.org/) +[RethinkDB Adapter](https://github.com/adityapandey9/rethinkdb-adapter) | NoSQL | [@adityapandey9](https://github.com/adityapandey9) | Persistence for [RethinkDB](https://rethinkdb.com/) +[DynamoDB Adapter](https://github.com/HOOQTV/dynacasbin) | NoSQL | [HOOQ](https://github.com/HOOQTV) | Persistence for [Amazon DynamoDB](https://aws.amazon.com/dynamodb/) +[Minio/AWS S3 Adapter](https://github.com/Soluto/casbin-minio-adapter) | Object storage | [Soluto](https://github.com/Soluto) | Persistence for [Minio](https://github.com/minio/minio) and [Amazon S3](https://aws.amazon.com/s3/) +[Bolt Adapter](https://github.com/wirepair/bolt-adapter) | KV store | [@wirepair](https://github.com/wirepair) | Persistence for [Bolt](https://github.com/boltdb/bolt) + +For details of adapters, please refer to the documentation: https://github.com/casbin/casbin/wiki/Policy-persistence + +## Policy enforcement at scale + +Some adapters support filtered policy management. This means that the policy loaded by Casbin is a subset of the policy in storage based on a given filter. This allows for efficient policy enforcement in large, multi-tenant environments when parsing the entire policy becomes a performance bottleneck. + +To use filtered policies with a supported adapter, simply call the `LoadFilteredPolicy` method. The valid format for the filter parameter depends on the adapter used. To prevent accidental data loss, the `SavePolicy` method is disabled when a filtered policy is loaded. + +For example, the following code snippet uses the built-in filtered file adapter and the RBAC model with domains. In this case, the filter limits the policy to a single domain. Any policy lines for domains other than `"domain1"` are omitted from the loaded policy: + +```go +import ( + "github.com/casbin/casbin" +) + +enforcer := casbin.NewEnforcer() + +adapter := fileadapter.NewFilteredAdapter("examples/rbac_with_domains_policy.csv") +enforcer.InitWithAdapter("examples/rbac_with_domains_model.conf", adapter) + +filter := &fileadapter.Filter{ + P: []string{"", "domain1"}, + G: []string{"", "", "domain1"}, +} +enforcer.LoadFilteredPolicy(filter) + +// The loaded policy now only contains the entries pertaining to "domain1". +``` + +## Policy consistence between multiple nodes + +We support to use distributed messaging systems like [etcd](https://github.com/coreos/etcd) to keep consistence between multiple Casbin enforcer instances. So our users can concurrently use multiple Casbin enforcers to handle large number of permission checking requests. + +Similar to policy storage adapters, we don't put watcher code in the main library. Any support for a new messaging system should be implemented as a watcher. A complete list of Casbin watchers is provided as below. Any 3rd-party contribution on a new watcher is welcomed, please inform us and I will put it in this list:) + +Watcher | Type | Author | Description +----|------|----|---- +[Etcd Watcher](https://github.com/casbin/etcd-watcher) | KV store | Casbin | Watcher for [etcd](https://github.com/coreos/etcd) +[NATS Watcher](https://github.com/Soluto/casbin-nats-watcher) | Messaging system | [Soluto](https://github.com/Soluto) | Watcher for [NATS](https://nats.io/) +[ZooKeeper Watcher](https://github.com/grepsr/casbin-zk-watcher) | KV store | [Grepsr](https://github.com/grepsr) | Watcher for [Apache ZooKeeper](https://zookeeper.apache.org/) +[Redis Watcher](https://github.com/billcobbler/casbin-redis-watcher) | KV store | [@billcobbler](https://github.com/billcobbler) | Watcher for [Redis](http://redis.io/) + +## Role manager + +The role manager is used to manage the RBAC role hierarchy (user-role mapping) in Casbin. A role manager can retrieve the role data from Casbin policy rules or external sources such as LDAP, Okta, Auth0, Azure AD, etc. We support different implementations of a role manager. To keep light-weight, we don't put role manager code in the main library (except the default role manager). A complete list of Casbin role managers is provided as below. Any 3rd-party contribution on a new role manager is welcomed, please inform us and I will put it in this list:) + +Role manager | Author | Description +----|----|---- +[Default Role Manager (built-in)](https://github.com/casbin/casbin/blob/master/rbac/default-role-manager/role_manager.go) | Casbin | Supports role hierarchy stored in Casbin policy +[Session Role Manager](https://github.com/casbin/session-role-manager) | [EDOMO Systems](https://github.com/edomosystems) | Supports role hierarchy stored in Casbin policy, with time-range-based sessions +[Okta Role Manager](https://github.com/casbin/okta-role-manager) | Casbin | Supports role hierarchy stored in [Okta](https://www.okta.com/) +[Auth0 Role Manager](https://github.com/casbin/auth0-role-manager) | Casbin | Supports role hierarchy stored in [Auth0](https://auth0.com/)'s [Authorization Extension](https://auth0.com/docs/extensions/authorization-extension/v2) + +For developers: all role managers must implement the [RoleManager](https://github.com/casbin/casbin/blob/master/rbac/role_manager.go) interface. [Session Role Manager](https://github.com/casbin/session-role-manager) can be used as a reference implementation. + +## Multi-threading + +If you use Casbin in a multi-threading manner, you can use the synchronized wrapper of the Casbin enforcer: https://github.com/casbin/casbin/blob/master/enforcer_synced.go. + +It also supports the ``AutoLoad`` feature, which means the Casbin enforcer will automatically load the latest policy rules from DB if it has changed. Call ``StartAutoLoadPolicy()`` to start automatically loading policy periodically and call ``StopAutoLoadPolicy()`` to stop it. + +## Benchmarks + +The overhead of policy enforcement is benchmarked in [model_b_test.go](https://github.com/casbin/casbin/blob/master/model_b_test.go). The testbed is: + +``` +Intel(R) Core(TM) i7-6700HQ CPU @ 2.60GHz, 2601 Mhz, 4 Core(s), 8 Logical Processor(s) +``` + +The benchmarking result of ``go test -bench=. -benchmem`` is as follows (op = an ``Enforce()`` call, ms = millisecond, KB = kilo bytes): + +Test case | Size | Time overhead | Memory overhead +----|------|------|---- +ACL | 2 rules (2 users) | 0.015493 ms/op | 5.649 KB +RBAC | 5 rules (2 users, 1 role) | 0.021738 ms/op | 7.522 KB +RBAC (small) | 1100 rules (1000 users, 100 roles) | 0.164309 ms/op | 80.620 KB +RBAC (medium) | 11000 rules (10000 users, 1000 roles) | 2.258262 ms/op | 765.152 KB +RBAC (large) | 110000 rules (100000 users, 10000 roles) | 23.916776 ms/op | 7.606 MB +RBAC with resource roles | 6 rules (2 users, 2 roles) | 0.021146 ms/op | 7.906 KB +RBAC with domains/tenants | 6 rules (2 users, 1 role, 2 domains) | 0.032696 ms/op | 10.755 KB +ABAC | 0 rule (0 user) | 0.007510 ms/op | 2.328 KB +RESTful | 5 rules (3 users) | 0.045398 ms/op | 91.774 KB +Deny-override | 6 rules (2 users, 1 role) | 0.023281 ms/op | 8.370 KB +Priority | 9 rules (2 users, 2 roles) | 0.016389 ms/op | 5.313 KB + +## Examples + +Model | Model file | Policy file +----|------|---- +ACL | [basic_model.conf](https://github.com/casbin/casbin/blob/master/examples/basic_model.conf) | [basic_policy.csv](https://github.com/casbin/casbin/blob/master/examples/basic_policy.csv) +ACL with superuser | [basic_model_with_root.conf](https://github.com/casbin/casbin/blob/master/examples/basic_with_root_model.conf) | [basic_policy.csv](https://github.com/casbin/casbin/blob/master/examples/basic_policy.csv) +ACL without users | [basic_model_without_users.conf](https://github.com/casbin/casbin/blob/master/examples/basic_without_users_model.conf) | [basic_policy_without_users.csv](https://github.com/casbin/casbin/blob/master/examples/basic_without_users_policy.csv) +ACL without resources | [basic_model_without_resources.conf](https://github.com/casbin/casbin/blob/master/examples/basic_without_resources_model.conf) | [basic_policy_without_resources.csv](https://github.com/casbin/casbin/blob/master/examples/basic_without_resources_policy.csv) +RBAC | [rbac_model.conf](https://github.com/casbin/casbin/blob/master/examples/rbac_model.conf) | [rbac_policy.csv](https://github.com/casbin/casbin/blob/master/examples/rbac_policy.csv) +RBAC with resource roles | [rbac_model_with_resource_roles.conf](https://github.com/casbin/casbin/blob/master/examples/rbac_with_resource_roles_model.conf) | [rbac_policy_with_resource_roles.csv](https://github.com/casbin/casbin/blob/master/examples/rbac_with_resource_roles_policy.csv) +RBAC with domains/tenants | [rbac_model_with_domains.conf](https://github.com/casbin/casbin/blob/master/examples/rbac_with_domains_model.conf) | [rbac_policy_with_domains.csv](https://github.com/casbin/casbin/blob/master/examples/rbac_with_domains_policy.csv) +ABAC | [abac_model.conf](https://github.com/casbin/casbin/blob/master/examples/abac_model.conf) | N/A +RESTful | [keymatch_model.conf](https://github.com/casbin/casbin/blob/master/examples/keymatch_model.conf) | [keymatch_policy.csv](https://github.com/casbin/casbin/blob/master/examples/keymatch_policy.csv) +Deny-override | [rbac_model_with_deny.conf](https://github.com/casbin/casbin/blob/master/examples/rbac_with_deny_model.conf) | [rbac_policy_with_deny.csv](https://github.com/casbin/casbin/blob/master/examples/rbac_with_deny_policy.csv) +Priority | [priority_model.conf](https://github.com/casbin/casbin/blob/master/examples/priority_model.conf) | [priority_policy.csv](https://github.com/casbin/casbin/blob/master/examples/priority_policy.csv) + +## How to use Casbin as a service? + +- [Go-Simple-API-Gateway](https://github.com/Soontao/go-simple-api-gateway): A simple API gateway written by golang, supports for authentication and authorization +- [Casbin Server](https://github.com/casbin/casbin-server): Casbin as a Service via RESTful, only exposed permission checking API +- [middleware-acl](https://github.com/luk4z7/middleware-acl): RESTful access control middleware based on Casbin + +## Our adopters + +### Web frameworks + +- [Beego](https://github.com/astaxie/beego): An open-source, high-performance web framework for Go, via built-in plugin: [plugins/authz](https://github.com/astaxie/beego/blob/master/plugins/authz) +- [Caddy](https://github.com/mholt/caddy): Fast, cross-platform HTTP/2 web server with automatic HTTPS, via plugin: [caddy-authz](https://github.com/casbin/caddy-authz) +- [Gin](https://github.com/gin-gonic/gin): A HTTP web framework featuring a Martini-like API with much better performance, via plugin: [authz](https://github.com/gin-contrib/authz) +- [Revel](https://github.com/revel/revel): A high productivity, full-stack web framework for the Go language, via plugin: [auth/casbin](https://github.com/revel/modules/tree/master/auth/casbin) +- [Echo](https://github.com/labstack/echo): High performance, minimalist Go web framework, via plugin: [echo-authz](https://github.com/labstack/echo-contrib/tree/master/casbin) (thanks to [@xqbumu](https://github.com/xqbumu)) +- [Iris](https://github.com/kataras/iris): The fastest web framework for Go in (THIS) Earth. HTTP/2 Ready-To-GO, via plugin: [casbin](https://github.com/iris-contrib/middleware/tree/master/casbin) (thanks to [@hiveminded](https://github.com/hiveminded)) +- [Negroni](https://github.com/urfave/negroni): Idiomatic HTTP Middleware for Golang, via plugin: [negroni-authz](https://github.com/casbin/negroni-authz) +- [Tango](https://github.com/lunny/tango): Micro & pluggable web framework for Go, via plugin: [authz](https://github.com/tango-contrib/authz) +- [Chi](https://github.com/pressly/chi): A lightweight, idiomatic and composable router for building HTTP services, via plugin: [chi-authz](https://github.com/casbin/chi-authz) +- [Macaron](https://github.com/go-macaron/macaron): A high productive and modular web framework in Go, via plugin: [authz](https://github.com/go-macaron/authz) +- [DotWeb](https://github.com/devfeel/dotweb): Simple and easy go web micro framework, via plugin: [authz](https://github.com/devfeel/middleware/tree/master/authz) +- [Baa](https://github.com/go-baa/baa): An express Go web framework with routing, middleware, dependency injection and http context, via plugin: [authz](https://github.com/baa-middleware/authz) + +### Others + +- [Intel RMD](https://github.com/intel/rmd): Intel's resource management daemon, via direct integration, see: [model](https://github.com/intel/rmd/blob/master/etc/rmd/acl/url/model.conf), [policy rules](https://github.com/intel/rmd/blob/master/etc/rmd/acl/url/policy.csv) +- [VMware Dispatch](https://github.com/vmware/dispatch): A framework for deploying and managing serverless style applications, via direct integration, see: [model (in code)](https://github.com/vmware/dispatch/blob/master/pkg/identity-manager/handlers.go#L46-L55), [policy rules (in code)](https://github.com/vmware/dispatch/blob/master/pkg/identity-manager/handlers_test.go#L35-L45) +- [Banzai Pipeline](https://github.com/banzaicloud/pipeline): [Banzai Cloud](https://github.com/banzaicloud)'s RESTful API to provision or reuse managed Kubernetes clusters in the cloud, via direct integration, see: [model (in code)](https://github.com/banzaicloud/pipeline/blob/master/auth/authz.go#L15-L30), [policy rules (in code)](https://github.com/banzaicloud/pipeline/blob/master/auth/authz.go#L84-L93) +- [Docker](https://github.com/docker/docker): The world's leading software container platform, via plugin: [casbin-authz-plugin](https://github.com/casbin/casbin-authz-plugin) ([recommended by Docker](https://docs.docker.com/engine/extend/legacy_plugins/#authorization-plugins)) +- [Gobis](https://github.com/orange-cloudfoundry/gobis): [Orange](https://github.com/orange-cloudfoundry)'s lightweight API Gateway written in go, via plugin: [casbin](https://github.com/orange-cloudfoundry/gobis-middlewares/tree/master/casbin), see [model (in code)](https://github.com/orange-cloudfoundry/gobis-middlewares/blob/master/casbin/model.go#L52-L65), [policy rules (from request)](https://github.com/orange-cloudfoundry/gobis-middlewares/blob/master/casbin/adapter.go#L46-L64) +- [Skydive](https://github.com/skydive-project/skydive): An open source real-time network topology and protocols analyzer, via direct integration, see: [model (in code)](https://github.com/skydive-project/skydive/blob/master/config/config.go#L136-L140), [policy rules](https://github.com/skydive-project/skydive/blob/master/rbac/policy.csv) +- [Zenpress](https://github.com/insionng/zenpress): A CMS system written in Golang, via direct integration, see: [model](https://github.com/insionng/zenpress/blob/master/content/config/rbac_model.conf), [policy rules (in Gorm)](https://github.com/insionng/zenpress/blob/master/model/user.go#L53-L77) +- [Argo CD](https://github.com/argoproj/argo-cd): GitOps continuous delivery for Kubernetes, via direct integration, see: [model](https://github.com/argoproj/argo-cd/blob/master/util/rbac/model.conf), [policy rules](https://github.com/argoproj/argo-cd/blob/master/util/rbac/builtin-policy.csv) +- [EngineerCMS](https://github.com/3xxx/EngineerCMS): A CMS to manage knowledge for engineers, via direct integration, see: [model](https://github.com/3xxx/EngineerCMS/blob/master/conf/rbac_model.conf), [policy rules (in SQLite)](https://github.com/3xxx/EngineerCMS/blob/master/database/engineer.db) +- [Cyber Auth API](https://github.com/CyberlifeCN/cyber-auth-api): A Golang authentication API project, via direct integration, see: [model](https://github.com/CyberlifeCN/cyber-auth-api/blob/master/conf/authz_model.conf), [policy rules](https://github.com/CyberlifeCN/cyber-auth-api/blob/master/conf/authz_policy.csv) +- [IRIS Community](https://github.com/irisnet/iris-community): Website for IRIS Community Activities, via direct integration, see: [model](https://github.com/irisnet/iris-community/blob/master/authz/authz_model.conf), [policy rules](https://github.com/irisnet/iris-community/blob/master/authz/authz_policy.csv) +- [Metadata DB](https://github.com/Bnei-Baruch/mdb): BB archive metadata database, via direct integration, see: [model](https://github.com/Bnei-Baruch/mdb/blob/master/data/permissions_model.conf), [policy rules](https://github.com/Bnei-Baruch/mdb/blob/master/data/permissions_policy.csv) + +## License + +This project is licensed under the [Apache 2.0 license](https://github.com/casbin/casbin/blob/master/LICENSE). + +## Contact + +If you have any issues or feature requests, please contact us. PR is welcomed. +- https://github.com/casbin/casbin/issues +- hsluoyz@gmail.com +- Tencent QQ group: [546057381](//shang.qq.com/wpa/qunwpa?idkey=8ac8b91fc97ace3d383d0035f7aa06f7d670fd8e8d4837347354a31c18fac885) + +## Donation + +[![Patreon](https://hsluoyz.github.io/donation/patreon.png)](http://www.patreon.com/yangluo) + +![Alipay](https://hsluoyz.github.io/donation/donate_alipay.png) +![Wechat](https://hsluoyz.github.io/donation/donate_weixin.png) diff --git a/vendor/github.com/casbin/casbin/casbin-logo.png b/vendor/github.com/casbin/casbin/casbin-logo.png new file mode 100644 index 00000000..7e5d1ecf Binary files /dev/null and b/vendor/github.com/casbin/casbin/casbin-logo.png differ diff --git a/vendor/github.com/casbin/casbin/config/config.go b/vendor/github.com/casbin/casbin/config/config.go new file mode 100644 index 00000000..f188dfc8 --- /dev/null +++ b/vendor/github.com/casbin/casbin/config/config.go @@ -0,0 +1,266 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "os" + "strconv" + "strings" + "sync" +) + +var ( + // DEFAULT_SECTION specifies the name of a section if no name provided + DEFAULT_SECTION = "default" + // DEFAULT_COMMENT defines what character(s) indicate a comment `#` + DEFAULT_COMMENT = []byte{'#'} + // DEFAULT_COMMENT_SEM defines what alternate character(s) indicate a comment `;` + DEFAULT_COMMENT_SEM = []byte{';'} + // DEFAULT_MULTI_LINE_SEPARATOR defines what character indicates a multi-line content + DEFAULT_MULTI_LINE_SEPARATOR = []byte{'\\'} +) + +// ConfigInterface defines the behavior of a Config implemenation +type ConfigInterface interface { + String(key string) string + Strings(key string) []string + Bool(key string) (bool, error) + Int(key string) (int, error) + Int64(key string) (int64, error) + Float64(key string) (float64, error) + Set(key string, value string) error +} + +// Config represents an implementation of the ConfigInterface +type Config struct { + // map is not safe. + sync.RWMutex + // Section:key=value + data map[string]map[string]string +} + +// NewConfig create an empty configuration representation from file. +func NewConfig(confName string) (ConfigInterface, error) { + c := &Config{ + data: make(map[string]map[string]string), + } + err := c.parse(confName) + return c, err +} + +// NewConfigFromText create an empty configuration representation from text. +func NewConfigFromText(text string) (ConfigInterface, error) { + c := &Config{ + data: make(map[string]map[string]string), + } + err := c.parseBuffer(bufio.NewReader(strings.NewReader(text))) + return c, err +} + +// AddConfig adds a new section->key:value to the configuration. +func (c *Config) AddConfig(section string, option string, value string) bool { + if section == "" { + section = DEFAULT_SECTION + } + + if _, ok := c.data[section]; !ok { + c.data[section] = make(map[string]string) + } + + _, ok := c.data[section][option] + c.data[section][option] = value + + return !ok +} + +func (c *Config) parse(fname string) (err error) { + c.Lock() + f, err := os.Open(fname) + if err != nil { + return err + } + defer c.Unlock() + defer f.Close() + + buf := bufio.NewReader(f) + return c.parseBuffer(buf) +} + +func (c *Config) parseBuffer(buf *bufio.Reader) error { + var section string + var lineNum int + var buffer bytes.Buffer + var canWrite bool + for { + if canWrite { + if err := c.write(section, lineNum, &buffer); err != nil { + return err + } else { + canWrite = false + } + } + lineNum++ + line, _, err := buf.ReadLine() + if err == io.EOF { + // force write when buffer is not flushed yet + if buffer.Len() > 0 { + if err := c.write(section, lineNum, &buffer); err != nil { + return err + } + } + break + } else if err != nil { + return err + } + + line = bytes.TrimSpace(line) + switch { + case bytes.Equal(line, []byte{}), bytes.HasPrefix(line, DEFAULT_COMMENT_SEM), + bytes.HasPrefix(line, DEFAULT_COMMENT): + canWrite = true + continue + case bytes.HasPrefix(line, []byte{'['}) && bytes.HasSuffix(line, []byte{']'}): + // force write when buffer is not flushed yet + if buffer.Len() > 0 { + if err := c.write(section, lineNum, &buffer); err != nil { + return err + } + canWrite = false + } + section = string(line[1 : len(line)-1]) + default: + var p []byte + if bytes.HasSuffix(line, DEFAULT_MULTI_LINE_SEPARATOR) { + p = bytes.TrimSpace(line[:len(line)-1]) + } else { + p = line + canWrite = true + } + + if _, err := buffer.Write(p); err != nil { + return err + } + } + } + + return nil +} + +func (c *Config) write(section string, lineNum int, b *bytes.Buffer) error { + if b.Len() <= 0 { + return nil + } + + optionVal := bytes.SplitN(b.Bytes(), []byte{'='}, 2) + if len(optionVal) != 2 { + return fmt.Errorf("parse the content error : line %d , %s = ? ", lineNum, optionVal[0]) + } + option := bytes.TrimSpace(optionVal[0]) + value := bytes.TrimSpace(optionVal[1]) + c.AddConfig(section, string(option), string(value)) + + // flush buffer after adding + b.Reset() + + return nil +} + +// Bool lookups up the value using the provided key and converts the value to a bool +func (c *Config) Bool(key string) (bool, error) { + return strconv.ParseBool(c.get(key)) +} + +// Int lookups up the value using the provided key and converts the value to a int +func (c *Config) Int(key string) (int, error) { + return strconv.Atoi(c.get(key)) +} + +// Int64 lookups up the value using the provided key and converts the value to a int64 +func (c *Config) Int64(key string) (int64, error) { + return strconv.ParseInt(c.get(key), 10, 64) +} + +// Float64 lookups up the value using the provided key and converts the value to a float64 +func (c *Config) Float64(key string) (float64, error) { + return strconv.ParseFloat(c.get(key), 64) +} + +// String lookups up the value using the provided key and converts the value to a string +func (c *Config) String(key string) string { + return c.get(key) +} + +// Strings lookups up the value using the provided key and converts the value to an array of string +// by splitting the string by comma +func (c *Config) Strings(key string) []string { + v := c.get(key) + if v == "" { + return nil + } + return strings.Split(v, ",") +} + +// Set sets the value for the specific key in the Config +func (c *Config) Set(key string, value string) error { + c.Lock() + defer c.Unlock() + if len(key) == 0 { + return errors.New("key is empty") + } + + var ( + section string + option string + ) + + keys := strings.Split(strings.ToLower(key), "::") + if len(keys) >= 2 { + section = keys[0] + option = keys[1] + } else { + option = keys[0] + } + + c.AddConfig(section, option, value) + return nil +} + +// section.key or key +func (c *Config) get(key string) string { + var ( + section string + option string + ) + + keys := strings.Split(strings.ToLower(key), "::") + if len(keys) >= 2 { + section = keys[0] + option = keys[1] + } else { + section = DEFAULT_SECTION + option = keys[0] + } + + if value, ok := c.data[section][option]; ok { + return value + } + + return "" +} diff --git a/vendor/github.com/casbin/casbin/effect/default_effector.go b/vendor/github.com/casbin/casbin/effect/default_effector.go new file mode 100644 index 00000000..7368e19f --- /dev/null +++ b/vendor/github.com/casbin/casbin/effect/default_effector.go @@ -0,0 +1,75 @@ +// Copyright 2018 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package effect + +import "github.com/pkg/errors" + +// DefaultEffector is default effector for Casbin. +type DefaultEffector struct { +} + +// NewDefaultEffector is the constructor for DefaultEffector. +func NewDefaultEffector() *DefaultEffector { + e := DefaultEffector{} + return &e +} + +// MergeEffects merges all matching results collected by the enforcer into a single decision. +func (e *DefaultEffector) MergeEffects(expr string, effects []Effect, results []float64) (bool, error) { + result := false + if expr == "some(where (p_eft == allow))" { + result = false + for _, eft := range effects { + if eft == Allow { + result = true + break + } + } + } else if expr == "!some(where (p_eft == deny))" { + result = true + for _, eft := range effects { + if eft == Deny { + result = false + break + } + } + } else if expr == "some(where (p_eft == allow)) && !some(where (p_eft == deny))" { + result = false + for _, eft := range effects { + if eft == Allow { + result = true + } else if eft == Deny { + result = false + break + } + } + } else if expr == "priority(p_eft) || deny" { + result = false + for _, eft := range effects { + if eft != Indeterminate { + if eft == Allow { + result = true + } else { + result = false + } + break + } + } + } else { + return false, errors.New("unsupported effect") + } + + return result, nil +} diff --git a/vendor/github.com/casbin/casbin/effect/effector.go b/vendor/github.com/casbin/casbin/effect/effector.go new file mode 100644 index 00000000..52271b3f --- /dev/null +++ b/vendor/github.com/casbin/casbin/effect/effector.go @@ -0,0 +1,31 @@ +// Copyright 2018 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package effect + +// Effect is the result for a policy rule. +type Effect int + +// Values for policy effect. +const ( + Allow Effect = iota + Indeterminate + Deny +) + +// Effector is the interface for Casbin effectors. +type Effector interface { + // MergeEffects merges all matching results collected by the enforcer into a single decision. + MergeEffects(expr string, effects []Effect, results []float64) (bool, error) +} diff --git a/vendor/github.com/casbin/casbin/enforcer.go b/vendor/github.com/casbin/casbin/enforcer.go new file mode 100644 index 00000000..a3a5490d --- /dev/null +++ b/vendor/github.com/casbin/casbin/enforcer.go @@ -0,0 +1,425 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +import ( + "errors" + "fmt" + + "github.com/Knetic/govaluate" + + "github.com/casbin/casbin/effect" + "github.com/casbin/casbin/model" + "github.com/casbin/casbin/persist" + "github.com/casbin/casbin/persist/file-adapter" + "github.com/casbin/casbin/rbac" + "github.com/casbin/casbin/rbac/default-role-manager" + "github.com/casbin/casbin/util" +) + +// Enforcer is the main interface for authorization enforcement and policy management. +type Enforcer struct { + modelPath string + model model.Model + fm model.FunctionMap + eft effect.Effector + + adapter persist.Adapter + watcher persist.Watcher + rm rbac.RoleManager + + enabled bool + autoSave bool + autoBuildRoleLinks bool +} + +// NewEnforcer creates an enforcer via file or DB. +// File: +// e := casbin.NewEnforcer("path/to/basic_model.conf", "path/to/basic_policy.csv") +// MySQL DB: +// a := mysqladapter.NewDBAdapter("mysql", "mysql_username:mysql_password@tcp(127.0.0.1:3306)/") +// e := casbin.NewEnforcer("path/to/basic_model.conf", a) +func NewEnforcer(params ...interface{}) *Enforcer { + e := &Enforcer{} + e.rm = defaultrolemanager.NewRoleManager(10) + e.eft = effect.NewDefaultEffector() + + parsedParamLen := 0 + if len(params) >= 1 { + enableLog, ok := params[len(params)-1].(bool) + if ok { + e.EnableLog(enableLog) + + parsedParamLen++ + } + } + + if len(params)-parsedParamLen == 2 { + switch params[0].(type) { + case string: + switch params[1].(type) { + case string: + e.InitWithFile(params[0].(string), params[1].(string)) + default: + e.InitWithAdapter(params[0].(string), params[1].(persist.Adapter)) + } + default: + switch params[1].(type) { + case string: + panic("Invalid parameters for enforcer.") + default: + e.InitWithModelAndAdapter(params[0].(model.Model), params[1].(persist.Adapter)) + } + } + } else if len(params)-parsedParamLen == 1 { + switch params[0].(type) { + case string: + e.InitWithFile(params[0].(string), "") + default: + e.InitWithModelAndAdapter(params[0].(model.Model), nil) + } + } else if len(params)-parsedParamLen == 0 { + e.InitWithFile("", "") + } else { + panic("Invalid parameters for enforcer.") + } + + return e +} + +// InitWithFile initializes an enforcer with a model file and a policy file. +func (e *Enforcer) InitWithFile(modelPath string, policyPath string) { + a := fileadapter.NewAdapter(policyPath) + e.InitWithAdapter(modelPath, a) +} + +// InitWithAdapter initializes an enforcer with a database adapter. +func (e *Enforcer) InitWithAdapter(modelPath string, adapter persist.Adapter) { + m := NewModel(modelPath, "") + e.InitWithModelAndAdapter(m, adapter) + + e.modelPath = modelPath +} + +// InitWithModelAndAdapter initializes an enforcer with a model and a database adapter. +func (e *Enforcer) InitWithModelAndAdapter(m model.Model, adapter persist.Adapter) { + e.adapter = adapter + e.watcher = nil + + e.model = m + e.model.PrintModel() + e.fm = model.LoadFunctionMap() + + e.initialize() + + if e.adapter != nil { + // error intentionally ignored + e.LoadPolicy() + } +} + +func (e *Enforcer) initialize() { + e.enabled = true + e.autoSave = true + e.autoBuildRoleLinks = true +} + +// NewModel creates a model. +func NewModel(text ...string) model.Model { + m := make(model.Model) + + if len(text) == 2 { + if text[0] != "" { + m.LoadModel(text[0]) + } + } else if len(text) == 1 { + m.LoadModelFromText(text[0]) + } else if len(text) != 0 { + panic("Invalid parameters for model.") + } + + return m +} + +// LoadModel reloads the model from the model CONF file. +// Because the policy is attached to a model, so the policy is invalidated and needs to be reloaded by calling LoadPolicy(). +func (e *Enforcer) LoadModel() { + e.model = NewModel() + e.model.LoadModel(e.modelPath) + e.model.PrintModel() + e.fm = model.LoadFunctionMap() +} + +// GetModel gets the current model. +func (e *Enforcer) GetModel() model.Model { + return e.model +} + +// SetModel sets the current model. +func (e *Enforcer) SetModel(m model.Model) { + e.model = m + e.fm = model.LoadFunctionMap() +} + +// GetAdapter gets the current adapter. +func (e *Enforcer) GetAdapter() persist.Adapter { + return e.adapter +} + +// SetAdapter sets the current adapter. +func (e *Enforcer) SetAdapter(adapter persist.Adapter) { + e.adapter = adapter +} + +// SetWatcher sets the current watcher. +func (e *Enforcer) SetWatcher(watcher persist.Watcher) { + e.watcher = watcher + // error intentionally ignored + watcher.SetUpdateCallback(func(string) { e.LoadPolicy() }) +} + +// SetRoleManager sets the current role manager. +func (e *Enforcer) SetRoleManager(rm rbac.RoleManager) { + e.rm = rm +} + +// SetEffector sets the current effector. +func (e *Enforcer) SetEffector(eft effect.Effector) { + e.eft = eft +} + +// ClearPolicy clears all policy. +func (e *Enforcer) ClearPolicy() { + e.model.ClearPolicy() +} + +// LoadPolicy reloads the policy from file/database. +func (e *Enforcer) LoadPolicy() error { + e.model.ClearPolicy() + if err := e.adapter.LoadPolicy(e.model); err != nil { + return err + } + + e.model.PrintPolicy() + if e.autoBuildRoleLinks { + e.BuildRoleLinks() + } + return nil +} + +// LoadFilteredPolicy reloads a filtered policy from file/database. +func (e *Enforcer) LoadFilteredPolicy(filter interface{}) error { + e.model.ClearPolicy() + + var filteredAdapter persist.FilteredAdapter + + // Attempt to cast the Adapter as a FilteredAdapter + switch e.adapter.(type) { + case persist.FilteredAdapter: + filteredAdapter = e.adapter.(persist.FilteredAdapter) + default: + return errors.New("filtered policies are not supported by this adapter") + } + if err := filteredAdapter.LoadFilteredPolicy(e.model, filter); err != nil { + return err + } + + e.model.PrintPolicy() + if e.autoBuildRoleLinks { + e.BuildRoleLinks() + } + return nil +} + +// IsFiltered returns true if the loaded policy has been filtered. +func (e *Enforcer) IsFiltered() bool { + filteredAdapter, ok := e.adapter.(persist.FilteredAdapter) + if !ok { + return false + } + return filteredAdapter.IsFiltered() +} + +// SavePolicy saves the current policy (usually after changed with Casbin API) back to file/database. +func (e *Enforcer) SavePolicy() error { + if e.IsFiltered() { + return errors.New("cannot save a filtered policy") + } + if err := e.adapter.SavePolicy(e.model); err != nil { + return err + } + if e.watcher != nil { + return e.watcher.Update() + } + return nil +} + +// EnableEnforce changes the enforcing state of Casbin, when Casbin is disabled, all access will be allowed by the Enforce() function. +func (e *Enforcer) EnableEnforce(enable bool) { + e.enabled = enable +} + +// EnableLog changes whether to print Casbin log to the standard output. +func (e *Enforcer) EnableLog(enable bool) { + util.EnableLog = enable +} + +// EnableAutoSave controls whether to save a policy rule automatically to the adapter when it is added or removed. +func (e *Enforcer) EnableAutoSave(autoSave bool) { + e.autoSave = autoSave +} + +// EnableAutoBuildRoleLinks controls whether to rebuild the role inheritance relations when a role is added or deleted. +func (e *Enforcer) EnableAutoBuildRoleLinks(autoBuildRoleLinks bool) { + e.autoBuildRoleLinks = autoBuildRoleLinks +} + +// BuildRoleLinks manually rebuild the role inheritance relations. +func (e *Enforcer) BuildRoleLinks() { + // error intentionally ignored + e.rm.Clear() + e.model.BuildRoleLinks(e.rm) +} + +// Enforce decides whether a "subject" can access a "object" with the operation "action", input parameters are usually: (sub, obj, act). +func (e *Enforcer) Enforce(rvals ...interface{}) bool { + if !e.enabled { + return true + } + + functions := make(map[string]govaluate.ExpressionFunction) + for key, function := range e.fm { + functions[key] = function + } + if _, ok := e.model["g"]; ok { + for key, ast := range e.model["g"] { + rm := ast.RM + functions[key] = util.GenerateGFunction(rm) + } + } + + expString := e.model["m"]["m"].Value + expression, _ := govaluate.NewEvaluableExpressionWithFunctions(expString, functions) + + var policyEffects []effect.Effect + var matcherResults []float64 + if policyLen := len(e.model["p"]["p"].Policy); policyLen != 0 { + policyEffects = make([]effect.Effect, policyLen) + matcherResults = make([]float64, policyLen) + + for i, pvals := range e.model["p"]["p"].Policy { + // util.LogPrint("Policy Rule: ", pvals) + + parameters := make(map[string]interface{}, 8) + for j, token := range e.model["r"]["r"].Tokens { + parameters[token] = rvals[j] + } + for j, token := range e.model["p"]["p"].Tokens { + parameters[token] = pvals[j] + } + + result, err := expression.Evaluate(parameters) + // util.LogPrint("Result: ", result) + + if err != nil { + policyEffects[i] = effect.Indeterminate + panic(err) + } + + switch result.(type) { + case bool: + if !result.(bool) { + policyEffects[i] = effect.Indeterminate + continue + } + case float64: + if result.(float64) == 0 { + policyEffects[i] = effect.Indeterminate + continue + } else { + matcherResults[i] = result.(float64) + } + default: + panic(errors.New("matcher result should be bool, int or float")) + } + + if eft, ok := parameters["p_eft"]; ok { + if eft == "allow" { + policyEffects[i] = effect.Allow + } else if eft == "deny" { + policyEffects[i] = effect.Deny + } else { + policyEffects[i] = effect.Indeterminate + } + } else { + policyEffects[i] = effect.Allow + } + + if e.model["e"]["e"].Value == "priority(p_eft) || deny" { + break + } + + } + } else { + policyEffects = make([]effect.Effect, 1) + matcherResults = make([]float64, 1) + + parameters := make(map[string]interface{}, 8) + for j, token := range e.model["r"]["r"].Tokens { + parameters[token] = rvals[j] + } + for _, token := range e.model["p"]["p"].Tokens { + parameters[token] = "" + } + + result, err := expression.Evaluate(parameters) + // util.LogPrint("Result: ", result) + + if err != nil { + policyEffects[0] = effect.Indeterminate + panic(err) + } + + if result.(bool) { + policyEffects[0] = effect.Allow + } else { + policyEffects[0] = effect.Indeterminate + } + } + + // util.LogPrint("Rule Results: ", policyEffects) + + result, err := e.eft.MergeEffects(e.model["e"]["e"].Value, policyEffects, matcherResults) + if err != nil { + panic(err) + } + + // only generate the request --> result string if the message + // is going to be logged. + if util.EnableLog { + reqStr := "Request: " + for i, rval := range rvals { + if i != len(rvals)-1 { + reqStr += fmt.Sprintf("%v, ", rval) + } else { + reqStr += fmt.Sprintf("%v", rval) + } + } + reqStr += fmt.Sprintf(" ---> %t", result) + util.LogPrint(reqStr) + } + + return result +} diff --git a/vendor/github.com/casbin/casbin/enforcer_cached.go b/vendor/github.com/casbin/casbin/enforcer_cached.go new file mode 100644 index 00000000..fffc8340 --- /dev/null +++ b/vendor/github.com/casbin/casbin/enforcer_cached.go @@ -0,0 +1,74 @@ +// Copyright 2018 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +import ( + "sync" +) + +// CachedEnforcer wraps Enforcer and provides decision cache +type CachedEnforcer struct { + *Enforcer + m map[string]bool + enableCache bool + locker *sync.Mutex +} + +// NewCachedEnforcer creates a cached enforcer via file or DB. +func NewCachedEnforcer(params ...interface{}) *CachedEnforcer { + e := &CachedEnforcer{} + e.Enforcer = NewEnforcer(params...) + e.enableCache = true + e.m = make(map[string]bool) + e.locker = new(sync.Mutex) + return e +} + +// EnableCache determines whether to enable cache on Enforce(). When enableCache is enabled, cached result (true | false) will be returned for previous decisions. +func (e *CachedEnforcer) EnableCache(enableCache bool) { + e.enableCache = enableCache +} + +// Enforce decides whether a "subject" can access a "object" with the operation "action", input parameters are usually: (sub, obj, act). +// if rvals is not string , ingore the cache +func (e *CachedEnforcer) Enforce(rvals ...interface{}) bool { + if !e.enableCache { + return e.Enforcer.Enforce(rvals...) + } + + key := "" + for _, rval := range rvals { + if val, ok := rval.(string); ok { + key += val + "$$" + } else { + return e.Enforcer.Enforce(rvals...) + } + } + + e.locker.Lock() + defer e.locker.Unlock() + if _, ok := e.m[key]; ok { + return e.m[key] + } else { + res := e.Enforcer.Enforce(rvals...) + e.m[key] = res + return res + } +} + +// InvalidateCache deletes all the existing cached decisions. +func (e *CachedEnforcer) InvalidateCache() { + e.m = make(map[string]bool) +} diff --git a/vendor/github.com/casbin/casbin/enforcer_safe.go b/vendor/github.com/casbin/casbin/enforcer_safe.go new file mode 100644 index 00000000..1cb358fe --- /dev/null +++ b/vendor/github.com/casbin/casbin/enforcer_safe.go @@ -0,0 +1,102 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +import ( + "fmt" +) + +// NewEnforcerSafe calls NewEnforcer in a safe way, returns error instead of causing panic. +func NewEnforcerSafe(params ...interface{}) (e *Enforcer, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%v", r) + e = nil + } + }() + + e = NewEnforcer(params...) + err = nil + return +} + +// LoadModelSafe calls LoadModel in a safe way, returns error instead of causing panic. +func (e *Enforcer) LoadModelSafe() (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%v", r) + } + }() + + e.LoadModel() + err = nil + return +} + +// EnforceSafe calls Enforce in a safe way, returns error instead of causing panic. +func (e *Enforcer) EnforceSafe(rvals ...interface{}) (result bool, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%v", r) + result = false + } + }() + + result = e.Enforce(rvals...) + err = nil + return +} + +// AddPolicySafe calls AddPolicy in a safe way, returns error instead of causing panic. +func (e *Enforcer) AddPolicySafe(params ...interface{}) (result bool, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%v", r) + result = false + } + }() + + result = e.AddNamedPolicy("p", params...) + err = nil + return +} + +// RemovePolicySafe calls RemovePolicy in a safe way, returns error instead of causing panic. +func (e *Enforcer) RemovePolicySafe(params ...interface{}) (result bool, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%v", r) + result = false + } + }() + + result = e.RemoveNamedPolicy("p", params...) + err = nil + return +} + +// RemoveFilteredPolicySafe calls RemoveFilteredPolicy in a safe way, returns error instead of causing panic. +func (e *Enforcer) RemoveFilteredPolicySafe(fieldIndex int, fieldValues ...string) (result bool, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%v", r) + result = false + } + }() + + result = e.RemoveFilteredNamedPolicy("p", fieldIndex, fieldValues...) + err = nil + return +} diff --git a/vendor/github.com/casbin/casbin/enforcer_synced.go b/vendor/github.com/casbin/casbin/enforcer_synced.go new file mode 100644 index 00000000..f752c71f --- /dev/null +++ b/vendor/github.com/casbin/casbin/enforcer_synced.go @@ -0,0 +1,223 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +import ( + "log" + "sync" + "time" + + "github.com/casbin/casbin/persist" +) + +// SyncedEnforcer wraps Enforcer and provides synchronized access +type SyncedEnforcer struct { + *Enforcer + m sync.RWMutex + autoLoad bool +} + +// NewSyncedEnforcer creates a synchronized enforcer via file or DB. +func NewSyncedEnforcer(params ...interface{}) *SyncedEnforcer { + e := &SyncedEnforcer{} + e.Enforcer = NewEnforcer(params...) + e.autoLoad = false + return e +} + +// StartAutoLoadPolicy starts a go routine that will every specified duration call LoadPolicy +func (e *SyncedEnforcer) StartAutoLoadPolicy(d time.Duration) { + e.autoLoad = true + go func() { + n := 1 + log.Print("Start automatically load policy") + for { + if !e.autoLoad { + log.Print("Stop automatically load policy") + break + } + + // error intentionally ignored + e.LoadPolicy() + // Uncomment this line to see when the policy is loaded. + // log.Print("Load policy for time: ", n) + n++ + time.Sleep(d) + } + }() +} + +// StopAutoLoadPolicy causes the go routine to exit. +func (e *SyncedEnforcer) StopAutoLoadPolicy() { + e.autoLoad = false +} + +// SetWatcher sets the current watcher. +func (e *SyncedEnforcer) SetWatcher(watcher persist.Watcher) { + e.watcher = watcher + // error intentionally ignored + watcher.SetUpdateCallback(func(string) { e.LoadPolicy() }) +} + +// ClearPolicy clears all policy. +func (e *SyncedEnforcer) ClearPolicy() { + e.m.Lock() + defer e.m.Unlock() + e.Enforcer.ClearPolicy() +} + +// LoadPolicy reloads the policy from file/database. +func (e *SyncedEnforcer) LoadPolicy() error { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.LoadPolicy() +} + +// SavePolicy saves the current policy (usually after changed with Casbin API) back to file/database. +func (e *SyncedEnforcer) SavePolicy() error { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.SavePolicy() +} + +// BuildRoleLinks manually rebuild the role inheritance relations. +func (e *SyncedEnforcer) BuildRoleLinks() { + e.m.RLock() + defer e.m.RUnlock() + e.Enforcer.BuildRoleLinks() +} + +// Enforce decides whether a "subject" can access a "object" with the operation "action", input parameters are usually: (sub, obj, act). +func (e *SyncedEnforcer) Enforce(rvals ...interface{}) bool { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.Enforce(rvals...) +} + +// GetAllSubjects gets the list of subjects that show up in the current policy. +func (e *SyncedEnforcer) GetAllSubjects() []string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetAllSubjects() +} + +// GetAllObjects gets the list of objects that show up in the current policy. +func (e *SyncedEnforcer) GetAllObjects() []string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetAllObjects() +} + +// GetAllActions gets the list of actions that show up in the current policy. +func (e *SyncedEnforcer) GetAllActions() []string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetAllActions() +} + +// GetAllRoles gets the list of roles that show up in the current policy. +func (e *SyncedEnforcer) GetAllRoles() []string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetAllRoles() +} + +// GetPolicy gets all the authorization rules in the policy. +func (e *SyncedEnforcer) GetPolicy() [][]string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetPolicy() +} + +// GetFilteredPolicy gets all the authorization rules in the policy, field filters can be specified. +func (e *SyncedEnforcer) GetFilteredPolicy(fieldIndex int, fieldValues ...string) [][]string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetFilteredPolicy(fieldIndex, fieldValues...) +} + +// GetGroupingPolicy gets all the role inheritance rules in the policy. +func (e *SyncedEnforcer) GetGroupingPolicy() [][]string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetGroupingPolicy() +} + +// GetFilteredGroupingPolicy gets all the role inheritance rules in the policy, field filters can be specified. +func (e *SyncedEnforcer) GetFilteredGroupingPolicy(fieldIndex int, fieldValues ...string) [][]string { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.GetFilteredGroupingPolicy(fieldIndex, fieldValues...) +} + +// HasPolicy determines whether an authorization rule exists. +func (e *SyncedEnforcer) HasPolicy(params ...interface{}) bool { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.HasPolicy(params...) +} + +// AddPolicy adds an authorization rule to the current policy. +// If the rule already exists, the function returns false and the rule will not be added. +// Otherwise the function returns true by adding the new rule. +func (e *SyncedEnforcer) AddPolicy(params ...interface{}) bool { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.AddPolicy(params...) +} + +// RemovePolicy removes an authorization rule from the current policy. +func (e *SyncedEnforcer) RemovePolicy(params ...interface{}) bool { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.RemovePolicy(params...) +} + +// RemoveFilteredPolicy removes an authorization rule from the current policy, field filters can be specified. +func (e *SyncedEnforcer) RemoveFilteredPolicy(fieldIndex int, fieldValues ...string) bool { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.RemoveFilteredPolicy(fieldIndex, fieldValues...) +} + +// HasGroupingPolicy determines whether a role inheritance rule exists. +func (e *SyncedEnforcer) HasGroupingPolicy(params ...interface{}) bool { + e.m.RLock() + defer e.m.RUnlock() + return e.Enforcer.HasGroupingPolicy(params...) +} + +// AddGroupingPolicy adds a role inheritance rule to the current policy. +// If the rule already exists, the function returns false and the rule will not be added. +// Otherwise the function returns true by adding the new rule. +func (e *SyncedEnforcer) AddGroupingPolicy(params ...interface{}) bool { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.AddGroupingPolicy(params...) +} + +// RemoveGroupingPolicy removes a role inheritance rule from the current policy. +func (e *SyncedEnforcer) RemoveGroupingPolicy(params ...interface{}) bool { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.RemoveGroupingPolicy(params...) +} + +// RemoveFilteredGroupingPolicy removes a role inheritance rule from the current policy, field filters can be specified. +func (e *SyncedEnforcer) RemoveFilteredGroupingPolicy(fieldIndex int, fieldValues ...string) bool { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.RemoveFilteredGroupingPolicy(fieldIndex, fieldValues...) +} diff --git a/vendor/github.com/casbin/casbin/internal_api.go b/vendor/github.com/casbin/casbin/internal_api.go new file mode 100644 index 00000000..6e6acead --- /dev/null +++ b/vendor/github.com/casbin/casbin/internal_api.go @@ -0,0 +1,85 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +const ( + notImplemented = "not implemented" +) + +// addPolicy adds a rule to the current policy. +func (e *Enforcer) addPolicy(sec string, ptype string, rule []string) bool { + ruleAdded := e.model.AddPolicy(sec, ptype, rule) + if !ruleAdded { + return ruleAdded + } + + if e.adapter != nil && e.autoSave { + if err := e.adapter.AddPolicy(sec, ptype, rule); err != nil { + if err.Error() != notImplemented { + panic(err) + } + } + if e.watcher != nil { + // error intentionally ignored + e.watcher.Update() + } + } + + return ruleAdded +} + +// removePolicy removes a rule from the current policy. +func (e *Enforcer) removePolicy(sec string, ptype string, rule []string) bool { + ruleRemoved := e.model.RemovePolicy(sec, ptype, rule) + if !ruleRemoved { + return ruleRemoved + } + + if e.adapter != nil && e.autoSave { + if err := e.adapter.RemovePolicy(sec, ptype, rule); err != nil { + if err.Error() != notImplemented { + panic(err) + } + } + if e.watcher != nil { + // error intentionally ignored + e.watcher.Update() + } + } + + return ruleRemoved +} + +// removeFilteredPolicy removes rules based on field filters from the current policy. +func (e *Enforcer) removeFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) bool { + ruleRemoved := e.model.RemoveFilteredPolicy(sec, ptype, fieldIndex, fieldValues...) + if !ruleRemoved { + return ruleRemoved + } + + if e.adapter != nil && e.autoSave { + if err := e.adapter.RemoveFilteredPolicy(sec, ptype, fieldIndex, fieldValues...); err != nil { + if err.Error() != notImplemented { + panic(err) + } + } + if e.watcher != nil { + // error intentionally ignored + e.watcher.Update() + } + } + + return ruleRemoved +} diff --git a/vendor/github.com/casbin/casbin/management_api.go b/vendor/github.com/casbin/casbin/management_api.go new file mode 100644 index 00000000..355e17a7 --- /dev/null +++ b/vendor/github.com/casbin/casbin/management_api.go @@ -0,0 +1,265 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +// GetAllSubjects gets the list of subjects that show up in the current policy. +func (e *Enforcer) GetAllSubjects() []string { + return e.GetAllNamedSubjects("p") +} + +// GetAllNamedSubjects gets the list of subjects that show up in the current named policy. +func (e *Enforcer) GetAllNamedSubjects(ptype string) []string { + return e.model.GetValuesForFieldInPolicy("p", ptype, 0) +} + +// GetAllObjects gets the list of objects that show up in the current policy. +func (e *Enforcer) GetAllObjects() []string { + return e.GetAllNamedObjects("p") +} + +// GetAllNamedObjects gets the list of objects that show up in the current named policy. +func (e *Enforcer) GetAllNamedObjects(ptype string) []string { + return e.model.GetValuesForFieldInPolicy("p", ptype, 1) +} + +// GetAllActions gets the list of actions that show up in the current policy. +func (e *Enforcer) GetAllActions() []string { + return e.GetAllNamedActions("p") +} + +// GetAllNamedActions gets the list of actions that show up in the current named policy. +func (e *Enforcer) GetAllNamedActions(ptype string) []string { + return e.model.GetValuesForFieldInPolicy("p", ptype, 2) +} + +// GetAllRoles gets the list of roles that show up in the current policy. +func (e *Enforcer) GetAllRoles() []string { + return e.GetAllNamedRoles("g") +} + +// GetAllNamedRoles gets the list of roles that show up in the current named policy. +func (e *Enforcer) GetAllNamedRoles(ptype string) []string { + return e.model.GetValuesForFieldInPolicy("g", ptype, 1) +} + +// GetPolicy gets all the authorization rules in the policy. +func (e *Enforcer) GetPolicy() [][]string { + return e.GetNamedPolicy("p") +} + +// GetFilteredPolicy gets all the authorization rules in the policy, field filters can be specified. +func (e *Enforcer) GetFilteredPolicy(fieldIndex int, fieldValues ...string) [][]string { + return e.GetFilteredNamedPolicy("p", fieldIndex, fieldValues...) +} + +// GetNamedPolicy gets all the authorization rules in the named policy. +func (e *Enforcer) GetNamedPolicy(ptype string) [][]string { + return e.model.GetPolicy("p", ptype) +} + +// GetFilteredNamedPolicy gets all the authorization rules in the named policy, field filters can be specified. +func (e *Enforcer) GetFilteredNamedPolicy(ptype string, fieldIndex int, fieldValues ...string) [][]string { + return e.model.GetFilteredPolicy("p", ptype, fieldIndex, fieldValues...) +} + +// GetGroupingPolicy gets all the role inheritance rules in the policy. +func (e *Enforcer) GetGroupingPolicy() [][]string { + return e.GetNamedGroupingPolicy("g") +} + +// GetFilteredGroupingPolicy gets all the role inheritance rules in the policy, field filters can be specified. +func (e *Enforcer) GetFilteredGroupingPolicy(fieldIndex int, fieldValues ...string) [][]string { + return e.GetFilteredNamedGroupingPolicy("g", fieldIndex, fieldValues...) +} + +// GetNamedGroupingPolicy gets all the role inheritance rules in the policy. +func (e *Enforcer) GetNamedGroupingPolicy(ptype string) [][]string { + return e.model.GetPolicy("g", ptype) +} + +// GetFilteredNamedGroupingPolicy gets all the role inheritance rules in the policy, field filters can be specified. +func (e *Enforcer) GetFilteredNamedGroupingPolicy(ptype string, fieldIndex int, fieldValues ...string) [][]string { + return e.model.GetFilteredPolicy("g", ptype, fieldIndex, fieldValues...) +} + +// HasPolicy determines whether an authorization rule exists. +func (e *Enforcer) HasPolicy(params ...interface{}) bool { + return e.HasNamedPolicy("p", params...) +} + +// HasNamedPolicy determines whether a named authorization rule exists. +func (e *Enforcer) HasNamedPolicy(ptype string, params ...interface{}) bool { + if strSlice, ok := params[0].([]string); len(params) == 1 && ok { + return e.model.HasPolicy("p", ptype, strSlice) + } + + policy := make([]string, 0) + for _, param := range params { + policy = append(policy, param.(string)) + } + + return e.model.HasPolicy("p", ptype, policy) +} + +// AddPolicy adds an authorization rule to the current policy. +// If the rule already exists, the function returns false and the rule will not be added. +// Otherwise the function returns true by adding the new rule. +func (e *Enforcer) AddPolicy(params ...interface{}) bool { + return e.AddNamedPolicy("p", params...) +} + +// AddNamedPolicy adds an authorization rule to the current named policy. +// If the rule already exists, the function returns false and the rule will not be added. +// Otherwise the function returns true by adding the new rule. +func (e *Enforcer) AddNamedPolicy(ptype string, params ...interface{}) bool { + var ruleAdded bool + if strSlice, ok := params[0].([]string); len(params) == 1 && ok { + ruleAdded = e.addPolicy("p", ptype, strSlice) + } else { + policy := make([]string, 0) + for _, param := range params { + policy = append(policy, param.(string)) + } + + ruleAdded = e.addPolicy("p", ptype, policy) + } + + return ruleAdded +} + +// RemovePolicy removes an authorization rule from the current policy. +func (e *Enforcer) RemovePolicy(params ...interface{}) bool { + return e.RemoveNamedPolicy("p", params...) +} + +// RemoveFilteredPolicy removes an authorization rule from the current policy, field filters can be specified. +func (e *Enforcer) RemoveFilteredPolicy(fieldIndex int, fieldValues ...string) bool { + return e.RemoveFilteredNamedPolicy("p", fieldIndex, fieldValues...) +} + +// RemoveNamedPolicy removes an authorization rule from the current named policy. +func (e *Enforcer) RemoveNamedPolicy(ptype string, params ...interface{}) bool { + var ruleRemoved bool + if strSlice, ok := params[0].([]string); len(params) == 1 && ok { + ruleRemoved = e.removePolicy("p", ptype, strSlice) + } else { + policy := make([]string, 0) + for _, param := range params { + policy = append(policy, param.(string)) + } + + ruleRemoved = e.removePolicy("p", ptype, policy) + } + + return ruleRemoved +} + +// RemoveFilteredNamedPolicy removes an authorization rule from the current named policy, field filters can be specified. +func (e *Enforcer) RemoveFilteredNamedPolicy(ptype string, fieldIndex int, fieldValues ...string) bool { + return e.removeFilteredPolicy("p", ptype, fieldIndex, fieldValues...) +} + +// HasGroupingPolicy determines whether a role inheritance rule exists. +func (e *Enforcer) HasGroupingPolicy(params ...interface{}) bool { + return e.HasNamedGroupingPolicy("g", params...) +} + +// HasNamedGroupingPolicy determines whether a named role inheritance rule exists. +func (e *Enforcer) HasNamedGroupingPolicy(ptype string, params ...interface{}) bool { + if strSlice, ok := params[0].([]string); len(params) == 1 && ok { + return e.model.HasPolicy("g", ptype, strSlice) + } + + policy := make([]string, 0) + for _, param := range params { + policy = append(policy, param.(string)) + } + + return e.model.HasPolicy("g", ptype, policy) +} + +// AddGroupingPolicy adds a role inheritance rule to the current policy. +// If the rule already exists, the function returns false and the rule will not be added. +// Otherwise the function returns true by adding the new rule. +func (e *Enforcer) AddGroupingPolicy(params ...interface{}) bool { + return e.AddNamedGroupingPolicy("g", params...) +} + +// AddNamedGroupingPolicy adds a named role inheritance rule to the current policy. +// If the rule already exists, the function returns false and the rule will not be added. +// Otherwise the function returns true by adding the new rule. +func (e *Enforcer) AddNamedGroupingPolicy(ptype string, params ...interface{}) bool { + var ruleAdded bool + if strSlice, ok := params[0].([]string); len(params) == 1 && ok { + ruleAdded = e.addPolicy("g", ptype, strSlice) + } else { + policy := make([]string, 0) + for _, param := range params { + policy = append(policy, param.(string)) + } + + ruleAdded = e.addPolicy("g", ptype, policy) + } + + if e.autoBuildRoleLinks { + e.BuildRoleLinks() + } + return ruleAdded +} + +// RemoveGroupingPolicy removes a role inheritance rule from the current policy. +func (e *Enforcer) RemoveGroupingPolicy(params ...interface{}) bool { + return e.RemoveNamedGroupingPolicy("g", params...) +} + +// RemoveFilteredGroupingPolicy removes a role inheritance rule from the current policy, field filters can be specified. +func (e *Enforcer) RemoveFilteredGroupingPolicy(fieldIndex int, fieldValues ...string) bool { + return e.RemoveFilteredNamedGroupingPolicy("g", fieldIndex, fieldValues...) +} + +// RemoveNamedGroupingPolicy removes a role inheritance rule from the current named policy. +func (e *Enforcer) RemoveNamedGroupingPolicy(ptype string, params ...interface{}) bool { + var ruleRemoved bool + if strSlice, ok := params[0].([]string); len(params) == 1 && ok { + ruleRemoved = e.removePolicy("g", ptype, strSlice) + } else { + policy := make([]string, 0) + for _, param := range params { + policy = append(policy, param.(string)) + } + + ruleRemoved = e.removePolicy("g", ptype, policy) + } + + if e.autoBuildRoleLinks { + e.BuildRoleLinks() + } + return ruleRemoved +} + +// RemoveFilteredNamedGroupingPolicy removes a role inheritance rule from the current named policy, field filters can be specified. +func (e *Enforcer) RemoveFilteredNamedGroupingPolicy(ptype string, fieldIndex int, fieldValues ...string) bool { + ruleRemoved := e.removeFilteredPolicy("g", ptype, fieldIndex, fieldValues...) + + if e.autoBuildRoleLinks { + e.BuildRoleLinks() + } + return ruleRemoved +} + +// AddFunction adds a customized function. +func (e *Enforcer) AddFunction(name string, function func(args ...interface{}) (interface{}, error)) { + e.fm.AddFunction(name, function) +} diff --git a/vendor/github.com/casbin/casbin/model/assertion.go b/vendor/github.com/casbin/casbin/model/assertion.go new file mode 100644 index 00000000..ee46fd75 --- /dev/null +++ b/vendor/github.com/casbin/casbin/model/assertion.go @@ -0,0 +1,60 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package model + +import ( + "errors" + "strings" + + "github.com/casbin/casbin/rbac" + "github.com/casbin/casbin/util" +) + +// Assertion represents an expression in a section of the model. +// For example: r = sub, obj, act +type Assertion struct { + Key string + Value string + Tokens []string + Policy [][]string + RM rbac.RoleManager +} + +func (ast *Assertion) buildRoleLinks(rm rbac.RoleManager) { + ast.RM = rm + count := strings.Count(ast.Value, "_") + for _, rule := range ast.Policy { + if count < 2 { + panic(errors.New("the number of \"_\" in role definition should be at least 2")) + } + if len(rule) < count { + panic(errors.New("grouping policy elements do not meet role definition")) + } + + if count == 2 { + // error intentionally ignored + ast.RM.AddLink(rule[0], rule[1]) + } else if count == 3 { + // error intentionally ignored + ast.RM.AddLink(rule[0], rule[1], rule[2]) + } else if count == 4 { + // error intentionally ignored + ast.RM.AddLink(rule[0], rule[1], rule[2], rule[3]) + } + } + + util.LogPrint("Role links for: " + ast.Key) + ast.RM.PrintRoles() +} diff --git a/vendor/github.com/casbin/casbin/model/function.go b/vendor/github.com/casbin/casbin/model/function.go new file mode 100644 index 00000000..5c81d180 --- /dev/null +++ b/vendor/github.com/casbin/casbin/model/function.go @@ -0,0 +1,40 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package model + +import "github.com/casbin/casbin/util" + +// FunctionMap represents the collection of Function. +type FunctionMap map[string]func(args ...interface{}) (interface{}, error) + +// Function represents a function that is used in the matchers, used to get attributes in ABAC. +type Function func(args ...interface{}) (interface{}, error) + +// AddFunction adds an expression function. +func (fm FunctionMap) AddFunction(name string, function Function) { + fm[name] = function +} + +// LoadFunctionMap loads an initial function map. +func LoadFunctionMap() FunctionMap { + fm := make(FunctionMap) + + fm.AddFunction("keyMatch", util.KeyMatchFunc) + fm.AddFunction("keyMatch2", util.KeyMatch2Func) + fm.AddFunction("regexMatch", util.RegexMatchFunc) + fm.AddFunction("ipMatch", util.IPMatchFunc) + + return fm +} diff --git a/vendor/github.com/casbin/casbin/model/model.go b/vendor/github.com/casbin/casbin/model/model.go new file mode 100644 index 00000000..55a1c120 --- /dev/null +++ b/vendor/github.com/casbin/casbin/model/model.go @@ -0,0 +1,129 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package model + +import ( + "strconv" + "strings" + + "github.com/casbin/casbin/config" + "github.com/casbin/casbin/util" +) + +// Model represents the whole access control model. +type Model map[string]AssertionMap + +// AssertionMap is the collection of assertions, can be "r", "p", "g", "e", "m". +type AssertionMap map[string]*Assertion + +var sectionNameMap = map[string]string{ + "r": "request_definition", + "p": "policy_definition", + "g": "role_definition", + "e": "policy_effect", + "m": "matchers", +} + +func loadAssertion(model Model, cfg config.ConfigInterface, sec string, key string) bool { + value := cfg.String(sectionNameMap[sec] + "::" + key) + return model.AddDef(sec, key, value) +} + +// AddDef adds an assertion to the model. +func (model Model) AddDef(sec string, key string, value string) bool { + ast := Assertion{} + ast.Key = key + ast.Value = value + + if ast.Value == "" { + return false + } + + if sec == "r" || sec == "p" { + ast.Tokens = strings.Split(ast.Value, ", ") + for i := range ast.Tokens { + ast.Tokens[i] = key + "_" + ast.Tokens[i] + } + } else { + ast.Value = util.RemoveComments(util.EscapeAssertion(ast.Value)) + } + + _, ok := model[sec] + if !ok { + model[sec] = make(AssertionMap) + } + + model[sec][key] = &ast + return true +} + +func getKeySuffix(i int) string { + if i == 1 { + return "" + } + + return strconv.Itoa(i) +} + +func loadSection(model Model, cfg config.ConfigInterface, sec string) { + i := 1 + for { + if !loadAssertion(model, cfg, sec, sec+getKeySuffix(i)) { + break + } else { + i++ + } + } +} + +// LoadModel loads the model from model CONF file. +func (model Model) LoadModel(path string) { + cfg, err := config.NewConfig(path) + if err != nil { + panic(err) + } + + loadSection(model, cfg, "r") + loadSection(model, cfg, "p") + loadSection(model, cfg, "e") + loadSection(model, cfg, "m") + + loadSection(model, cfg, "g") +} + +// LoadModelFromText loads the model from the text. +func (model Model) LoadModelFromText(text string) { + cfg, err := config.NewConfigFromText(text) + if err != nil { + panic(err) + } + + loadSection(model, cfg, "r") + loadSection(model, cfg, "p") + loadSection(model, cfg, "e") + loadSection(model, cfg, "m") + + loadSection(model, cfg, "g") +} + +// PrintModel prints the model to the log. +func (model Model) PrintModel() { + util.LogPrint("Model:") + for k, v := range model { + for i, j := range v { + util.LogPrintf("%s.%s: %s", k, i, j.Value) + } + } +} diff --git a/vendor/github.com/casbin/casbin/model/policy.go b/vendor/github.com/casbin/casbin/model/policy.go new file mode 100644 index 00000000..9d3808ed --- /dev/null +++ b/vendor/github.com/casbin/casbin/model/policy.go @@ -0,0 +1,146 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package model + +import ( + "github.com/casbin/casbin/rbac" + "github.com/casbin/casbin/util" +) + +// BuildRoleLinks initializes the roles in RBAC. +func (model Model) BuildRoleLinks(rm rbac.RoleManager) { + for _, ast := range model["g"] { + ast.buildRoleLinks(rm) + } +} + +// PrintPolicy prints the policy to log. +func (model Model) PrintPolicy() { + util.LogPrint("Policy:") + for key, ast := range model["p"] { + util.LogPrint(key, ": ", ast.Value, ": ", ast.Policy) + } + + for key, ast := range model["g"] { + util.LogPrint(key, ": ", ast.Value, ": ", ast.Policy) + } +} + +// ClearPolicy clears all current policy. +func (model Model) ClearPolicy() { + for _, ast := range model["p"] { + ast.Policy = nil + } + + for _, ast := range model["g"] { + ast.Policy = nil + } +} + +// GetPolicy gets all rules in a policy. +func (model Model) GetPolicy(sec string, ptype string) [][]string { + return model[sec][ptype].Policy +} + +// GetFilteredPolicy gets rules based on field filters from a policy. +func (model Model) GetFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) [][]string { + res := [][]string{} + + for _, rule := range model[sec][ptype].Policy { + matched := true + for i, fieldValue := range fieldValues { + if fieldValue != "" && rule[fieldIndex+i] != fieldValue { + matched = false + break + } + } + + if matched { + res = append(res, rule) + } + } + + return res +} + +// HasPolicy determines whether a model has the specified policy rule. +func (model Model) HasPolicy(sec string, ptype string, rule []string) bool { + for _, r := range model[sec][ptype].Policy { + if util.ArrayEquals(rule, r) { + return true + } + } + + return false +} + +// AddPolicy adds a policy rule to the model. +func (model Model) AddPolicy(sec string, ptype string, rule []string) bool { + if !model.HasPolicy(sec, ptype, rule) { + model[sec][ptype].Policy = append(model[sec][ptype].Policy, rule) + return true + } + return false +} + +// RemovePolicy removes a policy rule from the model. +func (model Model) RemovePolicy(sec string, ptype string, rule []string) bool { + for i, r := range model[sec][ptype].Policy { + if util.ArrayEquals(rule, r) { + model[sec][ptype].Policy = append(model[sec][ptype].Policy[:i], model[sec][ptype].Policy[i+1:]...) + return true + } + } + + return false +} + +// RemoveFilteredPolicy removes policy rules based on field filters from the model. +func (model Model) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) bool { + tmp := [][]string{} + res := false + for _, rule := range model[sec][ptype].Policy { + matched := true + for i, fieldValue := range fieldValues { + if fieldValue != "" && rule[fieldIndex+i] != fieldValue { + matched = false + break + } + } + + if matched { + res = true + } else { + tmp = append(tmp, rule) + } + } + + model[sec][ptype].Policy = tmp + return res +} + +// GetValuesForFieldInPolicy gets all values for a field for all rules in a policy, duplicated values are removed. +func (model Model) GetValuesForFieldInPolicy(sec string, ptype string, fieldIndex int) []string { + values := []string{} + + for _, rule := range model[sec][ptype].Policy { + values = append(values, rule[fieldIndex]) + } + + util.ArrayRemoveDuplicates(&values) + // sort.Strings(values) + + return values +} diff --git a/vendor/github.com/casbin/casbin/persist/adapter.go b/vendor/github.com/casbin/casbin/persist/adapter.go new file mode 100644 index 00000000..0d994b98 --- /dev/null +++ b/vendor/github.com/casbin/casbin/persist/adapter.go @@ -0,0 +1,56 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package persist + +import ( + "strings" + + "github.com/casbin/casbin/model" +) + +// LoadPolicyLine loads a text line as a policy rule to model. +func LoadPolicyLine(line string, model model.Model) { + if line == "" { + return + } + + if strings.HasPrefix(line, "#") { + return + } + + tokens := strings.Split(line, ", ") + + key := tokens[0] + sec := key[:1] + model[sec][key].Policy = append(model[sec][key].Policy, tokens[1:]) +} + +// Adapter is the interface for Casbin adapters. +type Adapter interface { + // LoadPolicy loads all policy rules from the storage. + LoadPolicy(model model.Model) error + // SavePolicy saves all policy rules to the storage. + SavePolicy(model model.Model) error + + // AddPolicy adds a policy rule to the storage. + // This is part of the Auto-Save feature. + AddPolicy(sec string, ptype string, rule []string) error + // RemovePolicy removes a policy rule from the storage. + // This is part of the Auto-Save feature. + RemovePolicy(sec string, ptype string, rule []string) error + // RemoveFilteredPolicy removes policy rules that match the filter from the storage. + // This is part of the Auto-Save feature. + RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error +} diff --git a/vendor/github.com/casbin/casbin/persist/adapter_filtered.go b/vendor/github.com/casbin/casbin/persist/adapter_filtered.go new file mode 100644 index 00000000..fab51e72 --- /dev/null +++ b/vendor/github.com/casbin/casbin/persist/adapter_filtered.go @@ -0,0 +1,29 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package persist + +import ( + "github.com/casbin/casbin/model" +) + +// FilteredAdapter is the interface for Casbin adapters supporting filtered policies. +type FilteredAdapter interface { + Adapter + + // LoadFilteredPolicy loads only policy rules that match the filter. + LoadFilteredPolicy(model model.Model, filter interface{}) error + // IsFiltered returns true if the loaded policy has been filtered. + IsFiltered() bool +} diff --git a/vendor/github.com/casbin/casbin/persist/file-adapter/adapter.go b/vendor/github.com/casbin/casbin/persist/file-adapter/adapter.go new file mode 100644 index 00000000..91bf7a9a --- /dev/null +++ b/vendor/github.com/casbin/casbin/persist/file-adapter/adapter.go @@ -0,0 +1,117 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fileadapter + +import ( + "bufio" + "bytes" + "errors" + "os" + "strings" + + "github.com/casbin/casbin/model" + "github.com/casbin/casbin/persist" + "github.com/casbin/casbin/util" +) + +// Adapter is the file adapter for Casbin. +// It can load policy from file or save policy to file. +type Adapter struct { + filePath string +} + +// NewAdapter is the constructor for Adapter. +func NewAdapter(filePath string) *Adapter { + return &Adapter{filePath: filePath} +} + +// LoadPolicy loads all policy rules from the storage. +func (a *Adapter) LoadPolicy(model model.Model) error { + if a.filePath == "" { + return errors.New("invalid file path, file path cannot be empty") + } + + return a.loadPolicyFile(model, persist.LoadPolicyLine) +} + +// SavePolicy saves all policy rules to the storage. +func (a *Adapter) SavePolicy(model model.Model) error { + if a.filePath == "" { + return errors.New("invalid file path, file path cannot be empty") + } + + var tmp bytes.Buffer + + for ptype, ast := range model["p"] { + for _, rule := range ast.Policy { + tmp.WriteString(ptype + ", ") + tmp.WriteString(util.ArrayToString(rule)) + tmp.WriteString("\n") + } + } + + for ptype, ast := range model["g"] { + for _, rule := range ast.Policy { + tmp.WriteString(ptype + ", ") + tmp.WriteString(util.ArrayToString(rule)) + tmp.WriteString("\n") + } + } + + return a.savePolicyFile(strings.TrimRight(tmp.String(), "\n")) +} + +func (a *Adapter) loadPolicyFile(model model.Model, handler func(string, model.Model)) error { + f, err := os.Open(a.filePath) + if err != nil { + return err + } + defer f.Close() + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + handler(line, model) + } + return scanner.Err() +} + +func (a *Adapter) savePolicyFile(text string) error { + f, err := os.Create(a.filePath) + if err != nil { + return err + } + w := bufio.NewWriter(f) + // error intentionally ignored + w.WriteString(text) + w.Flush() + f.Close() + return nil +} + +// AddPolicy adds a policy rule to the storage. +func (a *Adapter) AddPolicy(sec string, ptype string, rule []string) error { + return errors.New("not implemented") +} + +// RemovePolicy removes a policy rule from the storage. +func (a *Adapter) RemovePolicy(sec string, ptype string, rule []string) error { + return errors.New("not implemented") +} + +// RemoveFilteredPolicy removes policy rules that match the filter from the storage. +func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error { + return errors.New("not implemented") +} diff --git a/vendor/github.com/casbin/casbin/persist/file-adapter/adapter_filtered.go b/vendor/github.com/casbin/casbin/persist/file-adapter/adapter_filtered.go new file mode 100644 index 00000000..f0b18c3b --- /dev/null +++ b/vendor/github.com/casbin/casbin/persist/file-adapter/adapter_filtered.go @@ -0,0 +1,137 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fileadapter + +import ( + "bufio" + "errors" + "os" + "strings" + + "github.com/casbin/casbin/model" + "github.com/casbin/casbin/persist" +) + +// FilteredAdapter is the filtered file adapter for Casbin. It can load policy +// from file or save policy to file and supports loading of filtered policies. +type FilteredAdapter struct { + *Adapter + filtered bool +} + +// Filter defines the filtering rules for a FilteredAdapter's policy. Empty values +// are ignored, but all others must match the filter. +type Filter struct { + P []string + G []string +} + +// NewFilteredAdapter is the constructor for FilteredAdapter. +func NewFilteredAdapter(filePath string) *FilteredAdapter { + a := FilteredAdapter{} + a.Adapter = NewAdapter(filePath) + return &a +} + +// LoadPolicy loads all policy rules from the storage. +func (a *FilteredAdapter) LoadPolicy(model model.Model) error { + a.filtered = false + return a.Adapter.LoadPolicy(model) +} + +// LoadFilteredPolicy loads only policy rules that match the filter. +func (a *FilteredAdapter) LoadFilteredPolicy(model model.Model, filter interface{}) error { + if filter == nil { + return a.LoadPolicy(model) + } + if a.filePath == "" { + return errors.New("invalid file path, file path cannot be empty") + } + + filterValue, ok := filter.(*Filter) + if !ok { + return errors.New("invalid filter type") + } + err := a.loadFilteredPolicyFile(model, filterValue, persist.LoadPolicyLine) + if err == nil { + a.filtered = true + } + return err +} + +func (a *FilteredAdapter) loadFilteredPolicyFile(model model.Model, filter *Filter, handler func(string, model.Model)) error { + f, err := os.Open(a.filePath) + if err != nil { + return err + } + defer f.Close() + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + + if filterLine(line, filter) { + continue + } + + handler(line, model) + } + return scanner.Err() +} + +// IsFiltered returns true if the loaded policy has been filtered. +func (a *FilteredAdapter) IsFiltered() bool { + return a.filtered +} + +// SavePolicy saves all policy rules to the storage. +func (a *FilteredAdapter) SavePolicy(model model.Model) error { + if a.filtered { + return errors.New("cannot save a filtered policy") + } + return a.Adapter.SavePolicy(model) +} + +func filterLine(line string, filter *Filter) bool { + if filter == nil { + return false + } + p := strings.Split(line, ",") + if len(p) == 0 { + return true + } + var filterSlice []string + switch strings.TrimSpace(p[0]) { + case "p": + filterSlice = filter.P + case "g": + filterSlice = filter.G + } + return filterWords(p, filterSlice) +} + +func filterWords(line []string, filter []string) bool { + if len(line) < len(filter)+1 { + return true + } + var skipLine bool + for i, v := range filter { + if len(v) > 0 && strings.TrimSpace(v) != strings.TrimSpace(line[i+1]) { + skipLine = true + break + } + } + return skipLine +} diff --git a/vendor/github.com/casbin/casbin/persist/file-adapter/adapter_mock.go b/vendor/github.com/casbin/casbin/persist/file-adapter/adapter_mock.go new file mode 100644 index 00000000..0d7f83a7 --- /dev/null +++ b/vendor/github.com/casbin/casbin/persist/file-adapter/adapter_mock.go @@ -0,0 +1,100 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fileadapter + +import ( + "bufio" + "errors" + "io" + "os" + "strings" + + "github.com/casbin/casbin/model" + "github.com/casbin/casbin/persist" +) + +// AdapterMock is the file adapter for Casbin. +// It can load policy from file or save policy to file. +type AdapterMock struct { + filePath string + errorValue string +} + +// NewAdapterMock is the constructor for AdapterMock. +func NewAdapterMock(filePath string) *AdapterMock { + a := AdapterMock{} + a.filePath = filePath + return &a +} + +// LoadPolicy loads all policy rules from the storage. +func (a *AdapterMock) LoadPolicy(model model.Model) error { + err := a.loadPolicyFile(model, persist.LoadPolicyLine) + return err +} + +// SavePolicy saves all policy rules to the storage. +func (a *AdapterMock) SavePolicy(model model.Model) error { + return nil +} + +func (a *AdapterMock) loadPolicyFile(model model.Model, handler func(string, model.Model)) error { + f, err := os.Open(a.filePath) + if err != nil { + return err + } + defer f.Close() + + buf := bufio.NewReader(f) + for { + line, err := buf.ReadString('\n') + line = strings.TrimSpace(line) + handler(line, model) + if err != nil { + if err == io.EOF { + return nil + } + } + } +} + +// SetMockErr sets string to be returned by of the mock during testing +func (a *AdapterMock) SetMockErr(errorToSet string) { + a.errorValue = errorToSet +} + +// GetMockErr returns a mock error or nil +func (a *AdapterMock) GetMockErr() error { + var returnError error + if a.errorValue != "" { + returnError = errors.New(a.errorValue) + } + return returnError +} + +// AddPolicy adds a policy rule to the storage. +func (a *AdapterMock) AddPolicy(sec string, ptype string, rule []string) error { + return a.GetMockErr() +} + +// RemovePolicy removes a policy rule from the storage. +func (a *AdapterMock) RemovePolicy(sec string, ptype string, rule []string) error { + return a.GetMockErr() +} + +// RemoveFilteredPolicy removes policy rules that match the filter from the storage. +func (a *AdapterMock) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error { + return a.GetMockErr() +} diff --git a/vendor/github.com/casbin/casbin/persist/watcher.go b/vendor/github.com/casbin/casbin/persist/watcher.go new file mode 100644 index 00000000..089fbc0f --- /dev/null +++ b/vendor/github.com/casbin/casbin/persist/watcher.go @@ -0,0 +1,27 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package persist + +// Watcher is the interface for Casbin watchers. +type Watcher interface { + // SetUpdateCallback sets the callback function that the watcher will call + // when the policy in DB has been changed by other instances. + // A classic callback is Enforcer.LoadPolicy(). + SetUpdateCallback(func(string)) error + // Update calls the update callback of other instances to synchronize their policy. + // It is usually called after changing the policy in DB, like Enforcer.SavePolicy(), + // Enforcer.AddPolicy(), Enforcer.RemovePolicy(), etc. + Update() error +} diff --git a/vendor/github.com/casbin/casbin/rbac/default-role-manager/role_manager.go b/vendor/github.com/casbin/casbin/rbac/default-role-manager/role_manager.go new file mode 100644 index 00000000..15893560 --- /dev/null +++ b/vendor/github.com/casbin/casbin/rbac/default-role-manager/role_manager.go @@ -0,0 +1,257 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package defaultrolemanager + +import ( + "errors" + "sync" + + "github.com/casbin/casbin/rbac" + "github.com/casbin/casbin/util" +) + +// RoleManager provides a default implementation for the RoleManager interface +type RoleManager struct { + allRoles *sync.Map + maxHierarchyLevel int +} + +// NewRoleManager is the constructor for creating an instance of the +// default RoleManager implementation. +func NewRoleManager(maxHierarchyLevel int) rbac.RoleManager { + rm := RoleManager{} + rm.allRoles = &sync.Map{} + rm.maxHierarchyLevel = maxHierarchyLevel + return &rm +} + +func (rm *RoleManager) hasRole(name string) bool { + _, ok := rm.allRoles.Load(name) + return ok +} + +func (rm *RoleManager) createRole(name string) *Role { + role, _ := rm.allRoles.LoadOrStore(name, newRole(name)) + return role.(*Role) +} + +// Clear clears all stored data and resets the role manager to the initial state. +func (rm *RoleManager) Clear() error { + rm.allRoles = &sync.Map{} + return nil +} + +// AddLink adds the inheritance link between role: name1 and role: name2. +// aka role: name1 inherits role: name2. +// domain is a prefix to the roles. +func (rm *RoleManager) AddLink(name1 string, name2 string, domain ...string) error { + if len(domain) == 1 { + name1 = domain[0] + "::" + name1 + name2 = domain[0] + "::" + name2 + } else if len(domain) > 1 { + return errors.New("error: domain should be 1 parameter") + } + + role1 := rm.createRole(name1) + role2 := rm.createRole(name2) + role1.addRole(role2) + return nil +} + +// DeleteLink deletes the inheritance link between role: name1 and role: name2. +// aka role: name1 does not inherit role: name2 any more. +// domain is a prefix to the roles. +func (rm *RoleManager) DeleteLink(name1 string, name2 string, domain ...string) error { + if len(domain) == 1 { + name1 = domain[0] + "::" + name1 + name2 = domain[0] + "::" + name2 + } else if len(domain) > 1 { + return errors.New("error: domain should be 1 parameter") + } + + if !rm.hasRole(name1) || !rm.hasRole(name2) { + return errors.New("error: name1 or name2 does not exist") + } + + role1 := rm.createRole(name1) + role2 := rm.createRole(name2) + role1.deleteRole(role2) + return nil +} + +// HasLink determines whether role: name1 inherits role: name2. +// domain is a prefix to the roles. +func (rm *RoleManager) HasLink(name1 string, name2 string, domain ...string) (bool, error) { + if len(domain) == 1 { + name1 = domain[0] + "::" + name1 + name2 = domain[0] + "::" + name2 + } else if len(domain) > 1 { + return false, errors.New("error: domain should be 1 parameter") + } + + if name1 == name2 { + return true, nil + } + + if !rm.hasRole(name1) || !rm.hasRole(name2) { + return false, nil + } + + role1 := rm.createRole(name1) + return role1.hasRole(name2, rm.maxHierarchyLevel), nil +} + +// GetRoles gets the roles that a subject inherits. +// domain is a prefix to the roles. +func (rm *RoleManager) GetRoles(name string, domain ...string) ([]string, error) { + if len(domain) == 1 { + name = domain[0] + "::" + name + } else if len(domain) > 1 { + return nil, errors.New("error: domain should be 1 parameter") + } + + if !rm.hasRole(name) { + return nil, errors.New("error: name does not exist") + } + + roles := rm.createRole(name).getRoles() + if len(domain) == 1 { + for i := range roles { + roles[i] = roles[i][len(domain[0])+2:] + } + } + return roles, nil +} + +// GetUsers gets the users that inherits a subject. +// domain is an unreferenced parameter here, may be used in other implementations. +func (rm *RoleManager) GetUsers(name string, domain ...string) ([]string, error) { + if !rm.hasRole(name) { + return nil, errors.New("error: name does not exist") + } + + names := []string{} + rm.allRoles.Range(func(_, value interface{}) bool { + role := value.(*Role) + if role.hasDirectRole(name) { + names = append(names, role.name) + } + return true + }) + return names, nil +} + +// PrintRoles prints all the roles to log. +func (rm *RoleManager) PrintRoles() error { + line := "" + rm.allRoles.Range(func(_, value interface{}) bool { + if text := value.(*Role).toString(); text != "" { + if line == "" { + line = text + } else { + line += ", " + text + } + } + return true + }) + util.LogPrint(line) + return nil +} + +// Role represents the data structure for a role in RBAC. +type Role struct { + name string + roles []*Role +} + +func newRole(name string) *Role { + r := Role{} + r.name = name + return &r +} + +func (r *Role) addRole(role *Role) { + for _, rr := range r.roles { + if rr.name == role.name { + return + } + } + + r.roles = append(r.roles, role) +} + +func (r *Role) deleteRole(role *Role) { + for i, rr := range r.roles { + if rr.name == role.name { + r.roles = append(r.roles[:i], r.roles[i+1:]...) + return + } + } +} + +func (r *Role) hasRole(name string, hierarchyLevel int) bool { + if r.name == name { + return true + } + + if hierarchyLevel <= 0 { + return false + } + + for _, role := range r.roles { + if role.hasRole(name, hierarchyLevel-1) { + return true + } + } + return false +} + +func (r *Role) hasDirectRole(name string) bool { + for _, role := range r.roles { + if role.name == name { + return true + } + } + + return false +} + +func (r *Role) toString() string { + names := "" + if len(r.roles) == 0 { + return "" + } + for i, role := range r.roles { + if i == 0 { + names += role.name + } else { + names += ", " + role.name + } + } + + if len(r.roles) == 1 { + return r.name + " < " + names + } else { + return r.name + " < (" + names + ")" + } +} + +func (r *Role) getRoles() []string { + names := []string{} + for _, role := range r.roles { + names = append(names, role.name) + } + return names +} diff --git a/vendor/github.com/casbin/casbin/rbac/role_manager.go b/vendor/github.com/casbin/casbin/rbac/role_manager.go new file mode 100644 index 00000000..b853bc93 --- /dev/null +++ b/vendor/github.com/casbin/casbin/rbac/role_manager.go @@ -0,0 +1,38 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rbac + +// RoleManager provides interface to define the operations for managing roles. +type RoleManager interface { + // Clear clears all stored data and resets the role manager to the initial state. + Clear() error + // AddLink adds the inheritance link between two roles. role: name1 and role: name2. + // domain is a prefix to the roles (can be used for other purposes). + AddLink(name1 string, name2 string, domain ...string) error + // DeleteLink deletes the inheritance link between two roles. role: name1 and role: name2. + // domain is a prefix to the roles (can be used for other purposes). + DeleteLink(name1 string, name2 string, domain ...string) error + // HasLink determines whether a link exists between two roles. role: name1 inherits role: name2. + // domain is a prefix to the roles (can be used for other purposes). + HasLink(name1 string, name2 string, domain ...string) (bool, error) + // GetRoles gets the roles that a user inherits. + // domain is a prefix to the roles (can be used for other purposes). + GetRoles(name string, domain ...string) ([]string, error) + // GetUsers gets the users that inherits a role. + // domain is a prefix to the users (can be used for other purposes). + GetUsers(name string, domain ...string) ([]string, error) + // PrintRoles prints all the roles to log. + PrintRoles() error +} diff --git a/vendor/github.com/casbin/casbin/rbac_api.go b/vendor/github.com/casbin/casbin/rbac_api.go new file mode 100644 index 00000000..4a78b6fd --- /dev/null +++ b/vendor/github.com/casbin/casbin/rbac_api.go @@ -0,0 +1,127 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +// GetRolesForUser gets the roles that a user has. +func (e *Enforcer) GetRolesForUser(name string) []string { + res, _ := e.model["g"]["g"].RM.GetRoles(name) + return res +} + +// GetUsersForRole gets the users that has a role. +func (e *Enforcer) GetUsersForRole(name string) []string { + res, _ := e.model["g"]["g"].RM.GetUsers(name) + return res +} + +// HasRoleForUser determines whether a user has a role. +func (e *Enforcer) HasRoleForUser(name string, role string) bool { + roles := e.GetRolesForUser(name) + + hasRole := false + for _, r := range roles { + if r == role { + hasRole = true + break + } + } + + return hasRole +} + +// AddRoleForUser adds a role for a user. +// Returns false if the user already has the role (aka not affected). +func (e *Enforcer) AddRoleForUser(user string, role string) bool { + return e.AddGroupingPolicy(user, role) +} + +// DeleteRoleForUser deletes a role for a user. +// Returns false if the user does not have the role (aka not affected). +func (e *Enforcer) DeleteRoleForUser(user string, role string) bool { + return e.RemoveGroupingPolicy(user, role) +} + +// DeleteRolesForUser deletes all roles for a user. +// Returns false if the user does not have any roles (aka not affected). +func (e *Enforcer) DeleteRolesForUser(user string) bool { + return e.RemoveFilteredGroupingPolicy(0, user) +} + +// DeleteUser deletes a user. +// Returns false if the user does not exist (aka not affected). +func (e *Enforcer) DeleteUser(user string) bool { + return e.RemoveFilteredGroupingPolicy(0, user) +} + +// DeleteRole deletes a role. +func (e *Enforcer) DeleteRole(role string) { + e.RemoveFilteredGroupingPolicy(1, role) + e.RemoveFilteredPolicy(0, role) +} + +// DeletePermission deletes a permission. +// Returns false if the permission does not exist (aka not affected). +func (e *Enforcer) DeletePermission(permission ...string) bool { + return e.RemoveFilteredPolicy(1, permission...) +} + +// AddPermissionForUser adds a permission for a user or role. +// Returns false if the user or role already has the permission (aka not affected). +func (e *Enforcer) AddPermissionForUser(user string, permission ...string) bool { + params := make([]interface{}, 0, len(permission)+1) + + params = append(params, user) + for _, perm := range permission { + params = append(params, perm) + } + + return e.AddPolicy(params...) +} + +// DeletePermissionForUser deletes a permission for a user or role. +// Returns false if the user or role does not have the permission (aka not affected). +func (e *Enforcer) DeletePermissionForUser(user string, permission ...string) bool { + params := make([]interface{}, 0, len(permission)+1) + + params = append(params, user) + for _, perm := range permission { + params = append(params, perm) + } + + return e.RemovePolicy(params...) +} + +// DeletePermissionsForUser deletes permissions for a user or role. +// Returns false if the user or role does not have any permissions (aka not affected). +func (e *Enforcer) DeletePermissionsForUser(user string) bool { + return e.RemoveFilteredPolicy(0, user) +} + +// GetPermissionsForUser gets permissions for a user or role. +func (e *Enforcer) GetPermissionsForUser(user string) [][]string { + return e.GetFilteredPolicy(0, user) +} + +// HasPermissionForUser determines whether a user has a permission. +func (e *Enforcer) HasPermissionForUser(user string, permission ...string) bool { + params := make([]interface{}, 0, len(permission)+1) + + params = append(params, user) + for _, perm := range permission { + params = append(params, perm) + } + + return e.HasPolicy(params...) +} diff --git a/vendor/github.com/casbin/casbin/rbac_api_synced.go b/vendor/github.com/casbin/casbin/rbac_api_synced.go new file mode 100644 index 00000000..119ce9f8 --- /dev/null +++ b/vendor/github.com/casbin/casbin/rbac_api_synced.go @@ -0,0 +1,121 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +// GetRolesForUser gets the roles that a user has. +func (e *SyncedEnforcer) GetRolesForUser(name string) []string { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.GetRolesForUser(name) +} + +// GetUsersForRole gets the users that has a role. +func (e *SyncedEnforcer) GetUsersForRole(name string) []string { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.GetUsersForRole(name) +} + +// HasRoleForUser determines whether a user has a role. +func (e *SyncedEnforcer) HasRoleForUser(name string, role string) bool { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.HasRoleForUser(name, role) +} + +// AddRoleForUser adds a role for a user. +// Returns false if the user already has the role (aka not affected). +func (e *SyncedEnforcer) AddRoleForUser(user string, role string) bool { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.AddRoleForUser(user, role) +} + +// DeleteRoleForUser deletes a role for a user. +// Returns false if the user does not have the role (aka not affected). +func (e *SyncedEnforcer) DeleteRoleForUser(user string, role string) bool { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.DeleteRoleForUser(user, role) +} + +// DeleteRolesForUser deletes all roles for a user. +// Returns false if the user does not have any roles (aka not affected). +func (e *SyncedEnforcer) DeleteRolesForUser(user string) bool { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.DeleteRolesForUser(user) +} + +// DeleteUser deletes a user. +// Returns false if the user does not exist (aka not affected). +func (e *SyncedEnforcer) DeleteUser(user string) bool { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.DeleteUser(user) +} + +// DeleteRole deletes a role. +func (e *SyncedEnforcer) DeleteRole(role string) { + e.m.Lock() + defer e.m.Unlock() + e.Enforcer.DeleteRole(role) +} + +// DeletePermission deletes a permission. +// Returns false if the permission does not exist (aka not affected). +func (e *SyncedEnforcer) DeletePermission(permission ...string) bool { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.DeletePermission(permission...) +} + +// AddPermissionForUser adds a permission for a user or role. +// Returns false if the user or role already has the permission (aka not affected). +func (e *SyncedEnforcer) AddPermissionForUser(user string, permission ...string) bool { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.AddPermissionForUser(user, permission...) +} + +// DeletePermissionForUser deletes a permission for a user or role. +// Returns false if the user or role does not have the permission (aka not affected). +func (e *SyncedEnforcer) DeletePermissionForUser(user string, permission ...string) bool { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.DeletePermissionForUser(user, permission...) +} + +// DeletePermissionsForUser deletes permissions for a user or role. +// Returns false if the user or role does not have any permissions (aka not affected). +func (e *SyncedEnforcer) DeletePermissionsForUser(user string) bool { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.DeletePermissionsForUser(user) +} + +// GetPermissionsForUser gets permissions for a user or role. +func (e *SyncedEnforcer) GetPermissionsForUser(user string) [][]string { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.GetPermissionsForUser(user) +} + +// HasPermissionForUser determines whether a user has a permission. +func (e *SyncedEnforcer) HasPermissionForUser(user string, permission ...string) bool { + e.m.Lock() + defer e.m.Unlock() + return e.Enforcer.HasPermissionForUser(user, permission...) +} diff --git a/vendor/github.com/casbin/casbin/rbac_api_with_domains.go b/vendor/github.com/casbin/casbin/rbac_api_with_domains.go new file mode 100644 index 00000000..d6af0d6d --- /dev/null +++ b/vendor/github.com/casbin/casbin/rbac_api_with_domains.go @@ -0,0 +1,38 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +// GetRolesForUserInDomain gets the roles that a user has inside a domain. +func (e *Enforcer) GetRolesForUserInDomain(name string, domain string) []string { + res, _ := e.model["g"]["g"].RM.GetRoles(name, domain) + return res +} + +// GetPermissionsForUserInDomain gets permissions for a user or role inside a domain. +func (e *Enforcer) GetPermissionsForUserInDomain(user string, domain string) [][]string { + return e.GetFilteredPolicy(0, user, domain) +} + +// AddRoleForUserInDomain adds a role for a user inside a domain. +// Returns false if the user already has the role (aka not affected). +func (e *Enforcer) AddRoleForUserInDomain(user string, role string, domain string) bool { + return e.AddGroupingPolicy(user, role, domain) +} + +// DeleteRoleForUserInDomain deletes a role for a user inside a domain. +// Returns false if the user does not have the role (aka not affected). +func (e *Enforcer) DeleteRoleForUserInDomain(user string, role string, domain string) bool { + return e.RemoveGroupingPolicy(user, role, domain) +} diff --git a/vendor/github.com/casbin/casbin/util/builtin_operators.go b/vendor/github.com/casbin/casbin/util/builtin_operators.go new file mode 100644 index 00000000..98b9d98e --- /dev/null +++ b/vendor/github.com/casbin/casbin/util/builtin_operators.go @@ -0,0 +1,160 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "net" + "regexp" + "strings" + + "github.com/casbin/casbin/rbac" +) + +// KeyMatch determines whether key1 matches the pattern of key2 (similar to RESTful path), key2 can contain a *. +// For example, "/foo/bar" matches "/foo/*" +func KeyMatch(key1 string, key2 string) bool { + i := strings.Index(key2, "*") + if i == -1 { + return key1 == key2 + } + + if len(key1) > i { + return key1[:i] == key2[:i] + } + return key1 == key2[:i] +} + +// KeyMatchFunc is the wrapper for KeyMatch. +func KeyMatchFunc(args ...interface{}) (interface{}, error) { + name1 := args[0].(string) + name2 := args[1].(string) + + return (bool)(KeyMatch(name1, name2)), nil +} + +// KeyMatch2 determines whether key1 matches the pattern of key2 (similar to RESTful path), key2 can contain a *. +// For example, "/foo/bar" matches "/foo/*", "/resource1" matches "/:resource" +func KeyMatch2(key1 string, key2 string) bool { + key2 = strings.Replace(key2, "/*", "/.*", -1) + + re := regexp.MustCompile(`(.*):[^/]+(.*)`) + for { + if !strings.Contains(key2, "/:") { + break + } + + key2 = "^" + re.ReplaceAllString(key2, "$1[^/]+$2") + "$" + } + + return RegexMatch(key1, key2) +} + +// KeyMatch2Func is the wrapper for KeyMatch2. +func KeyMatch2Func(args ...interface{}) (interface{}, error) { + name1 := args[0].(string) + name2 := args[1].(string) + + return (bool)(KeyMatch2(name1, name2)), nil +} + +// KeyMatch3 determines whether key1 matches the pattern of key2 (similar to RESTful path), key2 can contain a *. +// For example, "/foo/bar" matches "/foo/*", "/resource1" matches "/{resource}" +func KeyMatch3(key1 string, key2 string) bool { + key2 = strings.Replace(key2, "/*", "/.*", -1) + + re := regexp.MustCompile(`(.*)\{[^/]+\}(.*)`) + for { + if !strings.Contains(key2, "/{") { + break + } + + key2 = re.ReplaceAllString(key2, "$1[^/]+$2") + } + + return RegexMatch(key1, key2) +} + +// KeyMatch3Func is the wrapper for KeyMatch3. +func KeyMatch3Func(args ...interface{}) (interface{}, error) { + name1 := args[0].(string) + name2 := args[1].(string) + + return (bool)(KeyMatch3(name1, name2)), nil +} + +// RegexMatch determines whether key1 matches the pattern of key2 in regular expression. +func RegexMatch(key1 string, key2 string) bool { + res, err := regexp.MatchString(key2, key1) + if err != nil { + panic(err) + } + return res +} + +// RegexMatchFunc is the wrapper for RegexMatch. +func RegexMatchFunc(args ...interface{}) (interface{}, error) { + name1 := args[0].(string) + name2 := args[1].(string) + + return (bool)(RegexMatch(name1, name2)), nil +} + +// IPMatch determines whether IP address ip1 matches the pattern of IP address ip2, ip2 can be an IP address or a CIDR pattern. +// For example, "192.168.2.123" matches "192.168.2.0/24" +func IPMatch(ip1 string, ip2 string) bool { + objIP1 := net.ParseIP(ip1) + if objIP1 == nil { + panic("invalid argument: ip1 in IPMatch() function is not an IP address.") + } + + _, cidr, err := net.ParseCIDR(ip2) + if err != nil { + objIP2 := net.ParseIP(ip2) + if objIP2 == nil { + panic("invalid argument: ip2 in IPMatch() function is neither an IP address nor a CIDR.") + } + + return objIP1.Equal(objIP2) + } + + return cidr.Contains(objIP1) +} + +// IPMatchFunc is the wrapper for IPMatch. +func IPMatchFunc(args ...interface{}) (interface{}, error) { + ip1 := args[0].(string) + ip2 := args[1].(string) + + return (bool)(IPMatch(ip1, ip2)), nil +} + +// GenerateGFunction is the factory method of the g(_, _) function. +func GenerateGFunction(rm rbac.RoleManager) func(args ...interface{}) (interface{}, error) { + return func(args ...interface{}) (interface{}, error) { + name1 := args[0].(string) + name2 := args[1].(string) + + if rm == nil { + return name1 == name2, nil + } else if len(args) == 2 { + res, _ := rm.HasLink(name1, name2) + return res, nil + } else { + domain := args[2].(string) + res, _ := rm.HasLink(name1, name2, domain) + return res, nil + } + } +} diff --git a/vendor/github.com/casbin/casbin/util/log_util.go b/vendor/github.com/casbin/casbin/util/log_util.go new file mode 100644 index 00000000..c9e692be --- /dev/null +++ b/vendor/github.com/casbin/casbin/util/log_util.go @@ -0,0 +1,34 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import "log" + +// EnableLog controls whether to print log to console. +var EnableLog = true + +// LogPrint prints the log. +func LogPrint(v ...interface{}) { + if EnableLog { + log.Print(v...) + } +} + +// LogPrintf prints the log with the format. +func LogPrintf(format string, v ...interface{}) { + if EnableLog { + log.Printf(format, v...) + } +} diff --git a/vendor/github.com/casbin/casbin/util/util.go b/vendor/github.com/casbin/casbin/util/util.go new file mode 100644 index 00000000..4ce033d5 --- /dev/null +++ b/vendor/github.com/casbin/casbin/util/util.go @@ -0,0 +1,105 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "sort" + "strings" +) + +// EscapeAssertion escapes the dots in the assertion, because the expression evaluation doesn't support such variable names. +func EscapeAssertion(s string) string { + s = strings.Replace(s, "r.", "r_", -1) + s = strings.Replace(s, "p.", "p_", -1) + return s +} + +// RemoveComments removes the comments starting with # in the text. +func RemoveComments(s string) string { + pos := strings.Index(s, "#") + if pos == -1 { + return s + } + return strings.TrimSpace(s[0:pos]) +} + +// ArrayEquals determines whether two string arrays are identical. +func ArrayEquals(a []string, b []string) bool { + if len(a) != len(b) { + return false + } + + for i, v := range a { + if v != b[i] { + return false + } + } + return true +} + +// Array2DEquals determines whether two 2-dimensional string arrays are identical. +func Array2DEquals(a [][]string, b [][]string) bool { + if len(a) != len(b) { + return false + } + + for i, v := range a { + if !ArrayEquals(v, b[i]) { + return false + } + } + return true +} + +// ArrayRemoveDuplicates removes any duplicated elements in a string array. +func ArrayRemoveDuplicates(s *[]string) { + found := make(map[string]bool) + j := 0 + for i, x := range *s { + if !found[x] { + found[x] = true + (*s)[j] = (*s)[i] + j++ + } + } + *s = (*s)[:j] +} + +// ArrayToString gets a printable string for a string array. +func ArrayToString(s []string) string { + return strings.Join(s, ", ") +} + +// ParamsToString gets a printable string for variable number of parameters. +func ParamsToString(s ...string) string { + return strings.Join(s, ", ") +} + +// SetEquals determines whether two string sets are identical. +func SetEquals(a []string, b []string) bool { + if len(a) != len(b) { + return false + } + + sort.Strings(a) + sort.Strings(b) + + for i, v := range a { + if v != b[i] { + return false + } + } + return true +} diff --git a/vendor/github.com/cloudflare/golz4/.gitignore b/vendor/github.com/cloudflare/golz4/.gitignore new file mode 100644 index 00000000..00268614 --- /dev/null +++ b/vendor/github.com/cloudflare/golz4/.gitignore @@ -0,0 +1,22 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe diff --git a/vendor/github.com/cloudflare/golz4/LICENSE b/vendor/github.com/cloudflare/golz4/LICENSE new file mode 100644 index 00000000..1579e81a --- /dev/null +++ b/vendor/github.com/cloudflare/golz4/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2013 CloudFlare, Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, this + list of conditions and the following disclaimer in the documentation and/or + other materials provided with the distribution. + +* Neither the name of the CloudFlare, Inc. nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/cloudflare/golz4/Makefile b/vendor/github.com/cloudflare/golz4/Makefile new file mode 100644 index 00000000..2296d80e --- /dev/null +++ b/vendor/github.com/cloudflare/golz4/Makefile @@ -0,0 +1,14 @@ +GCFLAGS := +LDFLAGS := + +.PHONY: install +install: + @go install -v . + +.PHONY: test +test: + @go test -gcflags='$(GCFLAGS)' -ldflags='$(LDFLAGS)' . + +.PHONY: bench +bench: + @go test -gcflags='$(GCFLAGS)' -ldflags='$(LDFLAGS)' -bench . diff --git a/vendor/github.com/cloudflare/golz4/README.md b/vendor/github.com/cloudflare/golz4/README.md new file mode 100644 index 00000000..e1bdb26e --- /dev/null +++ b/vendor/github.com/cloudflare/golz4/README.md @@ -0,0 +1,4 @@ +golz4 +===== + +Golang interface to LZ4 compression diff --git a/vendor/github.com/cloudflare/golz4/doc.go b/vendor/github.com/cloudflare/golz4/doc.go new file mode 100644 index 00000000..4876be87 --- /dev/null +++ b/vendor/github.com/cloudflare/golz4/doc.go @@ -0,0 +1,4 @@ +// Package lz4 implements compression using lz4.c and lz4hc.c +// +// Copyright (c) 2013 CloudFlare, Inc. +package lz4 diff --git a/vendor/github.com/cloudflare/golz4/lz4.go b/vendor/github.com/cloudflare/golz4/lz4.go new file mode 100644 index 00000000..f9abcb2d --- /dev/null +++ b/vendor/github.com/cloudflare/golz4/lz4.go @@ -0,0 +1,55 @@ +package lz4 + +// #cgo CFLAGS: -O3 +// #include "src/lz4.h" +// #include "src/lz4.c" +import "C" + +import ( + "errors" + "fmt" + "unsafe" +) + +// p gets a char pointer to the first byte of a []byte slice +func p(in []byte) *C.char { + if len(in) == 0 { + return (*C.char)(unsafe.Pointer(nil)) + } + return (*C.char)(unsafe.Pointer(&in[0])) +} + +// clen gets the length of a []byte slice as a char * +func clen(s []byte) C.int { + return C.int(len(s)) +} + +// Uncompress with a known output size. len(out) should be equal to +// the length of the uncompressed out. +func Uncompress(in, out []byte) (error) { + if int(C.LZ4_decompress_safe(p(in), p(out), clen(in), clen(out))) < 0 { + return errors.New("Malformed compression stream") + } + + return nil +} + +// CompressBound calculates the size of the output buffer needed by +// Compress. This is based on the following macro: +// +// #define LZ4_COMPRESSBOUND(isize) +// ((unsigned int)(isize) > (unsigned int)LZ4_MAX_INPUT_SIZE ? 0 : (isize) + ((isize)/255) + 16) +func CompressBound(in []byte) int { + return len(in) + ((len(in) / 255) + 16) +} + +// Compress compresses in and puts the content in out. len(out) +// should have enough space for the compressed data (use CompressBound +// to calculate). Returns the number of bytes in the out slice. +func Compress(in, out []byte) (outSize int, err error) { + outSize = int(C.LZ4_compress_limitedOutput(p(in), p(out), clen(in), clen(out))) + if outSize == 0 { + err = fmt.Errorf("insufficient space for compression") + } + return +} diff --git a/vendor/github.com/cloudflare/golz4/lz4_hc.go b/vendor/github.com/cloudflare/golz4/lz4_hc.go new file mode 100644 index 00000000..9779352c --- /dev/null +++ b/vendor/github.com/cloudflare/golz4/lz4_hc.go @@ -0,0 +1,38 @@ +package lz4 + +// #cgo CFLAGS: -O3 +// #include "src/lz4hc.h" +// #include "src/lz4hc.c" +import "C" + +import ( + "fmt" +) + +// CompressHC compresses in and puts the content in out. len(out) +// should have enough space for the compressed data (use CompressBound +// to calculate). Returns the number of bytes in the out slice. Determines +// the compression level automatically. +func CompressHC(in, out []byte) (int, error) { + // 0 automatically sets the compression level. + return CompressHCLevel(in, out, 0) +} + +// CompressHCLevel compresses in at the given compression level and puts the +// content in out. len(out) should have enough space for the compressed data +// (use CompressBound to calculate). Returns the number of bytes in the out +// slice. To automatically choose the compression level, use 0. Otherwise, use +// any value in the inclusive range 1 (worst) through 16 (best). Most +// applications will prefer CompressHC. +func CompressHCLevel(in, out []byte, level int) (outSize int, err error) { + // LZ4HC does not handle empty buffers. Pass through to Compress. + if len(in) == 0 || len(out) == 0 { + return Compress(in, out) + } + + outSize = int(C.LZ4_compressHC2_limitedOutput(p(in), p(out), clen(in), clen(out), C.int(level))) + if outSize == 0 { + err = fmt.Errorf("insufficient space for compression") + } + return +} diff --git a/vendor/github.com/cloudflare/golz4/sample.txt b/vendor/github.com/cloudflare/golz4/sample.txt new file mode 100644 index 00000000..3bb27fb7 --- /dev/null +++ b/vendor/github.com/cloudflare/golz4/sample.txt @@ -0,0 +1,143 @@ +CANTO I + + +IN the midway of this our mortal life, +I found me in a gloomy wood, astray +Gone from the path direct: and e'en to tell +It were no easy task, how savage wild +That forest, how robust and rough its growth, +Which to remember only, my dismay +Renews, in bitterness not far from death. +Yet to discourse of what there good befell, +All else will I relate discover'd there. +How first I enter'd it I scarce can say, +Such sleepy dullness in that instant weigh'd +My senses down, when the true path I left, +But when a mountain's foot I reach'd, where clos'd +The valley, that had pierc'd my heart with dread, +I look'd aloft, and saw his shoulders broad +Already vested with that planet's beam, +Who leads all wanderers safe through every way. + +Then was a little respite to the fear, +That in my heart's recesses deep had lain, +All of that night, so pitifully pass'd: +And as a man, with difficult short breath, +Forespent with toiling, 'scap'd from sea to shore, +Turns to the perilous wide waste, and stands +At gaze; e'en so my spirit, that yet fail'd +Struggling with terror, turn'd to view the straits, +That none hath pass'd and liv'd. My weary frame +After short pause recomforted, again +I journey'd on over that lonely steep, + +The hinder foot still firmer. Scarce the ascent +Began, when, lo! a panther, nimble, light, +And cover'd with a speckled skin, appear'd, +Nor, when it saw me, vanish'd, rather strove +To check my onward going; that ofttimes +With purpose to retrace my steps I turn'd. + +The hour was morning's prime, and on his way +Aloft the sun ascended with those stars, +That with him rose, when Love divine first mov'd +Those its fair works: so that with joyous hope +All things conspir'd to fill me, the gay skin +Of that swift animal, the matin dawn +And the sweet season. Soon that joy was chas'd, +And by new dread succeeded, when in view +A lion came, 'gainst me, as it appear'd, + +With his head held aloft and hunger-mad, +That e'en the air was fear-struck. A she-wolf +Was at his heels, who in her leanness seem'd +Full of all wants, and many a land hath made +Disconsolate ere now. She with such fear +O'erwhelmed me, at the sight of her appall'd, +That of the height all hope I lost. As one, +Who with his gain elated, sees the time +When all unwares is gone, he inwardly +Mourns with heart-griping anguish; such was I, +Haunted by that fell beast, never at peace, +Who coming o'er against me, by degrees +Impell'd me where the sun in silence rests. + +While to the lower space with backward step +I fell, my ken discern'd the form one of one, +Whose voice seem'd faint through long disuse of speech. +When him in that great desert I espied, +"Have mercy on me!" cried I out aloud, +"Spirit! or living man! what e'er thou be!" + +He answer'd: "Now not man, man once I was, +And born of Lombard parents, Mantuana both +By country, when the power of Julius yet +Was scarcely firm. At Rome my life was past +Beneath the mild Augustus, in the time +Of fabled deities and false. A bard +Was I, and made Anchises' upright son +The subject of my song, who came from Troy, +When the flames prey'd on Ilium's haughty towers. +But thou, say wherefore to such perils past +Return'st thou? wherefore not this pleasant mount +Ascendest, cause and source of all delight?" +"And art thou then that Virgil, that well-spring, +From which such copious floods of eloquence +Have issued?" I with front abash'd replied. +"Glory and light of all the tuneful train! +May it avail me that I long with zeal +Have sought thy volume, and with love immense +Have conn'd it o'er. My master thou and guide! +Thou he from whom alone I have deriv'd +That style, which for its beauty into fame +Exalts me. See the beast, from whom I fled. +O save me from her, thou illustrious sage!" + +"For every vein and pulse throughout my frame +She hath made tremble." He, soon as he saw +That I was weeping, answer'd, "Thou must needs +Another way pursue, if thou wouldst 'scape +From out that savage wilderness. This beast, +At whom thou criest, her way will suffer none +To pass, and no less hindrance makes than death: +So bad and so accursed in her kind, +That never sated is her ravenous will, +Still after food more craving than before. +To many an animal in wedlock vile +She fastens, and shall yet to many more, +Until that greyhound come, who shall destroy +Her with sharp pain. He will not life support +By earth nor its base metals, but by love, +Wisdom, and virtue, and his land shall be +The land 'twixt either Feltro. In his might +Shall safety to Italia's plains arise, +For whose fair realm, Camilla, virgin pure, +Nisus, Euryalus, and Turnus fell. +He with incessant chase through every town +Shall worry, until he to hell at length +Restore her, thence by envy first let loose. +I for thy profit pond'ring now devise, +That thou mayst follow me, and I thy guide +Will lead thee hence through an eternal space, +Where thou shalt hear despairing shrieks, and see +Spirits of old tormented, who invoke +A second death; and those next view, who dwell +Content in fire, for that they hope to come, +Whene'er the time may be, among the blest, +Into whose regions if thou then desire +T' ascend, a spirit worthier then I +Must lead thee, in whose charge, when I depart, +Thou shalt be left: for that Almighty King, +Who reigns above, a rebel to his law, +Adjudges me, and therefore hath decreed, +That to his city none through me should come. +He in all parts hath sway; there rules, there holds +His citadel and throne. O happy those, +Whom there he chooses!" I to him in few: +"Bard! by that God, whom thou didst not adore, +I do beseech thee (that this ill and worse +I may escape) to lead me, where thou saidst, +That I Saint Peter's gate may view, and those +Who as thou tell'st, are in such dismal plight." + +Onward he mov'd, I close his steps pursu'd. diff --git a/vendor/github.com/cloudflare/golz4/src/lz4.c b/vendor/github.com/cloudflare/golz4/src/lz4.c new file mode 100644 index 00000000..ed928ced --- /dev/null +++ b/vendor/github.com/cloudflare/golz4/src/lz4.c @@ -0,0 +1,1367 @@ +/* + LZ4 - Fast LZ compression algorithm + Copyright (C) 2011-2015, Yann Collet. + BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php) + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following disclaimer + in the documentation and/or other materials provided with the + distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + You can contact the author at : + - LZ4 source repository : http://code.google.com/p/lz4 + - LZ4 source mirror : https://github.com/Cyan4973/lz4 + - LZ4 public forum : https://groups.google.com/forum/#!forum/lz4c +*/ + + +/************************************** + Tuning parameters +**************************************/ +/* + * HEAPMODE : + * Select how default compression functions will allocate memory for their hash table, + * in memory stack (0:default, fastest), or in memory heap (1:requires malloc()). + */ +#define HEAPMODE 0 + +/* + * CPU_HAS_EFFICIENT_UNALIGNED_MEMORY_ACCESS : + * By default, the source code expects the compiler to correctly optimize + * 4-bytes and 8-bytes read on architectures able to handle it efficiently. + * This is not always the case. In some circumstances (ARM notably), + * the compiler will issue cautious code even when target is able to correctly handle unaligned memory accesses. + * + * You can force the compiler to use unaligned memory access by uncommenting the line below. + * One of the below scenarios will happen : + * 1 - Your target CPU correctly handle unaligned access, and was not well optimized by compiler (good case). + * You will witness large performance improvements (+50% and up). + * Keep the line uncommented and send a word to upstream (https://groups.google.com/forum/#!forum/lz4c) + * The goal is to automatically detect such situations by adding your target CPU within an exception list. + * 2 - Your target CPU correctly handle unaligned access, and was already already optimized by compiler + * No change will be experienced. + * 3 - Your target CPU inefficiently handle unaligned access. + * You will experience a performance loss. Comment back the line. + * 4 - Your target CPU does not handle unaligned access. + * Program will crash. + * If uncommenting results in better performance (case 1) + * please report your configuration to upstream (https://groups.google.com/forum/#!forum/lz4c) + * An automatic detection macro will be added to match your case within future versions of the library. + */ +/* #define CPU_HAS_EFFICIENT_UNALIGNED_MEMORY_ACCESS 1 */ + + +/************************************** + CPU Feature Detection +**************************************/ +/* + * Automated efficient unaligned memory access detection + * Based on known hardware architectures + * This list will be updated thanks to feedbacks + */ +#if defined(CPU_HAS_EFFICIENT_UNALIGNED_MEMORY_ACCESS) \ + || defined(__ARM_FEATURE_UNALIGNED) \ + || defined(__i386__) || defined(__x86_64__) \ + || defined(_M_IX86) || defined(_M_X64) \ + || defined(__ARM_ARCH_7__) || defined(__ARM_ARCH_8__) \ + || (defined(_M_ARM) && (_M_ARM >= 7)) +# define LZ4_UNALIGNED_ACCESS 1 +#else +# define LZ4_UNALIGNED_ACCESS 0 +#endif + +/* + * LZ4_FORCE_SW_BITCOUNT + * Define this parameter if your target system or compiler does not support hardware bit count + */ +#if defined(_MSC_VER) && defined(_WIN32_WCE) /* Visual Studio for Windows CE does not support Hardware bit count */ +# define LZ4_FORCE_SW_BITCOUNT +#endif + + +/************************************** + Compiler Options +**************************************/ +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */ +/* "restrict" is a known keyword */ +#else +# define restrict /* Disable restrict */ +#endif + +#ifdef _MSC_VER /* Visual Studio */ +# define FORCE_INLINE static __forceinline +# include +# pragma warning(disable : 4127) /* disable: C4127: conditional expression is constant */ +# pragma warning(disable : 4293) /* disable: C4293: too large shift (32-bits) */ +#else +# if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */ +# ifdef __GNUC__ +# define FORCE_INLINE static inline __attribute__((always_inline)) +# else +# define FORCE_INLINE static inline +# endif +# else +# define FORCE_INLINE static +# endif /* __STDC_VERSION__ */ +#endif /* _MSC_VER */ + +#define GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) + +#if (GCC_VERSION >= 302) || (__INTEL_COMPILER >= 800) || defined(__clang__) +# define expect(expr,value) (__builtin_expect ((expr),(value)) ) +#else +# define expect(expr,value) (expr) +#endif + +#define likely(expr) expect((expr) != 0, 1) +#define unlikely(expr) expect((expr) != 0, 0) + + +/************************************** + Memory routines +**************************************/ +#include /* malloc, calloc, free */ +#define ALLOCATOR(n,s) calloc(n,s) +#define FREEMEM free +#include /* memset, memcpy */ +#define MEM_INIT memset + + +/************************************** + Includes +**************************************/ +#include "lz4.h" + + +/************************************** + Basic Types +**************************************/ +#if defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */ +# include + typedef uint8_t BYTE; + typedef uint16_t U16; + typedef uint32_t U32; + typedef int32_t S32; + typedef uint64_t U64; +#else + typedef unsigned char BYTE; + typedef unsigned short U16; + typedef unsigned int U32; + typedef signed int S32; + typedef unsigned long long U64; +#endif + + +/************************************** + Reading and writing into memory +**************************************/ +#define STEPSIZE sizeof(size_t) + +static unsigned LZ4_64bits(void) { return sizeof(void*)==8; } + +static unsigned LZ4_isLittleEndian(void) +{ + const union { U32 i; BYTE c[4]; } one = { 1 }; /* don't use static : performance detrimental */ + return one.c[0]; +} + + +static U16 LZ4_readLE16(const void* memPtr) +{ + if ((LZ4_UNALIGNED_ACCESS) && (LZ4_isLittleEndian())) + return *(U16*)memPtr; + else + { + const BYTE* p = memPtr; + return (U16)((U16)p[0] + (p[1]<<8)); + } +} + +static void LZ4_writeLE16(void* memPtr, U16 value) +{ + if ((LZ4_UNALIGNED_ACCESS) && (LZ4_isLittleEndian())) + { + *(U16*)memPtr = value; + return; + } + else + { + BYTE* p = memPtr; + p[0] = (BYTE) value; + p[1] = (BYTE)(value>>8); + } +} + + +static U16 LZ4_read16(const void* memPtr) +{ + if (LZ4_UNALIGNED_ACCESS) + return *(U16*)memPtr; + else + { + U16 val16; + memcpy(&val16, memPtr, 2); + return val16; + } +} + +static U32 LZ4_read32(const void* memPtr) +{ + if (LZ4_UNALIGNED_ACCESS) + return *(U32*)memPtr; + else + { + U32 val32; + memcpy(&val32, memPtr, 4); + return val32; + } +} + +static U64 LZ4_read64(const void* memPtr) +{ + if (LZ4_UNALIGNED_ACCESS) + return *(U64*)memPtr; + else + { + U64 val64; + memcpy(&val64, memPtr, 8); + return val64; + } +} + +static size_t LZ4_read_ARCH(const void* p) +{ + if (LZ4_64bits()) + return (size_t)LZ4_read64(p); + else + return (size_t)LZ4_read32(p); +} + + +static void LZ4_copy4(void* dstPtr, const void* srcPtr) +{ + if (LZ4_UNALIGNED_ACCESS) + { + *(U32*)dstPtr = *(U32*)srcPtr; + return; + } + memcpy(dstPtr, srcPtr, 4); +} + +static void LZ4_copy8(void* dstPtr, const void* srcPtr) +{ +#if GCC_VERSION!=409 /* disabled on GCC 4.9, as it generates invalid opcode (crash) */ + if (LZ4_UNALIGNED_ACCESS) + { + if (LZ4_64bits()) + *(U64*)dstPtr = *(U64*)srcPtr; + else + ((U32*)dstPtr)[0] = ((U32*)srcPtr)[0], + ((U32*)dstPtr)[1] = ((U32*)srcPtr)[1]; + return; + } +#endif + memcpy(dstPtr, srcPtr, 8); +} + +/* customized version of memcpy, which may overwrite up to 7 bytes beyond dstEnd */ +static void LZ4_wildCopy(void* dstPtr, const void* srcPtr, void* dstEnd) +{ + BYTE* d = dstPtr; + const BYTE* s = srcPtr; + BYTE* e = dstEnd; + do { LZ4_copy8(d,s); d+=8; s+=8; } while (d>3); +# elif defined(__GNUC__) && (GCC_VERSION >= 304) && !defined(LZ4_FORCE_SW_BITCOUNT) + return (__builtin_ctzll((U64)val) >> 3); +# else + static const int DeBruijnBytePos[64] = { 0, 0, 0, 0, 0, 1, 1, 2, 0, 3, 1, 3, 1, 4, 2, 7, 0, 2, 3, 6, 1, 5, 3, 5, 1, 3, 4, 4, 2, 5, 6, 7, 7, 0, 1, 2, 3, 3, 4, 6, 2, 6, 5, 5, 3, 4, 5, 6, 7, 1, 2, 4, 6, 4, 4, 5, 7, 2, 6, 5, 7, 6, 7, 7 }; + return DeBruijnBytePos[((U64)((val & -(long long)val) * 0x0218A392CDABBD3FULL)) >> 58]; +# endif + } + else /* 32 bits */ + { +# if defined(_MSC_VER) && !defined(LZ4_FORCE_SW_BITCOUNT) + unsigned long r; + _BitScanForward( &r, (U32)val ); + return (int)(r>>3); +# elif defined(__GNUC__) && (GCC_VERSION >= 304) && !defined(LZ4_FORCE_SW_BITCOUNT) + return (__builtin_ctz((U32)val) >> 3); +# else + static const int DeBruijnBytePos[32] = { 0, 0, 3, 0, 3, 1, 3, 0, 3, 2, 2, 1, 3, 2, 0, 1, 3, 3, 1, 2, 2, 2, 2, 0, 3, 1, 2, 0, 1, 0, 1, 1 }; + return DeBruijnBytePos[((U32)((val & -(S32)val) * 0x077CB531U)) >> 27]; +# endif + } + } + else /* Big Endian CPU */ + { + if (LZ4_64bits()) + { +# if defined(_MSC_VER) && defined(_WIN64) && !defined(LZ4_FORCE_SW_BITCOUNT) + unsigned long r = 0; + _BitScanReverse64( &r, val ); + return (unsigned)(r>>3); +# elif defined(__GNUC__) && (GCC_VERSION >= 304) && !defined(LZ4_FORCE_SW_BITCOUNT) + return (__builtin_clzll(val) >> 3); +# else + unsigned r; + if (!(val>>32)) { r=4; } else { r=0; val>>=32; } + if (!(val>>16)) { r+=2; val>>=8; } else { val>>=24; } + r += (!val); + return r; +# endif + } + else /* 32 bits */ + { +# if defined(_MSC_VER) && !defined(LZ4_FORCE_SW_BITCOUNT) + unsigned long r = 0; + _BitScanReverse( &r, (unsigned long)val ); + return (unsigned)(r>>3); +# elif defined(__GNUC__) && (GCC_VERSION >= 304) && !defined(LZ4_FORCE_SW_BITCOUNT) + return (__builtin_clz(val) >> 3); +# else + unsigned r; + if (!(val>>16)) { r=2; val>>=8; } else { r=0; val>>=24; } + r += (!val); + return r; +# endif + } + } +} + +static unsigned LZ4_count(const BYTE* pIn, const BYTE* pMatch, const BYTE* pInLimit) +{ + const BYTE* const pStart = pIn; + + while (likely(pIn compression run slower on incompressible data */ + + +/************************************** + Local Utils +**************************************/ +int LZ4_versionNumber (void) { return LZ4_VERSION_NUMBER; } +int LZ4_compressBound(int isize) { return LZ4_COMPRESSBOUND(isize); } + + +/************************************** + Local Structures and types +**************************************/ +typedef struct { + U32 hashTable[HASH_SIZE_U32]; + U32 currentOffset; + U32 initCheck; + const BYTE* dictionary; + const BYTE* bufferStart; + U32 dictSize; +} LZ4_stream_t_internal; + +typedef enum { notLimited = 0, limitedOutput = 1 } limitedOutput_directive; +typedef enum { byPtr, byU32, byU16 } tableType_t; + +typedef enum { noDict = 0, withPrefix64k, usingExtDict } dict_directive; +typedef enum { noDictIssue = 0, dictSmall } dictIssue_directive; + +typedef enum { endOnOutputSize = 0, endOnInputSize = 1 } endCondition_directive; +typedef enum { full = 0, partial = 1 } earlyEnd_directive; + + + +/******************************** + Compression functions +********************************/ + +static U32 LZ4_hashSequence(U32 sequence, tableType_t tableType) +{ + if (tableType == byU16) + return (((sequence) * 2654435761U) >> ((MINMATCH*8)-(LZ4_HASHLOG+1))); + else + return (((sequence) * 2654435761U) >> ((MINMATCH*8)-LZ4_HASHLOG)); +} + +static U32 LZ4_hashPosition(const BYTE* p, tableType_t tableType) { return LZ4_hashSequence(LZ4_read32(p), tableType); } + +static void LZ4_putPositionOnHash(const BYTE* p, U32 h, void* tableBase, tableType_t tableType, const BYTE* srcBase) +{ + switch (tableType) + { + case byPtr: { const BYTE** hashTable = (const BYTE**)tableBase; hashTable[h] = p; return; } + case byU32: { U32* hashTable = (U32*) tableBase; hashTable[h] = (U32)(p-srcBase); return; } + case byU16: { U16* hashTable = (U16*) tableBase; hashTable[h] = (U16)(p-srcBase); return; } + } +} + +static void LZ4_putPosition(const BYTE* p, void* tableBase, tableType_t tableType, const BYTE* srcBase) +{ + U32 h = LZ4_hashPosition(p, tableType); + LZ4_putPositionOnHash(p, h, tableBase, tableType, srcBase); +} + +static const BYTE* LZ4_getPositionOnHash(U32 h, void* tableBase, tableType_t tableType, const BYTE* srcBase) +{ + if (tableType == byPtr) { const BYTE** hashTable = (const BYTE**) tableBase; return hashTable[h]; } + if (tableType == byU32) { U32* hashTable = (U32*) tableBase; return hashTable[h] + srcBase; } + { U16* hashTable = (U16*) tableBase; return hashTable[h] + srcBase; } /* default, to ensure a return */ +} + +static const BYTE* LZ4_getPosition(const BYTE* p, void* tableBase, tableType_t tableType, const BYTE* srcBase) +{ + U32 h = LZ4_hashPosition(p, tableType); + return LZ4_getPositionOnHash(h, tableBase, tableType, srcBase); +} + +static int LZ4_compress_generic( + void* ctx, + const char* source, + char* dest, + int inputSize, + int maxOutputSize, + limitedOutput_directive outputLimited, + tableType_t tableType, + dict_directive dict, + dictIssue_directive dictIssue) +{ + LZ4_stream_t_internal* const dictPtr = (LZ4_stream_t_internal*)ctx; + + const BYTE* ip = (const BYTE*) source; + const BYTE* base; + const BYTE* lowLimit; + const BYTE* const lowRefLimit = ip - dictPtr->dictSize; + const BYTE* const dictionary = dictPtr->dictionary; + const BYTE* const dictEnd = dictionary + dictPtr->dictSize; + const size_t dictDelta = dictEnd - (const BYTE*)source; + const BYTE* anchor = (const BYTE*) source; + const BYTE* const iend = ip + inputSize; + const BYTE* const mflimit = iend - MFLIMIT; + const BYTE* const matchlimit = iend - LASTLITERALS; + + BYTE* op = (BYTE*) dest; + BYTE* const olimit = op + maxOutputSize; + + U32 forwardH; + size_t refDelta=0; + + /* Init conditions */ + if ((U32)inputSize > (U32)LZ4_MAX_INPUT_SIZE) return 0; /* Unsupported input size, too large (or negative) */ + switch(dict) + { + case noDict: + default: + base = (const BYTE*)source; + lowLimit = (const BYTE*)source; + break; + case withPrefix64k: + base = (const BYTE*)source - dictPtr->currentOffset; + lowLimit = (const BYTE*)source - dictPtr->dictSize; + break; + case usingExtDict: + base = (const BYTE*)source - dictPtr->currentOffset; + lowLimit = (const BYTE*)source; + break; + } + if ((tableType == byU16) && (inputSize>=LZ4_64Klimit)) return 0; /* Size too large (not within 64K limit) */ + if (inputSize> LZ4_skipTrigger; + + if (unlikely(forwardIp > mflimit)) goto _last_literals; + + match = LZ4_getPositionOnHash(h, ctx, tableType, base); + if (dict==usingExtDict) + { + if (match<(const BYTE*)source) + { + refDelta = dictDelta; + lowLimit = dictionary; + } + else + { + refDelta = 0; + lowLimit = (const BYTE*)source; + } + } + forwardH = LZ4_hashPosition(forwardIp, tableType); + LZ4_putPositionOnHash(ip, h, ctx, tableType, base); + + } while ( ((dictIssue==dictSmall) ? (match < lowRefLimit) : 0) + || ((tableType==byU16) ? 0 : (match + MAX_DISTANCE < ip)) + || (LZ4_read32(match+refDelta) != LZ4_read32(ip)) ); + } + + /* Catch up */ + while ((ip>anchor) && (match+refDelta > lowLimit) && (unlikely(ip[-1]==match[refDelta-1]))) { ip--; match--; } + + { + /* Encode Literal length */ + unsigned litLength = (unsigned)(ip - anchor); + token = op++; + if ((outputLimited) && (unlikely(op + litLength + (2 + 1 + LASTLITERALS) + (litLength/255) > olimit))) + return 0; /* Check output limit */ + if (litLength>=RUN_MASK) + { + int len = (int)litLength-RUN_MASK; + *token=(RUN_MASK<= 255 ; len-=255) *op++ = 255; + *op++ = (BYTE)len; + } + else *token = (BYTE)(litLength< matchlimit) limit = matchlimit; + matchLength = LZ4_count(ip+MINMATCH, match+MINMATCH, limit); + ip += MINMATCH + matchLength; + if (ip==limit) + { + unsigned more = LZ4_count(ip, (const BYTE*)source, matchlimit); + matchLength += more; + ip += more; + } + } + else + { + matchLength = LZ4_count(ip+MINMATCH, match+MINMATCH, matchlimit); + ip += MINMATCH + matchLength; + } + + if ((outputLimited) && (unlikely(op + (1 + LASTLITERALS) + (matchLength>>8) > olimit))) + return 0; /* Check output limit */ + if (matchLength>=ML_MASK) + { + *token += ML_MASK; + matchLength -= ML_MASK; + for (; matchLength >= 510 ; matchLength-=510) { *op++ = 255; *op++ = 255; } + if (matchLength >= 255) { matchLength-=255; *op++ = 255; } + *op++ = (BYTE)matchLength; + } + else *token += (BYTE)(matchLength); + } + + anchor = ip; + + /* Test end of chunk */ + if (ip > mflimit) break; + + /* Fill table */ + LZ4_putPosition(ip-2, ctx, tableType, base); + + /* Test next position */ + match = LZ4_getPosition(ip, ctx, tableType, base); + if (dict==usingExtDict) + { + if (match<(const BYTE*)source) + { + refDelta = dictDelta; + lowLimit = dictionary; + } + else + { + refDelta = 0; + lowLimit = (const BYTE*)source; + } + } + LZ4_putPosition(ip, ctx, tableType, base); + if ( ((dictIssue==dictSmall) ? (match>=lowRefLimit) : 1) + && (match+MAX_DISTANCE>=ip) + && (LZ4_read32(match+refDelta)==LZ4_read32(ip)) ) + { token=op++; *token=0; goto _next_match; } + + /* Prepare next loop */ + forwardH = LZ4_hashPosition(++ip, tableType); + } + +_last_literals: + /* Encode Last Literals */ + { + int lastRun = (int)(iend - anchor); + if ((outputLimited) && (((char*)op - dest) + lastRun + 1 + ((lastRun+255-RUN_MASK)/255) > (U32)maxOutputSize)) + return 0; /* Check output limit */ + if (lastRun>=(int)RUN_MASK) { *op++=(RUN_MASK<= 255 ; lastRun-=255) *op++ = 255; *op++ = (BYTE) lastRun; } + else *op++ = (BYTE)(lastRun<= sizeof(LZ4_stream_t_internal)); /* A compilation error here means LZ4_STREAMSIZE is not large enough */ + LZ4_resetStream(lz4s); + return lz4s; +} + +int LZ4_freeStream (LZ4_stream_t* LZ4_stream) +{ + FREEMEM(LZ4_stream); + return (0); +} + + +int LZ4_loadDict (LZ4_stream_t* LZ4_dict, const char* dictionary, int dictSize) +{ + LZ4_stream_t_internal* dict = (LZ4_stream_t_internal*) LZ4_dict; + const BYTE* p = (const BYTE*)dictionary; + const BYTE* const dictEnd = p + dictSize; + const BYTE* base; + + if (dict->initCheck) LZ4_resetStream(LZ4_dict); /* Uninitialized structure detected */ + + if (dictSize < MINMATCH) + { + dict->dictionary = NULL; + dict->dictSize = 0; + return 0; + } + + if (p <= dictEnd - 64 KB) p = dictEnd - 64 KB; + base = p - dict->currentOffset; + dict->dictionary = p; + dict->dictSize = (U32)(dictEnd - p); + dict->currentOffset += dict->dictSize; + + while (p <= dictEnd-MINMATCH) + { + LZ4_putPosition(p, dict, byU32, base); + p+=3; + } + + return dict->dictSize; +} + + +static void LZ4_renormDictT(LZ4_stream_t_internal* LZ4_dict, const BYTE* src) +{ + if ((LZ4_dict->currentOffset > 0x80000000) || + ((size_t)LZ4_dict->currentOffset > (size_t)src)) /* address space overflow */ + { + /* rescale hash table */ + U32 delta = LZ4_dict->currentOffset - 64 KB; + const BYTE* dictEnd = LZ4_dict->dictionary + LZ4_dict->dictSize; + int i; + for (i=0; ihashTable[i] < delta) LZ4_dict->hashTable[i]=0; + else LZ4_dict->hashTable[i] -= delta; + } + LZ4_dict->currentOffset = 64 KB; + if (LZ4_dict->dictSize > 64 KB) LZ4_dict->dictSize = 64 KB; + LZ4_dict->dictionary = dictEnd - LZ4_dict->dictSize; + } +} + + +FORCE_INLINE int LZ4_compress_continue_generic (void* LZ4_stream, const char* source, char* dest, int inputSize, + int maxOutputSize, limitedOutput_directive limit) +{ + LZ4_stream_t_internal* streamPtr = (LZ4_stream_t_internal*)LZ4_stream; + const BYTE* const dictEnd = streamPtr->dictionary + streamPtr->dictSize; + + const BYTE* smallest = (const BYTE*) source; + if (streamPtr->initCheck) return 0; /* Uninitialized structure detected */ + if ((streamPtr->dictSize>0) && (smallest>dictEnd)) smallest = dictEnd; + LZ4_renormDictT(streamPtr, smallest); + + /* Check overlapping input/dictionary space */ + { + const BYTE* sourceEnd = (const BYTE*) source + inputSize; + if ((sourceEnd > streamPtr->dictionary) && (sourceEnd < dictEnd)) + { + streamPtr->dictSize = (U32)(dictEnd - sourceEnd); + if (streamPtr->dictSize > 64 KB) streamPtr->dictSize = 64 KB; + if (streamPtr->dictSize < 4) streamPtr->dictSize = 0; + streamPtr->dictionary = dictEnd - streamPtr->dictSize; + } + } + + /* prefix mode : source data follows dictionary */ + if (dictEnd == (const BYTE*)source) + { + int result; + if ((streamPtr->dictSize < 64 KB) && (streamPtr->dictSize < streamPtr->currentOffset)) + result = LZ4_compress_generic(LZ4_stream, source, dest, inputSize, maxOutputSize, limit, byU32, withPrefix64k, dictSmall); + else + result = LZ4_compress_generic(LZ4_stream, source, dest, inputSize, maxOutputSize, limit, byU32, withPrefix64k, noDictIssue); + streamPtr->dictSize += (U32)inputSize; + streamPtr->currentOffset += (U32)inputSize; + return result; + } + + /* external dictionary mode */ + { + int result; + if ((streamPtr->dictSize < 64 KB) && (streamPtr->dictSize < streamPtr->currentOffset)) + result = LZ4_compress_generic(LZ4_stream, source, dest, inputSize, maxOutputSize, limit, byU32, usingExtDict, dictSmall); + else + result = LZ4_compress_generic(LZ4_stream, source, dest, inputSize, maxOutputSize, limit, byU32, usingExtDict, noDictIssue); + streamPtr->dictionary = (const BYTE*)source; + streamPtr->dictSize = (U32)inputSize; + streamPtr->currentOffset += (U32)inputSize; + return result; + } +} + +int LZ4_compress_continue (LZ4_stream_t* LZ4_stream, const char* source, char* dest, int inputSize) +{ + return LZ4_compress_continue_generic(LZ4_stream, source, dest, inputSize, 0, notLimited); +} + +int LZ4_compress_limitedOutput_continue (LZ4_stream_t* LZ4_stream, const char* source, char* dest, int inputSize, int maxOutputSize) +{ + return LZ4_compress_continue_generic(LZ4_stream, source, dest, inputSize, maxOutputSize, limitedOutput); +} + + +/* Hidden debug function, to force separate dictionary mode */ +int LZ4_compress_forceExtDict (LZ4_stream_t* LZ4_dict, const char* source, char* dest, int inputSize) +{ + LZ4_stream_t_internal* streamPtr = (LZ4_stream_t_internal*)LZ4_dict; + int result; + const BYTE* const dictEnd = streamPtr->dictionary + streamPtr->dictSize; + + const BYTE* smallest = dictEnd; + if (smallest > (const BYTE*) source) smallest = (const BYTE*) source; + LZ4_renormDictT((LZ4_stream_t_internal*)LZ4_dict, smallest); + + result = LZ4_compress_generic(LZ4_dict, source, dest, inputSize, 0, notLimited, byU32, usingExtDict, noDictIssue); + + streamPtr->dictionary = (const BYTE*)source; + streamPtr->dictSize = (U32)inputSize; + streamPtr->currentOffset += (U32)inputSize; + + return result; +} + + +int LZ4_saveDict (LZ4_stream_t* LZ4_dict, char* safeBuffer, int dictSize) +{ + LZ4_stream_t_internal* dict = (LZ4_stream_t_internal*) LZ4_dict; + const BYTE* previousDictEnd = dict->dictionary + dict->dictSize; + + if ((U32)dictSize > 64 KB) dictSize = 64 KB; /* useless to define a dictionary > 64 KB */ + if ((U32)dictSize > dict->dictSize) dictSize = dict->dictSize; + + memmove(safeBuffer, previousDictEnd - dictSize, dictSize); + + dict->dictionary = (const BYTE*)safeBuffer; + dict->dictSize = (U32)dictSize; + + return dictSize; +} + + + +/**************************** + Decompression functions +****************************/ +/* + * This generic decompression function cover all use cases. + * It shall be instantiated several times, using different sets of directives + * Note that it is essential this generic function is really inlined, + * in order to remove useless branches during compilation optimization. + */ +FORCE_INLINE int LZ4_decompress_generic( + const char* const source, + char* const dest, + int inputSize, + int outputSize, /* If endOnInput==endOnInputSize, this value is the max size of Output Buffer. */ + + int endOnInput, /* endOnOutputSize, endOnInputSize */ + int partialDecoding, /* full, partial */ + int targetOutputSize, /* only used if partialDecoding==partial */ + int dict, /* noDict, withPrefix64k, usingExtDict */ + const BYTE* const lowPrefix, /* == dest if dict == noDict */ + const BYTE* const dictStart, /* only if dict==usingExtDict */ + const size_t dictSize /* note : = 0 if noDict */ + ) +{ + /* Local Variables */ + const BYTE* restrict ip = (const BYTE*) source; + const BYTE* const iend = ip + inputSize; + + BYTE* op = (BYTE*) dest; + BYTE* const oend = op + outputSize; + BYTE* cpy; + BYTE* oexit = op + targetOutputSize; + const BYTE* const lowLimit = lowPrefix - dictSize; + + const BYTE* const dictEnd = (const BYTE*)dictStart + dictSize; + const size_t dec32table[] = {4, 1, 2, 1, 4, 4, 4, 4}; + const size_t dec64table[] = {0, 0, 0, (size_t)-1, 0, 1, 2, 3}; + + const int safeDecode = (endOnInput==endOnInputSize); + const int checkOffset = ((safeDecode) && (dictSize < (int)(64 KB))); + + + /* Special cases */ + if ((partialDecoding) && (oexit> oend-MFLIMIT)) oexit = oend-MFLIMIT; /* targetOutputSize too high => decode everything */ + if ((endOnInput) && (unlikely(outputSize==0))) return ((inputSize==1) && (*ip==0)) ? 0 : -1; /* Empty output buffer */ + if ((!endOnInput) && (unlikely(outputSize==0))) return (*ip==0?1:-1); + + + /* Main Loop */ + while (1) + { + unsigned token; + size_t length; + const BYTE* match; + + /* get literal length */ + token = *ip++; + if ((length=(token>>ML_BITS)) == RUN_MASK) + { + unsigned s; + do + { + s = *ip++; + length += s; + } + while (likely((endOnInput)?ip(partialDecoding?oexit:oend-MFLIMIT)) || (ip+length>iend-(2+1+LASTLITERALS))) ) + || ((!endOnInput) && (cpy>oend-COPYLENGTH))) + { + if (partialDecoding) + { + if (cpy > oend) goto _output_error; /* Error : write attempt beyond end of output buffer */ + if ((endOnInput) && (ip+length > iend)) goto _output_error; /* Error : read attempt beyond end of input buffer */ + } + else + { + if ((!endOnInput) && (cpy != oend)) goto _output_error; /* Error : block decoding must stop exactly there */ + if ((endOnInput) && ((ip+length != iend) || (cpy > oend))) goto _output_error; /* Error : input must be consumed */ + } + memcpy(op, ip, length); + ip += length; + op += length; + break; /* Necessarily EOF, due to parsing restrictions */ + } + LZ4_wildCopy(op, ip, cpy); + ip += length; op = cpy; + + /* get offset */ + match = cpy - LZ4_readLE16(ip); ip+=2; + if ((checkOffset) && (unlikely(match < lowLimit))) goto _output_error; /* Error : offset outside destination buffer */ + + /* get matchlength */ + length = token & ML_MASK; + if (length == ML_MASK) + { + unsigned s; + do + { + if ((endOnInput) && (ip > iend-LASTLITERALS)) goto _output_error; + s = *ip++; + length += s; + } while (s==255); + if ((safeDecode) && unlikely((size_t)(op+length)<(size_t)op)) goto _output_error; /* overflow detection */ + } + length += MINMATCH; + + /* check external dictionary */ + if ((dict==usingExtDict) && (match < lowPrefix)) + { + if (unlikely(op+length > oend-LASTLITERALS)) goto _output_error; /* doesn't respect parsing restriction */ + + if (length <= (size_t)(lowPrefix-match)) + { + /* match can be copied as a single segment from external dictionary */ + match = dictEnd - (lowPrefix-match); + memcpy(op, match, length); + op += length; + } + else + { + /* match encompass external dictionary and current segment */ + size_t copySize = (size_t)(lowPrefix-match); + memcpy(op, dictEnd - copySize, copySize); + op += copySize; + copySize = length - copySize; + if (copySize > (size_t)(op-lowPrefix)) /* overlap within current segment */ + { + BYTE* const endOfMatch = op + copySize; + const BYTE* copyFrom = lowPrefix; + while (op < endOfMatch) *op++ = *copyFrom++; + } + else + { + memcpy(op, lowPrefix, copySize); + op += copySize; + } + } + continue; + } + + /* copy repeated sequence */ + cpy = op + length; + if (unlikely((op-match)<8)) + { + const size_t dec64 = dec64table[op-match]; + op[0] = match[0]; + op[1] = match[1]; + op[2] = match[2]; + op[3] = match[3]; + match += dec32table[op-match]; + LZ4_copy4(op+4, match); + op += 8; match -= dec64; + } else { LZ4_copy8(op, match); op+=8; match+=8; } + + if (unlikely(cpy>oend-12)) + { + if (cpy > oend-LASTLITERALS) goto _output_error; /* Error : last LASTLITERALS bytes must be literals */ + if (op < oend-8) + { + LZ4_wildCopy(op, match, oend-8); + match += (oend-8) - op; + op = oend-8; + } + while (opprefixSize = (size_t) dictSize; + lz4sd->prefixEnd = (BYTE*) dictionary + dictSize; + lz4sd->externalDict = NULL; + lz4sd->extDictSize = 0; + return 1; +} + +/* +*_continue() : + These decoding functions allow decompression of multiple blocks in "streaming" mode. + Previously decoded blocks must still be available at the memory position where they were decoded. + If it's not possible, save the relevant part of decoded data into a safe buffer, + and indicate where it stands using LZ4_setStreamDecode() +*/ +int LZ4_decompress_safe_continue (LZ4_streamDecode_t* LZ4_streamDecode, const char* source, char* dest, int compressedSize, int maxOutputSize) +{ + LZ4_streamDecode_t_internal* lz4sd = (LZ4_streamDecode_t_internal*) LZ4_streamDecode; + int result; + + if (lz4sd->prefixEnd == (BYTE*)dest) + { + result = LZ4_decompress_generic(source, dest, compressedSize, maxOutputSize, + endOnInputSize, full, 0, + usingExtDict, lz4sd->prefixEnd - lz4sd->prefixSize, lz4sd->externalDict, lz4sd->extDictSize); + if (result <= 0) return result; + lz4sd->prefixSize += result; + lz4sd->prefixEnd += result; + } + else + { + lz4sd->extDictSize = lz4sd->prefixSize; + lz4sd->externalDict = lz4sd->prefixEnd - lz4sd->extDictSize; + result = LZ4_decompress_generic(source, dest, compressedSize, maxOutputSize, + endOnInputSize, full, 0, + usingExtDict, (BYTE*)dest, lz4sd->externalDict, lz4sd->extDictSize); + if (result <= 0) return result; + lz4sd->prefixSize = result; + lz4sd->prefixEnd = (BYTE*)dest + result; + } + + return result; +} + +int LZ4_decompress_fast_continue (LZ4_streamDecode_t* LZ4_streamDecode, const char* source, char* dest, int originalSize) +{ + LZ4_streamDecode_t_internal* lz4sd = (LZ4_streamDecode_t_internal*) LZ4_streamDecode; + int result; + + if (lz4sd->prefixEnd == (BYTE*)dest) + { + result = LZ4_decompress_generic(source, dest, 0, originalSize, + endOnOutputSize, full, 0, + usingExtDict, lz4sd->prefixEnd - lz4sd->prefixSize, lz4sd->externalDict, lz4sd->extDictSize); + if (result <= 0) return result; + lz4sd->prefixSize += originalSize; + lz4sd->prefixEnd += originalSize; + } + else + { + lz4sd->extDictSize = lz4sd->prefixSize; + lz4sd->externalDict = (BYTE*)dest - lz4sd->extDictSize; + result = LZ4_decompress_generic(source, dest, 0, originalSize, + endOnOutputSize, full, 0, + usingExtDict, (BYTE*)dest, lz4sd->externalDict, lz4sd->extDictSize); + if (result <= 0) return result; + lz4sd->prefixSize = originalSize; + lz4sd->prefixEnd = (BYTE*)dest + originalSize; + } + + return result; +} + + +/* +Advanced decoding functions : +*_usingDict() : + These decoding functions work the same as "_continue" ones, + the dictionary must be explicitly provided within parameters +*/ + +FORCE_INLINE int LZ4_decompress_usingDict_generic(const char* source, char* dest, int compressedSize, int maxOutputSize, int safe, const char* dictStart, int dictSize) +{ + if (dictSize==0) + return LZ4_decompress_generic(source, dest, compressedSize, maxOutputSize, safe, full, 0, noDict, (BYTE*)dest, NULL, 0); + if (dictStart+dictSize == dest) + { + if (dictSize >= (int)(64 KB - 1)) + return LZ4_decompress_generic(source, dest, compressedSize, maxOutputSize, safe, full, 0, withPrefix64k, (BYTE*)dest-64 KB, NULL, 0); + return LZ4_decompress_generic(source, dest, compressedSize, maxOutputSize, safe, full, 0, noDict, (BYTE*)dest-dictSize, NULL, 0); + } + return LZ4_decompress_generic(source, dest, compressedSize, maxOutputSize, safe, full, 0, usingExtDict, (BYTE*)dest, (BYTE*)dictStart, dictSize); +} + +int LZ4_decompress_safe_usingDict(const char* source, char* dest, int compressedSize, int maxOutputSize, const char* dictStart, int dictSize) +{ + return LZ4_decompress_usingDict_generic(source, dest, compressedSize, maxOutputSize, 1, dictStart, dictSize); +} + +int LZ4_decompress_fast_usingDict(const char* source, char* dest, int originalSize, const char* dictStart, int dictSize) +{ + return LZ4_decompress_usingDict_generic(source, dest, 0, originalSize, 0, dictStart, dictSize); +} + +/* debug function */ +int LZ4_decompress_safe_forceExtDict(const char* source, char* dest, int compressedSize, int maxOutputSize, const char* dictStart, int dictSize) +{ + return LZ4_decompress_generic(source, dest, compressedSize, maxOutputSize, endOnInputSize, full, 0, usingExtDict, (BYTE*)dest, (BYTE*)dictStart, dictSize); +} + + +/*************************************************** + Obsolete Functions +***************************************************/ +/* +These function names are deprecated and should no longer be used. +They are only provided here for compatibility with older user programs. +- LZ4_uncompress is totally equivalent to LZ4_decompress_fast +- LZ4_uncompress_unknownOutputSize is totally equivalent to LZ4_decompress_safe +*/ +int LZ4_uncompress (const char* source, char* dest, int outputSize) { return LZ4_decompress_fast(source, dest, outputSize); } +int LZ4_uncompress_unknownOutputSize (const char* source, char* dest, int isize, int maxOutputSize) { return LZ4_decompress_safe(source, dest, isize, maxOutputSize); } + + +/* Obsolete Streaming functions */ + +int LZ4_sizeofStreamState() { return LZ4_STREAMSIZE; } + +static void LZ4_init(LZ4_stream_t_internal* lz4ds, const BYTE* base) +{ + MEM_INIT(lz4ds, 0, LZ4_STREAMSIZE); + lz4ds->bufferStart = base; +} + +int LZ4_resetStreamState(void* state, const char* inputBuffer) +{ + if ((((size_t)state) & 3) != 0) return 1; /* Error : pointer is not aligned on 4-bytes boundary */ + LZ4_init((LZ4_stream_t_internal*)state, (const BYTE*)inputBuffer); + return 0; +} + +void* LZ4_create (const char* inputBuffer) +{ + void* lz4ds = ALLOCATOR(8, LZ4_STREAMSIZE_U64); + LZ4_init ((LZ4_stream_t_internal*)lz4ds, (const BYTE*)inputBuffer); + return lz4ds; +} + +char* LZ4_slideInputBuffer (void* LZ4_Data) +{ + LZ4_stream_t_internal* ctx = (LZ4_stream_t_internal*)LZ4_Data; + int dictSize = LZ4_saveDict((LZ4_stream_t*)ctx, (char*)ctx->bufferStart, 64 KB); + return (char*)(ctx->bufferStart + dictSize); +} + +/* Obsolete compresson functions using User-allocated state */ + +int LZ4_sizeofState() { return LZ4_STREAMSIZE; } + +int LZ4_compress_withState (void* state, const char* source, char* dest, int inputSize) +{ + if (((size_t)(state)&3) != 0) return 0; /* Error : state is not aligned on 4-bytes boundary */ + MEM_INIT(state, 0, LZ4_STREAMSIZE); + + if (inputSize < LZ4_64Klimit) + return LZ4_compress_generic(state, source, dest, inputSize, 0, notLimited, byU16, noDict, noDictIssue); + else + return LZ4_compress_generic(state, source, dest, inputSize, 0, notLimited, LZ4_64bits() ? byU32 : byPtr, noDict, noDictIssue); +} + +int LZ4_compress_limitedOutput_withState (void* state, const char* source, char* dest, int inputSize, int maxOutputSize) +{ + if (((size_t)(state)&3) != 0) return 0; /* Error : state is not aligned on 4-bytes boundary */ + MEM_INIT(state, 0, LZ4_STREAMSIZE); + + if (inputSize < LZ4_64Klimit) + return LZ4_compress_generic(state, source, dest, inputSize, maxOutputSize, limitedOutput, byU16, noDict, noDictIssue); + else + return LZ4_compress_generic(state, source, dest, inputSize, maxOutputSize, limitedOutput, LZ4_64bits() ? byU32 : byPtr, noDict, noDictIssue); +} + +/* Obsolete streaming decompression functions */ + +int LZ4_decompress_safe_withPrefix64k(const char* source, char* dest, int compressedSize, int maxOutputSize) +{ + return LZ4_decompress_generic(source, dest, compressedSize, maxOutputSize, endOnInputSize, full, 0, withPrefix64k, (BYTE*)dest - 64 KB, NULL, 64 KB); +} + +int LZ4_decompress_fast_withPrefix64k(const char* source, char* dest, int originalSize) +{ + return LZ4_decompress_generic(source, dest, 0, originalSize, endOnOutputSize, full, 0, withPrefix64k, (BYTE*)dest - 64 KB, NULL, 64 KB); +} + +#endif /* LZ4_COMMONDEFS_ONLY */ + diff --git a/vendor/github.com/cloudflare/golz4/src/lz4.h b/vendor/github.com/cloudflare/golz4/src/lz4.h new file mode 100644 index 00000000..7778caad --- /dev/null +++ b/vendor/github.com/cloudflare/golz4/src/lz4.h @@ -0,0 +1,315 @@ +/* + LZ4 - Fast LZ compression algorithm + Header File + Copyright (C) 2011-2014, Yann Collet. + BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php) + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following disclaimer + in the documentation and/or other materials provided with the + distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + You can contact the author at : + - LZ4 source repository : http://code.google.com/p/lz4/ + - LZ4 public forum : https://groups.google.com/forum/#!forum/lz4c +*/ +#pragma once + +#if defined (__cplusplus) +extern "C" { +#endif + +/* + * lz4.h provides raw compression format functions, for optimal performance and integration into programs. + * If you need to generate data using an inter-operable format (respecting the framing specification), + * please use lz4frame.h instead. +*/ + +/************************************** + Version +**************************************/ +#define LZ4_VERSION_MAJOR 1 /* for breaking interface changes */ +#define LZ4_VERSION_MINOR 5 /* for new (non-breaking) interface capabilities */ +#define LZ4_VERSION_RELEASE 0 /* for tweaks, bug-fixes, or development */ +#define LZ4_VERSION_NUMBER (LZ4_VERSION_MAJOR *100*100 + LZ4_VERSION_MINOR *100 + LZ4_VERSION_RELEASE) +int LZ4_versionNumber (void); + +/************************************** + Tuning parameter +**************************************/ +/* + * LZ4_MEMORY_USAGE : + * Memory usage formula : N->2^N Bytes (examples : 10 -> 1KB; 12 -> 4KB ; 16 -> 64KB; 20 -> 1MB; etc.) + * Increasing memory usage improves compression ratio + * Reduced memory usage can improve speed, due to cache effect + * Default value is 14, for 16KB, which nicely fits into Intel x86 L1 cache + */ +#define LZ4_MEMORY_USAGE 14 + + +/************************************** + Simple Functions +**************************************/ + +int LZ4_compress (const char* source, char* dest, int sourceSize); +int LZ4_decompress_safe (const char* source, char* dest, int compressedSize, int maxDecompressedSize); + +/* +LZ4_compress() : + Compresses 'sourceSize' bytes from 'source' into 'dest'. + Destination buffer must be already allocated, + and must be sized to handle worst cases situations (input data not compressible) + Worst case size evaluation is provided by function LZ4_compressBound() + inputSize : Max supported value is LZ4_MAX_INPUT_SIZE + return : the number of bytes written in buffer dest + or 0 if the compression fails + +LZ4_decompress_safe() : + compressedSize : is obviously the source size + maxDecompressedSize : is the size of the destination buffer, which must be already allocated. + return : the number of bytes decompressed into the destination buffer (necessarily <= maxDecompressedSize) + If the destination buffer is not large enough, decoding will stop and output an error code (<0). + If the source stream is detected malformed, the function will stop decoding and return a negative result. + This function is protected against buffer overflow exploits, + and never writes outside of output buffer, nor reads outside of input buffer. + It is also protected against malicious data packets. +*/ + + +/************************************** + Advanced Functions +**************************************/ +#define LZ4_MAX_INPUT_SIZE 0x7E000000 /* 2 113 929 216 bytes */ +#define LZ4_COMPRESSBOUND(isize) ((unsigned int)(isize) > (unsigned int)LZ4_MAX_INPUT_SIZE ? 0 : (isize) + ((isize)/255) + 16) + +/* +LZ4_compressBound() : + Provides the maximum size that LZ4 compression may output in a "worst case" scenario (input data not compressible) + This function is primarily useful for memory allocation purposes (output buffer size). + Macro LZ4_COMPRESSBOUND() is also provided for compilation-time evaluation (stack memory allocation for example). + + isize : is the input size. Max supported value is LZ4_MAX_INPUT_SIZE + return : maximum output size in a "worst case" scenario + or 0, if input size is too large ( > LZ4_MAX_INPUT_SIZE) +*/ +int LZ4_compressBound(int isize); + + +/* +LZ4_compress_limitedOutput() : + Compress 'sourceSize' bytes from 'source' into an output buffer 'dest' of maximum size 'maxOutputSize'. + If it cannot achieve it, compression will stop, and result of the function will be zero. + This saves time and memory on detecting non-compressible (or barely compressible) data. + This function never writes outside of provided output buffer. + + sourceSize : Max supported value is LZ4_MAX_INPUT_VALUE + maxOutputSize : is the size of the destination buffer (which must be already allocated) + return : the number of bytes written in buffer 'dest' + or 0 if compression fails +*/ +int LZ4_compress_limitedOutput (const char* source, char* dest, int sourceSize, int maxOutputSize); + + +/* +LZ4_compress_withState() : + Same compression functions, but using an externally allocated memory space to store compression state. + Use LZ4_sizeofState() to know how much memory must be allocated, + and then, provide it as 'void* state' to compression functions. +*/ +int LZ4_sizeofState(void); +int LZ4_compress_withState (void* state, const char* source, char* dest, int inputSize); +int LZ4_compress_limitedOutput_withState (void* state, const char* source, char* dest, int inputSize, int maxOutputSize); + + +/* +LZ4_decompress_fast() : + originalSize : is the original and therefore uncompressed size + return : the number of bytes read from the source buffer (in other words, the compressed size) + If the source stream is detected malformed, the function will stop decoding and return a negative result. + Destination buffer must be already allocated. Its size must be a minimum of 'originalSize' bytes. + note : This function fully respect memory boundaries for properly formed compressed data. + It is a bit faster than LZ4_decompress_safe(). + However, it does not provide any protection against intentionally modified data stream (malicious input). + Use this function in trusted environment only (data to decode comes from a trusted source). +*/ +int LZ4_decompress_fast (const char* source, char* dest, int originalSize); + + +/* +LZ4_decompress_safe_partial() : + This function decompress a compressed block of size 'compressedSize' at position 'source' + into destination buffer 'dest' of size 'maxDecompressedSize'. + The function tries to stop decompressing operation as soon as 'targetOutputSize' has been reached, + reducing decompression time. + return : the number of bytes decoded in the destination buffer (necessarily <= maxDecompressedSize) + Note : this number can be < 'targetOutputSize' should the compressed block to decode be smaller. + Always control how many bytes were decoded. + If the source stream is detected malformed, the function will stop decoding and return a negative result. + This function never writes outside of output buffer, and never reads outside of input buffer. It is therefore protected against malicious data packets +*/ +int LZ4_decompress_safe_partial (const char* source, char* dest, int compressedSize, int targetOutputSize, int maxDecompressedSize); + + +/*********************************************** + Streaming Compression Functions +***********************************************/ + +#define LZ4_STREAMSIZE_U64 ((1 << (LZ4_MEMORY_USAGE-3)) + 4) +#define LZ4_STREAMSIZE (LZ4_STREAMSIZE_U64 * sizeof(long long)) +/* + * LZ4_stream_t + * information structure to track an LZ4 stream. + * important : init this structure content before first use ! + * note : only allocated directly the structure if you are statically linking LZ4 + * If you are using liblz4 as a DLL, please use below construction methods instead. + */ +typedef struct { long long table[LZ4_STREAMSIZE_U64]; } LZ4_stream_t; + +/* + * LZ4_resetStream + * Use this function to init an allocated LZ4_stream_t structure + */ +void LZ4_resetStream (LZ4_stream_t* LZ4_streamPtr); + +/* + * LZ4_createStream will allocate and initialize an LZ4_stream_t structure + * LZ4_freeStream releases its memory. + * In the context of a DLL (liblz4), please use these methods rather than the static struct. + * They are more future proof, in case of a change of LZ4_stream_t size. + */ +LZ4_stream_t* LZ4_createStream(void); +int LZ4_freeStream (LZ4_stream_t* LZ4_streamPtr); + +/* + * LZ4_loadDict + * Use this function to load a static dictionary into LZ4_stream. + * Any previous data will be forgotten, only 'dictionary' will remain in memory. + * Loading a size of 0 is allowed. + * Return : dictionary size, in bytes (necessarily <= 64 KB) + */ +int LZ4_loadDict (LZ4_stream_t* LZ4_streamPtr, const char* dictionary, int dictSize); + +/* + * LZ4_compress_continue + * Compress data block 'source', using blocks compressed before as dictionary to improve compression ratio + * Previous data blocks are assumed to still be present at their previous location. + */ +int LZ4_compress_continue (LZ4_stream_t* LZ4_streamPtr, const char* source, char* dest, int inputSize); + +/* + * LZ4_compress_limitedOutput_continue + * Same as before, but also specify a maximum target compressed size (maxOutputSize) + * If objective cannot be met, compression exits, and returns a zero. + */ +int LZ4_compress_limitedOutput_continue (LZ4_stream_t* LZ4_streamPtr, const char* source, char* dest, int inputSize, int maxOutputSize); + +/* + * LZ4_saveDict + * If previously compressed data block is not guaranteed to remain available at its memory location + * save it into a safer place (char* safeBuffer) + * Note : you don't need to call LZ4_loadDict() afterwards, + * dictionary is immediately usable, you can therefore call again LZ4_compress_continue() + * Return : dictionary size in bytes, or 0 if error + * Note : any dictSize > 64 KB will be interpreted as 64KB. + */ +int LZ4_saveDict (LZ4_stream_t* LZ4_streamPtr, char* safeBuffer, int dictSize); + + +/************************************************ + Streaming Decompression Functions +************************************************/ + +#define LZ4_STREAMDECODESIZE_U64 4 +#define LZ4_STREAMDECODESIZE (LZ4_STREAMDECODESIZE_U64 * sizeof(unsigned long long)) +typedef struct { unsigned long long table[LZ4_STREAMDECODESIZE_U64]; } LZ4_streamDecode_t; +/* + * LZ4_streamDecode_t + * information structure to track an LZ4 stream. + * init this structure content using LZ4_setStreamDecode or memset() before first use ! + * + * In the context of a DLL (liblz4) please prefer usage of construction methods below. + * They are more future proof, in case of a change of LZ4_streamDecode_t size in the future. + * LZ4_createStreamDecode will allocate and initialize an LZ4_streamDecode_t structure + * LZ4_freeStreamDecode releases its memory. + */ +LZ4_streamDecode_t* LZ4_createStreamDecode(void); +int LZ4_freeStreamDecode (LZ4_streamDecode_t* LZ4_stream); + +/* + * LZ4_setStreamDecode + * Use this function to instruct where to find the dictionary. + * Setting a size of 0 is allowed (same effect as reset). + * Return : 1 if OK, 0 if error + */ +int LZ4_setStreamDecode (LZ4_streamDecode_t* LZ4_streamDecode, const char* dictionary, int dictSize); + +/* +*_continue() : + These decoding functions allow decompression of multiple blocks in "streaming" mode. + Previously decoded blocks *must* remain available at the memory position where they were decoded (up to 64 KB) + If this condition is not possible, save the relevant part of decoded data into a safe buffer, + and indicate where is its new address using LZ4_setStreamDecode() +*/ +int LZ4_decompress_safe_continue (LZ4_streamDecode_t* LZ4_streamDecode, const char* source, char* dest, int compressedSize, int maxDecompressedSize); +int LZ4_decompress_fast_continue (LZ4_streamDecode_t* LZ4_streamDecode, const char* source, char* dest, int originalSize); + + +/* +Advanced decoding functions : +*_usingDict() : + These decoding functions work the same as + a combination of LZ4_setDictDecode() followed by LZ4_decompress_x_continue() + They are stand-alone and don't use nor update an LZ4_streamDecode_t structure. +*/ +int LZ4_decompress_safe_usingDict (const char* source, char* dest, int compressedSize, int maxDecompressedSize, const char* dictStart, int dictSize); +int LZ4_decompress_fast_usingDict (const char* source, char* dest, int originalSize, const char* dictStart, int dictSize); + + + +/************************************** + Obsolete Functions +**************************************/ +/* +Obsolete decompression functions +These function names are deprecated and should no longer be used. +They are only provided here for compatibility with older user programs. +- LZ4_uncompress is the same as LZ4_decompress_fast +- LZ4_uncompress_unknownOutputSize is the same as LZ4_decompress_safe +These function prototypes are now disabled; uncomment them if you really need them. +It is highly recommended to stop using these functions and migrate to newer ones */ +/* int LZ4_uncompress (const char* source, char* dest, int outputSize); */ +/* int LZ4_uncompress_unknownOutputSize (const char* source, char* dest, int isize, int maxOutputSize); */ + + +/* Obsolete streaming functions; use new streaming interface whenever possible */ +void* LZ4_create (const char* inputBuffer); +int LZ4_sizeofStreamState(void); +int LZ4_resetStreamState(void* state, const char* inputBuffer); +char* LZ4_slideInputBuffer (void* state); + +/* Obsolete streaming decoding functions */ +int LZ4_decompress_safe_withPrefix64k (const char* source, char* dest, int compressedSize, int maxOutputSize); +int LZ4_decompress_fast_withPrefix64k (const char* source, char* dest, int originalSize); + + +#if defined (__cplusplus) +} +#endif diff --git a/vendor/github.com/cloudflare/golz4/src/lz4hc.c b/vendor/github.com/cloudflare/golz4/src/lz4hc.c new file mode 100644 index 00000000..5549969c --- /dev/null +++ b/vendor/github.com/cloudflare/golz4/src/lz4hc.c @@ -0,0 +1,751 @@ +/* +LZ4 HC - High Compression Mode of LZ4 +Copyright (C) 2011-2014, Yann Collet. +BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php) + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +* Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +You can contact the author at : +- LZ4 homepage : http://fastcompression.blogspot.com/p/lz4.html +- LZ4 source repository : http://code.google.com/p/lz4/ +*/ + + + +/************************************** + Tuning Parameter +**************************************/ +static const int LZ4HC_compressionLevel_default = 8; + + +/************************************** + Includes +**************************************/ +#include "lz4hc.h" + + +/************************************** + Local Compiler Options +**************************************/ +#if defined(__GNUC__) +# pragma GCC diagnostic ignored "-Wunused-function" +#endif + +#if defined (__clang__) +# pragma clang diagnostic ignored "-Wunused-function" +#endif + + +/************************************** + Common LZ4 definition +**************************************/ +#define LZ4_COMMONDEFS_ONLY +#include "lz4.c" + + +/************************************** + Local Constants +**************************************/ +#define DICTIONARY_LOGSIZE 16 +#define MAXD (1<> ((MINMATCH*8)-HASH_LOG)) +#define DELTANEXT(p) chainTable[(size_t)(p) & MAXD_MASK] +#define GETNEXT(p) ((p) - (size_t)DELTANEXT(p)) + +static U32 LZ4HC_hashPtr(const void* ptr) { return HASH_FUNCTION(LZ4_read32(ptr)); } + + + +/************************************** + HC Compression +**************************************/ +static void LZ4HC_init (LZ4HC_Data_Structure* hc4, const BYTE* start) +{ + MEM_INIT((void*)hc4->hashTable, 0, sizeof(hc4->hashTable)); + MEM_INIT(hc4->chainTable, 0xFF, sizeof(hc4->chainTable)); + hc4->nextToUpdate = 64 KB; + hc4->base = start - 64 KB; + hc4->inputBuffer = start; + hc4->end = start; + hc4->dictBase = start - 64 KB; + hc4->dictLimit = 64 KB; + hc4->lowLimit = 64 KB; +} + + +/* Update chains up to ip (excluded) */ +FORCE_INLINE void LZ4HC_Insert (LZ4HC_Data_Structure* hc4, const BYTE* ip) +{ + U16* chainTable = hc4->chainTable; + U32* HashTable = hc4->hashTable; + const BYTE* const base = hc4->base; + const U32 target = (U32)(ip - base); + U32 idx = hc4->nextToUpdate; + + while(idx < target) + { + U32 h = LZ4HC_hashPtr(base+idx); + size_t delta = idx - HashTable[h]; + if (delta>MAX_DISTANCE) delta = MAX_DISTANCE; + chainTable[idx & 0xFFFF] = (U16)delta; + HashTable[h] = idx; + idx++; + } + + hc4->nextToUpdate = target; +} + + +FORCE_INLINE int LZ4HC_InsertAndFindBestMatch (LZ4HC_Data_Structure* hc4, /* Index table will be updated */ + const BYTE* ip, const BYTE* const iLimit, + const BYTE** matchpos, + const int maxNbAttempts) +{ + U16* const chainTable = hc4->chainTable; + U32* const HashTable = hc4->hashTable; + const BYTE* const base = hc4->base; + const BYTE* const dictBase = hc4->dictBase; + const U32 dictLimit = hc4->dictLimit; + const U32 lowLimit = (hc4->lowLimit + 64 KB > (U32)(ip-base)) ? hc4->lowLimit : (U32)(ip - base) - (64 KB - 1); + U32 matchIndex; + const BYTE* match; + int nbAttempts=maxNbAttempts; + size_t ml=0; + + /* HC4 match finder */ + LZ4HC_Insert(hc4, ip); + matchIndex = HashTable[LZ4HC_hashPtr(ip)]; + + while ((matchIndex>=lowLimit) && (nbAttempts)) + { + nbAttempts--; + if (matchIndex >= dictLimit) + { + match = base + matchIndex; + if (*(match+ml) == *(ip+ml) + && (LZ4_read32(match) == LZ4_read32(ip))) + { + size_t mlt = LZ4_count(ip+MINMATCH, match+MINMATCH, iLimit) + MINMATCH; + if (mlt > ml) { ml = mlt; *matchpos = match; } + } + } + else + { + match = dictBase + matchIndex; + if (LZ4_read32(match) == LZ4_read32(ip)) + { + size_t mlt; + const BYTE* vLimit = ip + (dictLimit - matchIndex); + if (vLimit > iLimit) vLimit = iLimit; + mlt = LZ4_count(ip+MINMATCH, match+MINMATCH, vLimit) + MINMATCH; + if ((ip+mlt == vLimit) && (vLimit < iLimit)) + mlt += LZ4_count(ip+mlt, base+dictLimit, iLimit); + if (mlt > ml) { ml = mlt; *matchpos = base + matchIndex; } /* virtual matchpos */ + } + } + matchIndex -= chainTable[matchIndex & 0xFFFF]; + } + + return (int)ml; +} + + +FORCE_INLINE int LZ4HC_InsertAndGetWiderMatch ( + LZ4HC_Data_Structure* hc4, + const BYTE* ip, + const BYTE* iLowLimit, + const BYTE* iHighLimit, + int longest, + const BYTE** matchpos, + const BYTE** startpos, + const int maxNbAttempts) +{ + U16* const chainTable = hc4->chainTable; + U32* const HashTable = hc4->hashTable; + const BYTE* const base = hc4->base; + const U32 dictLimit = hc4->dictLimit; + const U32 lowLimit = (hc4->lowLimit + 64 KB > (U32)(ip-base)) ? hc4->lowLimit : (U32)(ip - base) - (64 KB - 1); + const BYTE* const dictBase = hc4->dictBase; + const BYTE* match; + U32 matchIndex; + int nbAttempts = maxNbAttempts; + int delta = (int)(ip-iLowLimit); + + + /* First Match */ + LZ4HC_Insert(hc4, ip); + matchIndex = HashTable[LZ4HC_hashPtr(ip)]; + + while ((matchIndex>=lowLimit) && (nbAttempts)) + { + nbAttempts--; + if (matchIndex >= dictLimit) + { + match = base + matchIndex; + if (*(iLowLimit + longest) == *(match - delta + longest)) + if (LZ4_read32(match) == LZ4_read32(ip)) + { + const BYTE* startt = ip; + const BYTE* tmpMatch = match; + const BYTE* const matchEnd = ip + MINMATCH + LZ4_count(ip+MINMATCH, match+MINMATCH, iHighLimit); + + while ((startt>iLowLimit) && (tmpMatch > iLowLimit) && (startt[-1] == tmpMatch[-1])) {startt--; tmpMatch--;} + + if ((matchEnd-startt) > longest) + { + longest = (int)(matchEnd-startt); + *matchpos = tmpMatch; + *startpos = startt; + } + } + } + else + { + match = dictBase + matchIndex; + if (LZ4_read32(match) == LZ4_read32(ip)) + { + size_t mlt; + int back=0; + const BYTE* vLimit = ip + (dictLimit - matchIndex); + if (vLimit > iHighLimit) vLimit = iHighLimit; + mlt = LZ4_count(ip+MINMATCH, match+MINMATCH, vLimit) + MINMATCH; + if ((ip+mlt == vLimit) && (vLimit < iHighLimit)) + mlt += LZ4_count(ip+mlt, base+dictLimit, iHighLimit); + while ((ip+back > iLowLimit) && (matchIndex+back > lowLimit) && (ip[back-1] == match[back-1])) back--; + mlt -= back; + if ((int)mlt > longest) { longest = (int)mlt; *matchpos = base + matchIndex + back; *startpos = ip+back; } + } + } + matchIndex -= chainTable[matchIndex & 0xFFFF]; + } + + return longest; +} + + +typedef enum { noLimit = 0, limitedOutput = 1 } limitedOutput_directive; + +#define LZ4HC_DEBUG 0 +#if LZ4HC_DEBUG +static unsigned debug = 0; +#endif + +FORCE_INLINE int LZ4HC_encodeSequence ( + const BYTE** ip, + BYTE** op, + const BYTE** anchor, + int matchLength, + const BYTE* const match, + limitedOutput_directive limitedOutputBuffer, + BYTE* oend) +{ + int length; + BYTE* token; + +#if LZ4HC_DEBUG + if (debug) printf("literal : %u -- match : %u -- offset : %u\n", (U32)(*ip - *anchor), (U32)matchLength, (U32)(*ip-match)); +#endif + + /* Encode Literal length */ + length = (int)(*ip - *anchor); + token = (*op)++; + if ((limitedOutputBuffer) && ((*op + (length>>8) + length + (2 + 1 + LASTLITERALS)) > oend)) return 1; /* Check output limit */ + if (length>=(int)RUN_MASK) { int len; *token=(RUN_MASK< 254 ; len-=255) *(*op)++ = 255; *(*op)++ = (BYTE)len; } + else *token = (BYTE)(length<>8) + (1 + LASTLITERALS) > oend)) return 1; /* Check output limit */ + if (length>=(int)ML_MASK) { *token+=ML_MASK; length-=ML_MASK; for(; length > 509 ; length-=510) { *(*op)++ = 255; *(*op)++ = 255; } if (length > 254) { length-=255; *(*op)++ = 255; } *(*op)++ = (BYTE)length; } + else *token += (BYTE)(length); + + /* Prepare next loop */ + *ip += matchLength; + *anchor = *ip; + + return 0; +} + + +static int LZ4HC_compress_generic ( + void* ctxvoid, + const char* source, + char* dest, + int inputSize, + int maxOutputSize, + int compressionLevel, + limitedOutput_directive limit + ) +{ + LZ4HC_Data_Structure* ctx = (LZ4HC_Data_Structure*) ctxvoid; + const BYTE* ip = (const BYTE*) source; + const BYTE* anchor = ip; + const BYTE* const iend = ip + inputSize; + const BYTE* const mflimit = iend - MFLIMIT; + const BYTE* const matchlimit = (iend - LASTLITERALS); + + BYTE* op = (BYTE*) dest; + BYTE* const oend = op + maxOutputSize; + + unsigned maxNbAttempts; + int ml, ml2, ml3, ml0; + const BYTE* ref=NULL; + const BYTE* start2=NULL; + const BYTE* ref2=NULL; + const BYTE* start3=NULL; + const BYTE* ref3=NULL; + const BYTE* start0; + const BYTE* ref0; + + + /* init */ + if (compressionLevel > g_maxCompressionLevel) compressionLevel = g_maxCompressionLevel; + if (compressionLevel < 1) compressionLevel = LZ4HC_compressionLevel_default; + maxNbAttempts = 1 << (compressionLevel-1); + ctx->end += inputSize; + + ip++; + + /* Main Loop */ + while (ip < mflimit) + { + ml = LZ4HC_InsertAndFindBestMatch (ctx, ip, matchlimit, (&ref), maxNbAttempts); + if (!ml) { ip++; continue; } + + /* saved, in case we would skip too much */ + start0 = ip; + ref0 = ref; + ml0 = ml; + +_Search2: + if (ip+ml < mflimit) + ml2 = LZ4HC_InsertAndGetWiderMatch(ctx, ip + ml - 2, ip + 1, matchlimit, ml, &ref2, &start2, maxNbAttempts); + else ml2 = ml; + + if (ml2 == ml) /* No better match */ + { + if (LZ4HC_encodeSequence(&ip, &op, &anchor, ml, ref, limit, oend)) return 0; + continue; + } + + if (start0 < ip) + { + if (start2 < ip + ml0) /* empirical */ + { + ip = start0; + ref = ref0; + ml = ml0; + } + } + + /* Here, start0==ip */ + if ((start2 - ip) < 3) /* First Match too small : removed */ + { + ml = ml2; + ip = start2; + ref =ref2; + goto _Search2; + } + +_Search3: + /* + * Currently we have : + * ml2 > ml1, and + * ip1+3 <= ip2 (usually < ip1+ml1) + */ + if ((start2 - ip) < OPTIMAL_ML) + { + int correction; + int new_ml = ml; + if (new_ml > OPTIMAL_ML) new_ml = OPTIMAL_ML; + if (ip+new_ml > start2 + ml2 - MINMATCH) new_ml = (int)(start2 - ip) + ml2 - MINMATCH; + correction = new_ml - (int)(start2 - ip); + if (correction > 0) + { + start2 += correction; + ref2 += correction; + ml2 -= correction; + } + } + /* Now, we have start2 = ip+new_ml, with new_ml = min(ml, OPTIMAL_ML=18) */ + + if (start2 + ml2 < mflimit) + ml3 = LZ4HC_InsertAndGetWiderMatch(ctx, start2 + ml2 - 3, start2, matchlimit, ml2, &ref3, &start3, maxNbAttempts); + else ml3 = ml2; + + if (ml3 == ml2) /* No better match : 2 sequences to encode */ + { + /* ip & ref are known; Now for ml */ + if (start2 < ip+ml) ml = (int)(start2 - ip); + /* Now, encode 2 sequences */ + if (LZ4HC_encodeSequence(&ip, &op, &anchor, ml, ref, limit, oend)) return 0; + ip = start2; + if (LZ4HC_encodeSequence(&ip, &op, &anchor, ml2, ref2, limit, oend)) return 0; + continue; + } + + if (start3 < ip+ml+3) /* Not enough space for match 2 : remove it */ + { + if (start3 >= (ip+ml)) /* can write Seq1 immediately ==> Seq2 is removed, so Seq3 becomes Seq1 */ + { + if (start2 < ip+ml) + { + int correction = (int)(ip+ml - start2); + start2 += correction; + ref2 += correction; + ml2 -= correction; + if (ml2 < MINMATCH) + { + start2 = start3; + ref2 = ref3; + ml2 = ml3; + } + } + + if (LZ4HC_encodeSequence(&ip, &op, &anchor, ml, ref, limit, oend)) return 0; + ip = start3; + ref = ref3; + ml = ml3; + + start0 = start2; + ref0 = ref2; + ml0 = ml2; + goto _Search2; + } + + start2 = start3; + ref2 = ref3; + ml2 = ml3; + goto _Search3; + } + + /* + * OK, now we have 3 ascending matches; let's write at least the first one + * ip & ref are known; Now for ml + */ + if (start2 < ip+ml) + { + if ((start2 - ip) < (int)ML_MASK) + { + int correction; + if (ml > OPTIMAL_ML) ml = OPTIMAL_ML; + if (ip + ml > start2 + ml2 - MINMATCH) ml = (int)(start2 - ip) + ml2 - MINMATCH; + correction = ml - (int)(start2 - ip); + if (correction > 0) + { + start2 += correction; + ref2 += correction; + ml2 -= correction; + } + } + else + { + ml = (int)(start2 - ip); + } + } + if (LZ4HC_encodeSequence(&ip, &op, &anchor, ml, ref, limit, oend)) return 0; + + ip = start2; + ref = ref2; + ml = ml2; + + start2 = start3; + ref2 = ref3; + ml2 = ml3; + + goto _Search3; + } + + /* Encode Last Literals */ + { + int lastRun = (int)(iend - anchor); + if ((limit) && (((char*)op - dest) + lastRun + 1 + ((lastRun+255-RUN_MASK)/255) > (U32)maxOutputSize)) return 0; /* Check output limit */ + if (lastRun>=(int)RUN_MASK) { *op++=(RUN_MASK< 254 ; lastRun-=255) *op++ = 255; *op++ = (BYTE) lastRun; } + else *op++ = (BYTE)(lastRun<base = NULL; + ((LZ4HC_Data_Structure*)LZ4_streamHCPtr)->compressionLevel = (unsigned)compressionLevel; +} + +int LZ4_loadDictHC (LZ4_streamHC_t* LZ4_streamHCPtr, const char* dictionary, int dictSize) +{ + LZ4HC_Data_Structure* ctxPtr = (LZ4HC_Data_Structure*) LZ4_streamHCPtr; + if (dictSize > 64 KB) + { + dictionary += dictSize - 64 KB; + dictSize = 64 KB; + } + LZ4HC_init (ctxPtr, (const BYTE*)dictionary); + if (dictSize >= 4) LZ4HC_Insert (ctxPtr, (const BYTE*)dictionary +(dictSize-3)); + ctxPtr->end = (const BYTE*)dictionary + dictSize; + return dictSize; +} + + +/* compression */ + +static void LZ4HC_setExternalDict(LZ4HC_Data_Structure* ctxPtr, const BYTE* newBlock) +{ + if (ctxPtr->end >= ctxPtr->base + 4) + LZ4HC_Insert (ctxPtr, ctxPtr->end-3); /* Referencing remaining dictionary content */ + /* Only one memory segment for extDict, so any previous extDict is lost at this stage */ + ctxPtr->lowLimit = ctxPtr->dictLimit; + ctxPtr->dictLimit = (U32)(ctxPtr->end - ctxPtr->base); + ctxPtr->dictBase = ctxPtr->base; + ctxPtr->base = newBlock - ctxPtr->dictLimit; + ctxPtr->end = newBlock; + ctxPtr->nextToUpdate = ctxPtr->dictLimit; /* match referencing will resume from there */ +} + +static int LZ4_compressHC_continue_generic (LZ4HC_Data_Structure* ctxPtr, + const char* source, char* dest, + int inputSize, int maxOutputSize, limitedOutput_directive limit) +{ + /* auto-init if forgotten */ + if (ctxPtr->base == NULL) + LZ4HC_init (ctxPtr, (const BYTE*) source); + + /* Check overflow */ + if ((size_t)(ctxPtr->end - ctxPtr->base) > 2 GB) + { + size_t dictSize = (size_t)(ctxPtr->end - ctxPtr->base) - ctxPtr->dictLimit; + if (dictSize > 64 KB) dictSize = 64 KB; + + LZ4_loadDictHC((LZ4_streamHC_t*)ctxPtr, (const char*)(ctxPtr->end) - dictSize, (int)dictSize); + } + + /* Check if blocks follow each other */ + if ((const BYTE*)source != ctxPtr->end) LZ4HC_setExternalDict(ctxPtr, (const BYTE*)source); + + /* Check overlapping input/dictionary space */ + { + const BYTE* sourceEnd = (const BYTE*) source + inputSize; + const BYTE* dictBegin = ctxPtr->dictBase + ctxPtr->lowLimit; + const BYTE* dictEnd = ctxPtr->dictBase + ctxPtr->dictLimit; + if ((sourceEnd > dictBegin) && ((BYTE*)source < dictEnd)) + { + if (sourceEnd > dictEnd) sourceEnd = dictEnd; + ctxPtr->lowLimit = (U32)(sourceEnd - ctxPtr->dictBase); + if (ctxPtr->dictLimit - ctxPtr->lowLimit < 4) ctxPtr->lowLimit = ctxPtr->dictLimit; + } + } + + return LZ4HC_compress_generic (ctxPtr, source, dest, inputSize, maxOutputSize, ctxPtr->compressionLevel, limit); +} + +int LZ4_compressHC_continue (LZ4_streamHC_t* LZ4_streamHCPtr, const char* source, char* dest, int inputSize) +{ + return LZ4_compressHC_continue_generic ((LZ4HC_Data_Structure*)LZ4_streamHCPtr, source, dest, inputSize, 0, noLimit); +} + +int LZ4_compressHC_limitedOutput_continue (LZ4_streamHC_t* LZ4_streamHCPtr, const char* source, char* dest, int inputSize, int maxOutputSize) +{ + return LZ4_compressHC_continue_generic ((LZ4HC_Data_Structure*)LZ4_streamHCPtr, source, dest, inputSize, maxOutputSize, limitedOutput); +} + + +/* dictionary saving */ + +int LZ4_saveDictHC (LZ4_streamHC_t* LZ4_streamHCPtr, char* safeBuffer, int dictSize) +{ + LZ4HC_Data_Structure* streamPtr = (LZ4HC_Data_Structure*)LZ4_streamHCPtr; + int prefixSize = (int)(streamPtr->end - (streamPtr->base + streamPtr->dictLimit)); + if (dictSize > 64 KB) dictSize = 64 KB; + if (dictSize < 4) dictSize = 0; + if (dictSize > prefixSize) dictSize = prefixSize; + memcpy(safeBuffer, streamPtr->end - dictSize, dictSize); + { + U32 endIndex = (U32)(streamPtr->end - streamPtr->base); + streamPtr->end = (const BYTE*)safeBuffer + dictSize; + streamPtr->base = streamPtr->end - endIndex; + streamPtr->dictLimit = endIndex - dictSize; + streamPtr->lowLimit = endIndex - dictSize; + if (streamPtr->nextToUpdate < streamPtr->dictLimit) streamPtr->nextToUpdate = streamPtr->dictLimit; + } + return dictSize; +} + + +/*********************************** + * Deprecated Functions + ***********************************/ +int LZ4_sizeofStreamStateHC(void) { return LZ4_STREAMHCSIZE; } + +int LZ4_resetStreamStateHC(void* state, const char* inputBuffer) +{ + if ((((size_t)state) & (sizeof(void*)-1)) != 0) return 1; /* Error : pointer is not aligned for pointer (32 or 64 bits) */ + LZ4HC_init((LZ4HC_Data_Structure*)state, (const BYTE*)inputBuffer); + return 0; +} + +void* LZ4_createHC (const char* inputBuffer) +{ + void* hc4 = ALLOCATOR(1, sizeof(LZ4HC_Data_Structure)); + LZ4HC_init ((LZ4HC_Data_Structure*)hc4, (const BYTE*)inputBuffer); + return hc4; +} + +int LZ4_freeHC (void* LZ4HC_Data) +{ + FREEMEM(LZ4HC_Data); + return (0); +} + +/* +int LZ4_compressHC_continue (void* LZ4HC_Data, const char* source, char* dest, int inputSize) +{ +return LZ4HC_compress_generic (LZ4HC_Data, source, dest, inputSize, 0, 0, noLimit); +} +int LZ4_compressHC_limitedOutput_continue (void* LZ4HC_Data, const char* source, char* dest, int inputSize, int maxOutputSize) +{ +return LZ4HC_compress_generic (LZ4HC_Data, source, dest, inputSize, maxOutputSize, 0, limitedOutput); +} +*/ + +int LZ4_compressHC2_continue (void* LZ4HC_Data, const char* source, char* dest, int inputSize, int compressionLevel) +{ + return LZ4HC_compress_generic (LZ4HC_Data, source, dest, inputSize, 0, compressionLevel, noLimit); +} + +int LZ4_compressHC2_limitedOutput_continue (void* LZ4HC_Data, const char* source, char* dest, int inputSize, int maxOutputSize, int compressionLevel) +{ + return LZ4HC_compress_generic (LZ4HC_Data, source, dest, inputSize, maxOutputSize, compressionLevel, limitedOutput); +} + +char* LZ4_slideInputBufferHC(void* LZ4HC_Data) +{ + LZ4HC_Data_Structure* hc4 = (LZ4HC_Data_Structure*)LZ4HC_Data; + int dictSize = LZ4_saveDictHC((LZ4_streamHC_t*)LZ4HC_Data, (char*)(hc4->inputBuffer), 64 KB); + return (char*)(hc4->inputBuffer + dictSize); +} diff --git a/vendor/github.com/cloudflare/golz4/src/lz4hc.h b/vendor/github.com/cloudflare/golz4/src/lz4hc.h new file mode 100644 index 00000000..ce813aba --- /dev/null +++ b/vendor/github.com/cloudflare/golz4/src/lz4hc.h @@ -0,0 +1,174 @@ +/* + LZ4 HC - High Compression Mode of LZ4 + Header File + Copyright (C) 2011-2014, Yann Collet. + BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php) + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following disclaimer + in the documentation and/or other materials provided with the + distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + You can contact the author at : + - LZ4 homepage : http://fastcompression.blogspot.com/p/lz4.html + - LZ4 source repository : http://code.google.com/p/lz4/ +*/ +#pragma once + + +#if defined (__cplusplus) +extern "C" { +#endif + + +int LZ4_compressHC (const char* source, char* dest, int inputSize); +/* +LZ4_compressHC : + return : the number of bytes in compressed buffer dest + or 0 if compression fails. + note : destination buffer must be already allocated. + To avoid any problem, size it to handle worst cases situations (input data not compressible) + Worst case size evaluation is provided by function LZ4_compressBound() (see "lz4.h") +*/ + +int LZ4_compressHC_limitedOutput (const char* source, char* dest, int inputSize, int maxOutputSize); +/* +LZ4_compress_limitedOutput() : + Compress 'inputSize' bytes from 'source' into an output buffer 'dest' of maximum size 'maxOutputSize'. + If it cannot achieve it, compression will stop, and result of the function will be zero. + This function never writes outside of provided output buffer. + + inputSize : Max supported value is 1 GB + maxOutputSize : is maximum allowed size into the destination buffer (which must be already allocated) + return : the number of output bytes written in buffer 'dest' + or 0 if compression fails. +*/ + + +int LZ4_compressHC2 (const char* source, char* dest, int inputSize, int compressionLevel); +int LZ4_compressHC2_limitedOutput (const char* source, char* dest, int inputSize, int maxOutputSize, int compressionLevel); +/* + Same functions as above, but with programmable 'compressionLevel'. + Recommended values are between 4 and 9, although any value between 0 and 16 will work. + 'compressionLevel'==0 means use default 'compressionLevel' value. + Values above 16 behave the same as 16. + Equivalent variants exist for all other compression functions below. +*/ + +/* Note : + Decompression functions are provided within LZ4 source code (see "lz4.h") (BSD license) +*/ + + +/************************************** + Using an external allocation +**************************************/ +int LZ4_sizeofStateHC(void); +int LZ4_compressHC_withStateHC (void* state, const char* source, char* dest, int inputSize); +int LZ4_compressHC_limitedOutput_withStateHC (void* state, const char* source, char* dest, int inputSize, int maxOutputSize); + +int LZ4_compressHC2_withStateHC (void* state, const char* source, char* dest, int inputSize, int compressionLevel); +int LZ4_compressHC2_limitedOutput_withStateHC(void* state, const char* source, char* dest, int inputSize, int maxOutputSize, int compressionLevel); + +/* +These functions are provided should you prefer to allocate memory for compression tables with your own allocation methods. +To know how much memory must be allocated for the compression tables, use : +int LZ4_sizeofStateHC(); + +Note that tables must be aligned for pointer (32 or 64 bits), otherwise compression will fail (return code 0). + +The allocated memory can be provided to the compression functions using 'void* state' parameter. +LZ4_compress_withStateHC() and LZ4_compress_limitedOutput_withStateHC() are equivalent to previously described functions. +They just use the externally allocated memory for state instead of allocating their own (on stack, or on heap). +*/ + + + +/************************************** + Experimental Streaming Functions +**************************************/ +#define LZ4_STREAMHCSIZE_U64 32774 +#define LZ4_STREAMHCSIZE (LZ4_STREAMHCSIZE_U64 * sizeof(unsigned long long)) +typedef struct { unsigned long long table[LZ4_STREAMHCSIZE_U64]; } LZ4_streamHC_t; +/* +LZ4_streamHC_t +This structure allows static allocation of LZ4 HC streaming state. +State must then be initialized using LZ4_resetStreamHC() before first use. + +Static allocation should only be used with statically linked library. +If you want to use LZ4 as a DLL, please use construction functions below, which are more future-proof. +*/ + + +LZ4_streamHC_t* LZ4_createStreamHC(void); +int LZ4_freeStreamHC (LZ4_streamHC_t* LZ4_streamHCPtr); +/* +These functions create and release memory for LZ4 HC streaming state. +Newly created states are already initialized. +Existing state space can be re-used anytime using LZ4_resetStreamHC(). +If you use LZ4 as a DLL, please use these functions instead of direct struct allocation, +to avoid size mismatch between different versions. +*/ + +void LZ4_resetStreamHC (LZ4_streamHC_t* LZ4_streamHCPtr, int compressionLevel); +int LZ4_loadDictHC (LZ4_streamHC_t* LZ4_streamHCPtr, const char* dictionary, int dictSize); + +int LZ4_compressHC_continue (LZ4_streamHC_t* LZ4_streamHCPtr, const char* source, char* dest, int inputSize); +int LZ4_compressHC_limitedOutput_continue (LZ4_streamHC_t* LZ4_streamHCPtr, const char* source, char* dest, int inputSize, int maxOutputSize); + +int LZ4_saveDictHC (LZ4_streamHC_t* LZ4_streamHCPtr, char* safeBuffer, int maxDictSize); + +/* +These functions compress data in successive blocks of any size, using previous blocks as dictionary. +One key assumption is that each previous block will remain read-accessible while compressing next block. + +Before starting compression, state must be properly initialized, using LZ4_resetStreamHC(). +A first "fictional block" can then be designated as initial dictionary, using LZ4_loadDictHC() (Optional). + +Then, use LZ4_compressHC_continue() or LZ4_compressHC_limitedOutput_continue() to compress each successive block. +They work like usual LZ4_compressHC() or LZ4_compressHC_limitedOutput(), but use previous memory blocks to improve compression. +Previous memory blocks (including initial dictionary when present) must remain accessible and unmodified during compression. + +If, for any reason, previous data block can't be preserved in memory during next compression block, +you must save it to a safer memory space, +using LZ4_saveDictHC(). +*/ + + + +/************************************** + * Deprecated Streaming Functions + * ************************************/ +/* Note : these streaming functions follows the older model, and should no longer be used */ +void* LZ4_createHC (const char* inputBuffer); +char* LZ4_slideInputBufferHC (void* LZ4HC_Data); +int LZ4_freeHC (void* LZ4HC_Data); + +int LZ4_compressHC2_continue (void* LZ4HC_Data, const char* source, char* dest, int inputSize, int compressionLevel); +int LZ4_compressHC2_limitedOutput_continue (void* LZ4HC_Data, const char* source, char* dest, int inputSize, int maxOutputSize, int compressionLevel); + +int LZ4_sizeofStreamStateHC(void); +int LZ4_resetStreamStateHC(void* state, const char* inputBuffer); + + +#if defined (__cplusplus) +} +#endif diff --git a/vendor/github.com/couchbase/go-couchbase/.gitignore b/vendor/github.com/couchbase/go-couchbase/.gitignore new file mode 100644 index 00000000..eda885ce --- /dev/null +++ b/vendor/github.com/couchbase/go-couchbase/.gitignore @@ -0,0 +1,14 @@ +#* +*.6 +*.a +*~ +*.swp +/examples/basic/basic +/hello/hello +/populate/populate +/tools/view2go/view2go +/tools/loadfile/loadfile +gotags.files +TAGS +6.out +_* diff --git a/vendor/github.com/couchbase/go-couchbase/.travis.yml b/vendor/github.com/couchbase/go-couchbase/.travis.yml new file mode 100644 index 00000000..4ecafb18 --- /dev/null +++ b/vendor/github.com/couchbase/go-couchbase/.travis.yml @@ -0,0 +1,5 @@ +language: go +install: go get -v -d ./... && go build -v ./... +script: go test -v ./... + +go: 1.1.1 diff --git a/vendor/github.com/couchbase/go-couchbase/LICENSE b/vendor/github.com/couchbase/go-couchbase/LICENSE new file mode 100644 index 00000000..0b23ef35 --- /dev/null +++ b/vendor/github.com/couchbase/go-couchbase/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2013 Couchbase, Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/couchbase/go-couchbase/README.markdown b/vendor/github.com/couchbase/go-couchbase/README.markdown new file mode 100644 index 00000000..bf5fe494 --- /dev/null +++ b/vendor/github.com/couchbase/go-couchbase/README.markdown @@ -0,0 +1,37 @@ +# A smart client for couchbase in go + +This is a *unoffical* version of a Couchbase Golang client. If you are +looking for the *Offical* Couchbase Golang client please see + [CB-go])[https://github.com/couchbaselabs/gocb]. + +This is an evolving package, but does provide a useful interface to a +[couchbase](http://www.couchbase.com/) server including all of the +pool/bucket discovery features, compatible key distribution with other +clients, and vbucket motion awareness so application can continue to +operate during rebalances. + +It also supports view querying with source node randomization so you +don't bang on all one node to do all the work. + +## Install + + go get github.com/couchbase/go-couchbase + +## Example + + c, err := couchbase.Connect("http://dev-couchbase.example.com:8091/") + if err != nil { + log.Fatalf("Error connecting: %v", err) + } + + pool, err := c.GetPool("default") + if err != nil { + log.Fatalf("Error getting pool: %v", err) + } + + bucket, err := pool.GetBucket("default") + if err != nil { + log.Fatalf("Error getting bucket: %v", err) + } + + bucket.Set("someKey", 0, []string{"an", "example", "list"}) diff --git a/vendor/github.com/couchbase/go-couchbase/audit.go b/vendor/github.com/couchbase/go-couchbase/audit.go new file mode 100644 index 00000000..3db7d9f9 --- /dev/null +++ b/vendor/github.com/couchbase/go-couchbase/audit.go @@ -0,0 +1,32 @@ +package couchbase + +import () + +// Sample data: +// {"disabled":["12333", "22244"],"uid":"132492431","auditdEnabled":true, +// "disabledUsers":[{"name":"bill","domain":"local"},{"name":"bob","domain":"local"}], +// "logPath":"/Users/johanlarson/Library/Application Support/Couchbase/var/lib/couchbase/logs", +// "rotateInterval":86400,"rotateSize":20971520} +type AuditSpec struct { + Disabled []uint32 `json:"disabled"` + Uid string `json:"uid"` + AuditdEnabled bool `json:"auditdEnabled` + DisabledUsers []AuditUser `json:"disabledUsers"` + LogPath string `json:"logPath"` + RotateInterval int64 `json:"rotateInterval"` + RotateSize int64 `json:"rotateSize"` +} + +type AuditUser struct { + Name string `json:"name"` + Domain string `json:"domain"` +} + +func (c *Client) GetAuditSpec() (*AuditSpec, error) { + ret := &AuditSpec{} + err := c.parseURLResponse("/settings/audit", ret) + if err != nil { + return nil, err + } + return ret, nil +} diff --git a/vendor/github.com/couchbase/go-couchbase/client.go b/vendor/github.com/couchbase/go-couchbase/client.go new file mode 100644 index 00000000..00a4a7de --- /dev/null +++ b/vendor/github.com/couchbase/go-couchbase/client.go @@ -0,0 +1,1436 @@ +/* +Package couchbase provides a smart client for go. + +Usage: + + client, err := couchbase.Connect("http://myserver:8091/") + handleError(err) + pool, err := client.GetPool("default") + handleError(err) + bucket, err := pool.GetBucket("MyAwesomeBucket") + handleError(err) + ... + +or a shortcut for the bucket directly + + bucket, err := couchbase.GetBucket("http://myserver:8091/", "default", "default") + +in any case, you can specify authentication credentials using +standard URL userinfo syntax: + + b, err := couchbase.GetBucket("http://bucketname:bucketpass@myserver:8091/", + "default", "bucket") +*/ +package couchbase + +import ( + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "runtime" + "strconv" + "strings" + "sync" + "time" + "unsafe" + + "github.com/couchbase/gomemcached" + "github.com/couchbase/gomemcached/client" // package name is 'memcached' + "github.com/couchbase/goutils/logging" +) + +// Mutation Token +type MutationToken struct { + VBid uint16 // vbucket id + Guard uint64 // vbuuid + Value uint64 // sequence number +} + +// Maximum number of times to retry a chunk of a bulk get on error. +var MaxBulkRetries = 5000 +var backOffDuration time.Duration = 100 * time.Millisecond +var MaxBackOffRetries = 25 // exponentail backOff result in over 30sec (25*13*0.1s) + +// If this is set to a nonzero duration, Do() and ViewCustom() will log a warning if the call +// takes longer than that. +var SlowServerCallWarningThreshold time.Duration + +func slowLog(startTime time.Time, format string, args ...interface{}) { + if elapsed := time.Now().Sub(startTime); elapsed > SlowServerCallWarningThreshold { + pc, _, _, _ := runtime.Caller(2) + caller := runtime.FuncForPC(pc).Name() + logging.Infof("go-couchbase: "+format+" in "+caller+" took "+elapsed.String(), args...) + } +} + +// Return true if error is KEY_ENOENT. Required by cbq-engine +func IsKeyEExistsError(err error) bool { + + res, ok := err.(*gomemcached.MCResponse) + if ok && res.Status == gomemcached.KEY_EEXISTS { + return true + } + + return false +} + +// Return true if error is KEY_ENOENT. Required by cbq-engine +func IsKeyNoEntError(err error) bool { + + res, ok := err.(*gomemcached.MCResponse) + if ok && res.Status == gomemcached.KEY_ENOENT { + return true + } + + return false +} + +// Return true if error suggests a bucket refresh is required. Required by cbq-engine +func IsRefreshRequired(err error) bool { + + res, ok := err.(*gomemcached.MCResponse) + if ok && (res.Status == gomemcached.NO_BUCKET || res.Status == gomemcached.NOT_MY_VBUCKET) { + return true + } + + return false +} + +// ClientOpCallback is called for each invocation of Do. +var ClientOpCallback func(opname, k string, start time.Time, err error) + +// Do executes a function on a memcached connection to the node owning key "k" +// +// Note that this automatically handles transient errors by replaying +// your function on a "not-my-vbucket" error, so don't assume +// your command will only be executed only once. +func (b *Bucket) Do(k string, f func(mc *memcached.Client, vb uint16) error) (err error) { + if SlowServerCallWarningThreshold > 0 { + defer slowLog(time.Now(), "call to Do(%q)", k) + } + + vb := b.VBHash(k) + + backOffAttempts := 0 + maxTries := len(b.Nodes()) * 2 + for i := 0; i < maxTries; i++ { + conn, pool, err := b.getConnectionToVBucket(vb) + if err != nil { + return err + } + + if DefaultTimeout > 0 { + conn.SetDeadline(getDeadline(noDeadline, DefaultTimeout)) + err = f(conn, uint16(vb)) + conn.SetDeadline(noDeadline) + } else { + err = f(conn, uint16(vb)) + } + + var st gomemcached.Status + var retry bool + + // MB-30967 / MB-31001 implement back off for transient errors + if i, ok := err.(*gomemcached.MCResponse); ok { + st = i.Status + switch st { + case gomemcached.NOT_MY_VBUCKET: + b.Refresh() + retry = true + case gomemcached.NOT_SUPPORTED: + retry = true + case gomemcached.ENOMEM: + fallthrough + case gomemcached.TMPFAIL: + backOffAttempts++ + retry = backOff(backOffAttempts, MaxBackOffRetries, backOffDuration, true) + default: + retry = false + } + } + + // MB-28842: in case of NMVB, check if the node is still part of the map + // and ditch the connection if it isn't. + if st == gomemcached.NOT_MY_VBUCKET && b.checkVBmap(pool.Node()) { + pool.Discard(conn) + } else { + pool.Return(conn) + } + + if !retry { + return err + } + } + + return fmt.Errorf("unable to complete action after %v attemps", + maxTries) +} + +// this is the same as Do, ie executes a function on a memcached connection to the node owning key "k" +// however it avoid setting the connection deadline, leaving the responsibility on the function +// all other considerations as Do() apply +func (b *Bucket) DoNoDeadline(k string, f func(mc *memcached.Client, vb uint16) error) (err error) { + if SlowServerCallWarningThreshold > 0 { + defer slowLog(time.Now(), "call to Do(%q)", k) + } + + vb := b.VBHash(k) + + backOffAttempts := 0 + maxTries := len(b.Nodes()) * 2 + for i := 0; i < maxTries; i++ { + conn, pool, err := b.getConnectionToVBucket(vb) + if err != nil { + return err + } + + err = f(conn, uint16(vb)) + + var st gomemcached.Status + var retry bool + + // MB-30967 / MB-31001 implement back off for transient errors + if i, ok := err.(*gomemcached.MCResponse); ok { + st = i.Status + switch st { + case gomemcached.NOT_MY_VBUCKET: + b.Refresh() + retry = true + case gomemcached.NOT_SUPPORTED: + retry = true + case gomemcached.ENOMEM: + fallthrough + case gomemcached.TMPFAIL: + backOffAttempts++ + retry = backOff(backOffAttempts, MaxBackOffRetries, backOffDuration, true) + default: + retry = false + } + } + + // MB-28842: in case of NMVB, check if the node is still part of the map + // and ditch the connection if it isn't. + if st == gomemcached.NOT_MY_VBUCKET && b.checkVBmap(pool.Node()) { + pool.Discard(conn) + } else { + pool.Return(conn) + } + + if !retry { + return err + } + } + + return fmt.Errorf("unable to complete action after %v attemps", + maxTries) +} + +type GatheredStats struct { + Server string + Stats map[string]string + Err error +} + +func getStatsParallel(sn string, b *Bucket, offset int, which string, + ch chan<- GatheredStats) { + pool := b.getConnPool(offset) + var gatheredStats GatheredStats + + conn, err := pool.Get() + defer func() { + pool.Return(conn) + ch <- gatheredStats + }() + + if err != nil { + gatheredStats = GatheredStats{Server: sn, Err: err} + } else { + sm, err := conn.StatsMap(which) + gatheredStats = GatheredStats{Server: sn, Stats: sm, Err: err} + } +} + +// GetStats gets a set of stats from all servers. +// +// Returns a map of server ID -> map of stat key to map value. +func (b *Bucket) GetStats(which string) map[string]map[string]string { + rv := map[string]map[string]string{} + for server, gs := range b.GatherStats(which) { + if len(gs.Stats) > 0 { + rv[server] = gs.Stats + } + } + return rv +} + +// GatherStats returns a map of server ID -> GatheredStats from all servers. +func (b *Bucket) GatherStats(which string) map[string]GatheredStats { + vsm := b.VBServerMap() + if vsm.ServerList == nil { + return nil + } + + // Go grab all the things at once. + ch := make(chan GatheredStats, len(vsm.ServerList)) + for i, sn := range vsm.ServerList { + go getStatsParallel(sn, b, i, which, ch) + } + + // Gather the results + rv := map[string]GatheredStats{} + for range vsm.ServerList { + gs := <-ch + rv[gs.Server] = gs + } + return rv +} + +// Get bucket count through the bucket stats +func (b *Bucket) GetCount(refresh bool) (count int64, err error) { + if refresh { + b.Refresh() + } + + var cnt int64 + for _, gs := range b.GatherStats("") { + if len(gs.Stats) > 0 { + cnt, err = strconv.ParseInt(gs.Stats["curr_items"], 10, 64) + if err != nil { + return 0, err + } + count += cnt + } + } + + return count, nil +} + +func isAuthError(err error) bool { + estr := err.Error() + return strings.Contains(estr, "Auth failure") +} + +func IsReadTimeOutError(err error) bool { + estr := err.Error() + return strings.Contains(estr, "read tcp") || + strings.Contains(estr, "i/o timeout") +} + +func isTimeoutError(err error) bool { + estr := err.Error() + return strings.Contains(estr, "i/o timeout") || + strings.Contains(estr, "connection timed out") || + strings.Contains(estr, "no route to host") +} + +// Errors that are not considered fatal for our fetch loop +func isConnError(err error) bool { + if err == io.EOF { + return true + } + estr := err.Error() + return strings.Contains(estr, "broken pipe") || + strings.Contains(estr, "connection reset") || + strings.Contains(estr, "connection refused") || + strings.Contains(estr, "connection pool is closed") +} + +func isOutOfBoundsError(err error) bool { + return strings.Contains(err.Error(), "Out of Bounds error") + +} + +func getDeadline(reqDeadline time.Time, duration time.Duration) time.Time { + if reqDeadline.IsZero() && duration > 0 { + return time.Now().Add(duration) + } + return reqDeadline +} + +func backOff(attempt, maxAttempts int, duration time.Duration, exponential bool) bool { + if attempt < maxAttempts { + // 0th attempt return immediately + if attempt > 0 { + if exponential { + duration = time.Duration(attempt) * duration + } + time.Sleep(duration) + } + return true + } + + return false +} + +func (b *Bucket) doBulkGet(vb uint16, keys []string, reqDeadline time.Time, + ch chan<- map[string]*gomemcached.MCResponse, ech chan<- error, subPaths []string, + eStatus *errorStatus) { + if SlowServerCallWarningThreshold > 0 { + defer slowLog(time.Now(), "call to doBulkGet(%d, %d keys)", vb, len(keys)) + } + + rv := _STRING_MCRESPONSE_POOL.Get() + attempts := 0 + backOffAttempts := 0 + done := false + bname := b.Name + for ; attempts < MaxBulkRetries && !done && !eStatus.errStatus; attempts++ { + + if len(b.VBServerMap().VBucketMap) < int(vb) { + //fatal + err := fmt.Errorf("vbmap smaller than requested for %v", bname) + logging.Errorf("go-couchbase: %v vb %d vbmap len %d", err.Error(), vb, len(b.VBServerMap().VBucketMap)) + ech <- err + return + } + + masterID := b.VBServerMap().VBucketMap[vb][0] + + if masterID < 0 { + // fatal + err := fmt.Errorf("No master node available for %v vb %d", bname, vb) + logging.Errorf("%v", err.Error()) + ech <- err + return + } + + // This stack frame exists to ensure we can clean up + // connection at a reasonable time. + err := func() error { + pool := b.getConnPool(masterID) + conn, err := pool.Get() + if err != nil { + if isAuthError(err) || isTimeoutError(err) { + logging.Errorf("Fatal Error %v : %v", bname, err) + ech <- err + return err + } else if isConnError(err) { + if !backOff(backOffAttempts, MaxBackOffRetries, backOffDuration, true) { + logging.Errorf("Connection Error %v : %v", bname, err) + ech <- err + return err + } + b.Refresh() + backOffAttempts++ + } + logging.Infof("Pool Get returned %v: %v", bname, err) + // retry + return nil + } + + conn.SetDeadline(getDeadline(reqDeadline, DefaultTimeout)) + err = conn.GetBulk(vb, keys, rv, subPaths) + conn.SetDeadline(noDeadline) + + discard := false + defer func() { + if discard { + pool.Discard(conn) + } else { + pool.Return(conn) + } + }() + + switch err.(type) { + case *gomemcached.MCResponse: + notSMaxTries := len(b.Nodes()) * 2 + st := err.(*gomemcached.MCResponse).Status + if st == gomemcached.NOT_MY_VBUCKET || (st == gomemcached.NOT_SUPPORTED && attempts < notSMaxTries) { + b.Refresh() + discard = b.checkVBmap(pool.Node()) + return nil // retry + } else if st == gomemcached.EBUSY || st == gomemcached.LOCKED { + if (attempts % (MaxBulkRetries / 100)) == 0 { + logging.Infof("Retrying Memcached error (%v) FOR %v(vbid:%d, keys:%v)", + err.Error(), bname, vb, keys) + } + return nil // retry + } else if (st == gomemcached.ENOMEM || st == gomemcached.TMPFAIL) && backOff(backOffAttempts, MaxBackOffRetries, backOffDuration, true) { + // MB-30967 / MB-31001 use backoff for TMPFAIL too + backOffAttempts++ + logging.Infof("Retrying Memcached error (%v) FOR %v(vbid:%d, keys:%v)", + err.Error(), bname, vb, keys) + return nil // retry + } + ech <- err + return err + case error: + if isOutOfBoundsError(err) { + // We got an out of bound error, retry the operation + return nil + } else if isConnError(err) && backOff(backOffAttempts, MaxBackOffRetries, backOffDuration, true) { + backOffAttempts++ + logging.Errorf("Connection Error: %s. Refreshing bucket %v (vbid:%v,keys:%v)", + err.Error(), bname, vb, keys) + discard = true + b.Refresh() + return nil // retry + } + ech <- err + ch <- rv + return err + } + + done = true + return nil + }() + + if err != nil { + return + } + } + + if attempts >= MaxBulkRetries { + err := fmt.Errorf("bulkget exceeded MaxBulkRetries for %v(vbid:%d,keys:%v)", bname, vb, keys) + logging.Errorf("%v", err.Error()) + ech <- err + } + + ch <- rv +} + +type errorStatus struct { + errStatus bool +} + +type vbBulkGet struct { + b *Bucket + ch chan<- map[string]*gomemcached.MCResponse + ech chan<- error + k uint16 + keys []string + reqDeadline time.Time + wg *sync.WaitGroup + subPaths []string + groupError *errorStatus +} + +const _NUM_CHANNELS = 5 + +var _NUM_CHANNEL_WORKERS = (runtime.NumCPU() + 1) / 2 +var DefaultDialTimeout = time.Duration(0) +var DefaultTimeout = time.Duration(0) +var noDeadline = time.Time{} + +// Buffer 4k requests per worker +var _VB_BULK_GET_CHANNELS []chan *vbBulkGet + +func InitBulkGet() { + + DefaultDialTimeout = 20 * time.Second + DefaultTimeout = 120 * time.Second + + memcached.SetDefaultDialTimeout(DefaultDialTimeout) + + _VB_BULK_GET_CHANNELS = make([]chan *vbBulkGet, _NUM_CHANNELS) + + for i := 0; i < _NUM_CHANNELS; i++ { + channel := make(chan *vbBulkGet, 16*1024*_NUM_CHANNEL_WORKERS) + _VB_BULK_GET_CHANNELS[i] = channel + + for j := 0; j < _NUM_CHANNEL_WORKERS; j++ { + go vbBulkGetWorker(channel) + } + } +} + +func vbBulkGetWorker(ch chan *vbBulkGet) { + defer func() { + // Workers cannot panic and die + recover() + go vbBulkGetWorker(ch) + }() + + for vbg := range ch { + vbDoBulkGet(vbg) + } +} + +func vbDoBulkGet(vbg *vbBulkGet) { + defer vbg.wg.Done() + defer func() { + // Workers cannot panic and die + recover() + }() + vbg.b.doBulkGet(vbg.k, vbg.keys, vbg.reqDeadline, vbg.ch, vbg.ech, vbg.subPaths, vbg.groupError) +} + +var _ERR_CHAN_FULL = fmt.Errorf("Data request queue full, aborting query.") + +func (b *Bucket) processBulkGet(kdm map[uint16][]string, reqDeadline time.Time, + ch chan<- map[string]*gomemcached.MCResponse, ech chan<- error, subPaths []string, + eStatus *errorStatus) { + + defer close(ch) + defer close(ech) + + wg := &sync.WaitGroup{} + + for k, keys := range kdm { + + // GetBulk() group has error donot Queue items for this group + if eStatus.errStatus { + break + } + + vbg := &vbBulkGet{ + b: b, + ch: ch, + ech: ech, + k: k, + keys: keys, + reqDeadline: reqDeadline, + wg: wg, + subPaths: subPaths, + groupError: eStatus, + } + + wg.Add(1) + + // Random int + // Right shift to avoid 8-byte alignment, and take low bits + c := (uintptr(unsafe.Pointer(vbg)) >> 4) % _NUM_CHANNELS + + select { + case _VB_BULK_GET_CHANNELS[c] <- vbg: + // No-op + default: + // Buffer full, abandon the bulk get + ech <- _ERR_CHAN_FULL + wg.Add(-1) + } + } + + // Wait for my vb bulk gets + wg.Wait() +} + +type multiError []error + +func (m multiError) Error() string { + if len(m) == 0 { + panic("Error of none") + } + + return fmt.Sprintf("{%v errors, starting with %v}", len(m), m[0].Error()) +} + +// Convert a stream of errors from ech into a multiError (or nil) and +// send down eout. +// +// At least one send is guaranteed on eout, but two is possible, so +// buffer the out channel appropriately. +func errorCollector(ech <-chan error, eout chan<- error, eStatus *errorStatus) { + defer func() { eout <- nil }() + var errs multiError + for e := range ech { + if !eStatus.errStatus && !IsKeyNoEntError(e) { + eStatus.errStatus = true + } + + errs = append(errs, e) + } + + if len(errs) > 0 { + eout <- errs + } +} + +// Fetches multiple keys concurrently, with []byte values +// +// This is a wrapper around GetBulk which converts all values returned +// by GetBulk from raw memcached responses into []byte slices. +// Returns one document for duplicate keys +func (b *Bucket) GetBulkRaw(keys []string) (map[string][]byte, error) { + + resp, eout := b.getBulk(keys, noDeadline, nil) + + rv := make(map[string][]byte, len(keys)) + for k, av := range resp { + rv[k] = av.Body + } + + b.ReleaseGetBulkPools(resp) + return rv, eout + +} + +// GetBulk fetches multiple keys concurrently. +// +// Unlike more convenient GETs, the entire response is returned in the +// map array for each key. Keys that were not found will not be included in +// the map. + +func (b *Bucket) GetBulk(keys []string, reqDeadline time.Time, subPaths []string) (map[string]*gomemcached.MCResponse, error) { + return b.getBulk(keys, reqDeadline, subPaths) +} + +func (b *Bucket) ReleaseGetBulkPools(rv map[string]*gomemcached.MCResponse) { + _STRING_MCRESPONSE_POOL.Put(rv) +} + +func (b *Bucket) getBulk(keys []string, reqDeadline time.Time, subPaths []string) (map[string]*gomemcached.MCResponse, error) { + kdm := _VB_STRING_POOL.Get() + defer _VB_STRING_POOL.Put(kdm) + for _, k := range keys { + if k != "" { + vb := uint16(b.VBHash(k)) + a, ok1 := kdm[vb] + if !ok1 { + a = _STRING_POOL.Get() + } + kdm[vb] = append(a, k) + } + } + + eout := make(chan error, 2) + groupErrorStatus := &errorStatus{} + + // processBulkGet will own both of these channels and + // guarantee they're closed before it returns. + ch := make(chan map[string]*gomemcached.MCResponse) + ech := make(chan error) + + go errorCollector(ech, eout, groupErrorStatus) + go b.processBulkGet(kdm, reqDeadline, ch, ech, subPaths, groupErrorStatus) + + var rv map[string]*gomemcached.MCResponse + + for m := range ch { + if rv == nil { + rv = m + continue + } + + for k, v := range m { + rv[k] = v + } + _STRING_MCRESPONSE_POOL.Put(m) + } + + return rv, <-eout +} + +// WriteOptions is the set of option flags availble for the Write +// method. They are ORed together to specify the desired request. +type WriteOptions int + +const ( + // Raw specifies that the value is raw []byte or nil; don't + // JSON-encode it. + Raw = WriteOptions(1 << iota) + // AddOnly indicates an item should only be written if it + // doesn't exist, otherwise ErrKeyExists is returned. + AddOnly + // Persist causes the operation to block until the server + // confirms the item is persisted. + Persist + // Indexable causes the operation to block until it's availble via the index. + Indexable + // Append indicates the given value should be appended to the + // existing value for the given key. + Append +) + +var optNames = []struct { + opt WriteOptions + name string +}{ + {Raw, "raw"}, + {AddOnly, "addonly"}, {Persist, "persist"}, + {Indexable, "indexable"}, {Append, "append"}, +} + +// String representation of WriteOptions +func (w WriteOptions) String() string { + f := []string{} + for _, on := range optNames { + if w&on.opt != 0 { + f = append(f, on.name) + w &= ^on.opt + } + } + if len(f) == 0 || w != 0 { + f = append(f, fmt.Sprintf("0x%x", int(w))) + } + return strings.Join(f, "|") +} + +// Error returned from Write with AddOnly flag, when key already exists in the bucket. +var ErrKeyExists = errors.New("key exists") + +// General-purpose value setter. +// +// The Set, Add and Delete methods are just wrappers around this. The +// interpretation of `v` depends on whether the `Raw` option is +// given. If it is, v must be a byte array or nil. (A nil value causes +// a delete.) If `Raw` is not given, `v` will be marshaled as JSON +// before being written. It must be JSON-marshalable and it must not +// be nil. +func (b *Bucket) Write(k string, flags, exp int, v interface{}, + opt WriteOptions) (err error) { + + if ClientOpCallback != nil { + defer func(t time.Time) { + ClientOpCallback(fmt.Sprintf("Write(%v)", opt), k, t, err) + }(time.Now()) + } + + var data []byte + if opt&Raw == 0 { + data, err = json.Marshal(v) + if err != nil { + return err + } + } else if v != nil { + data = v.([]byte) + } + + var res *gomemcached.MCResponse + err = b.Do(k, func(mc *memcached.Client, vb uint16) error { + if opt&AddOnly != 0 { + res, err = memcached.UnwrapMemcachedError( + mc.Add(vb, k, flags, exp, data)) + if err == nil && res.Status != gomemcached.SUCCESS { + if res.Status == gomemcached.KEY_EEXISTS { + err = ErrKeyExists + } else { + err = res + } + } + } else if opt&Append != 0 { + res, err = mc.Append(vb, k, data) + } else if data == nil { + res, err = mc.Del(vb, k) + } else { + res, err = mc.Set(vb, k, flags, exp, data) + } + + return err + }) + + if err == nil && (opt&(Persist|Indexable) != 0) { + err = b.WaitForPersistence(k, res.Cas, data == nil) + } + + return err +} + +func (b *Bucket) WriteWithMT(k string, flags, exp int, v interface{}, + opt WriteOptions) (mt *MutationToken, err error) { + + if ClientOpCallback != nil { + defer func(t time.Time) { + ClientOpCallback(fmt.Sprintf("WriteWithMT(%v)", opt), k, t, err) + }(time.Now()) + } + + var data []byte + if opt&Raw == 0 { + data, err = json.Marshal(v) + if err != nil { + return nil, err + } + } else if v != nil { + data = v.([]byte) + } + + var res *gomemcached.MCResponse + err = b.Do(k, func(mc *memcached.Client, vb uint16) error { + if opt&AddOnly != 0 { + res, err = memcached.UnwrapMemcachedError( + mc.Add(vb, k, flags, exp, data)) + if err == nil && res.Status != gomemcached.SUCCESS { + if res.Status == gomemcached.KEY_EEXISTS { + err = ErrKeyExists + } else { + err = res + } + } + } else if opt&Append != 0 { + res, err = mc.Append(vb, k, data) + } else if data == nil { + res, err = mc.Del(vb, k) + } else { + res, err = mc.Set(vb, k, flags, exp, data) + } + + if len(res.Extras) >= 16 { + vbuuid := uint64(binary.BigEndian.Uint64(res.Extras[0:8])) + seqNo := uint64(binary.BigEndian.Uint64(res.Extras[8:16])) + mt = &MutationToken{VBid: vb, Guard: vbuuid, Value: seqNo} + } + + return err + }) + + if err == nil && (opt&(Persist|Indexable) != 0) { + err = b.WaitForPersistence(k, res.Cas, data == nil) + } + + return mt, err +} + +// Set a value in this bucket with Cas and return the new Cas value +func (b *Bucket) Cas(k string, exp int, cas uint64, v interface{}) (uint64, error) { + return b.WriteCas(k, 0, exp, cas, v, 0) +} + +// Set a value in this bucket with Cas without json encoding it +func (b *Bucket) CasRaw(k string, exp int, cas uint64, v interface{}) (uint64, error) { + return b.WriteCas(k, 0, exp, cas, v, Raw) +} + +func (b *Bucket) WriteCas(k string, flags, exp int, cas uint64, v interface{}, + opt WriteOptions) (newCas uint64, err error) { + + if ClientOpCallback != nil { + defer func(t time.Time) { + ClientOpCallback(fmt.Sprintf("Write(%v)", opt), k, t, err) + }(time.Now()) + } + + var data []byte + if opt&Raw == 0 { + data, err = json.Marshal(v) + if err != nil { + return 0, err + } + } else if v != nil { + data = v.([]byte) + } + + var res *gomemcached.MCResponse + err = b.Do(k, func(mc *memcached.Client, vb uint16) error { + res, err = mc.SetCas(vb, k, flags, exp, cas, data) + return err + }) + + if err == nil && (opt&(Persist|Indexable) != 0) { + err = b.WaitForPersistence(k, res.Cas, data == nil) + } + + return res.Cas, err +} + +// Extended CAS operation. These functions will return the mutation token, i.e vbuuid & guard +func (b *Bucket) CasWithMeta(k string, flags int, exp int, cas uint64, v interface{}) (uint64, *MutationToken, error) { + return b.WriteCasWithMT(k, flags, exp, cas, v, 0) +} + +func (b *Bucket) CasWithMetaRaw(k string, flags int, exp int, cas uint64, v interface{}) (uint64, *MutationToken, error) { + return b.WriteCasWithMT(k, flags, exp, cas, v, Raw) +} + +func (b *Bucket) WriteCasWithMT(k string, flags, exp int, cas uint64, v interface{}, + opt WriteOptions) (newCas uint64, mt *MutationToken, err error) { + + if ClientOpCallback != nil { + defer func(t time.Time) { + ClientOpCallback(fmt.Sprintf("Write(%v)", opt), k, t, err) + }(time.Now()) + } + + var data []byte + if opt&Raw == 0 { + data, err = json.Marshal(v) + if err != nil { + return 0, nil, err + } + } else if v != nil { + data = v.([]byte) + } + + var res *gomemcached.MCResponse + err = b.Do(k, func(mc *memcached.Client, vb uint16) error { + res, err = mc.SetCas(vb, k, flags, exp, cas, data) + return err + }) + + if err != nil { + return 0, nil, err + } + + // check for extras + if len(res.Extras) >= 16 { + vbuuid := uint64(binary.BigEndian.Uint64(res.Extras[0:8])) + seqNo := uint64(binary.BigEndian.Uint64(res.Extras[8:16])) + vb := b.VBHash(k) + mt = &MutationToken{VBid: uint16(vb), Guard: vbuuid, Value: seqNo} + } + + if err == nil && (opt&(Persist|Indexable) != 0) { + err = b.WaitForPersistence(k, res.Cas, data == nil) + } + + return res.Cas, mt, err +} + +// Set a value in this bucket. +// The value will be serialized into a JSON document. +func (b *Bucket) Set(k string, exp int, v interface{}) error { + return b.Write(k, 0, exp, v, 0) +} + +// Set a value in this bucket with with flags +func (b *Bucket) SetWithMeta(k string, flags int, exp int, v interface{}) (*MutationToken, error) { + return b.WriteWithMT(k, flags, exp, v, 0) +} + +// SetRaw sets a value in this bucket without JSON encoding it. +func (b *Bucket) SetRaw(k string, exp int, v []byte) error { + return b.Write(k, 0, exp, v, Raw) +} + +// Add adds a value to this bucket; like Set except that nothing +// happens if the key exists. The value will be serialized into a +// JSON document. +func (b *Bucket) Add(k string, exp int, v interface{}) (added bool, err error) { + err = b.Write(k, 0, exp, v, AddOnly) + if err == ErrKeyExists { + return false, nil + } + return (err == nil), err +} + +// AddRaw adds a value to this bucket; like SetRaw except that nothing +// happens if the key exists. The value will be stored as raw bytes. +func (b *Bucket) AddRaw(k string, exp int, v []byte) (added bool, err error) { + err = b.Write(k, 0, exp, v, AddOnly|Raw) + if err == ErrKeyExists { + return false, nil + } + return (err == nil), err +} + +// Add adds a value to this bucket; like Set except that nothing +// happens if the key exists. The value will be serialized into a +// JSON document. +func (b *Bucket) AddWithMT(k string, exp int, v interface{}) (added bool, mt *MutationToken, err error) { + mt, err = b.WriteWithMT(k, 0, exp, v, AddOnly) + if err == ErrKeyExists { + return false, mt, nil + } + return (err == nil), mt, err +} + +// AddRaw adds a value to this bucket; like SetRaw except that nothing +// happens if the key exists. The value will be stored as raw bytes. +func (b *Bucket) AddRawWithMT(k string, exp int, v []byte) (added bool, mt *MutationToken, err error) { + mt, err = b.WriteWithMT(k, 0, exp, v, AddOnly|Raw) + if err == ErrKeyExists { + return false, mt, nil + } + return (err == nil), mt, err +} + +// Append appends raw data to an existing item. +func (b *Bucket) Append(k string, data []byte) error { + return b.Write(k, 0, 0, data, Append|Raw) +} + +// Get a value straight from Memcached +func (b *Bucket) GetsMC(key string, reqDeadline time.Time) (*gomemcached.MCResponse, error) { + var err error + var response *gomemcached.MCResponse + + if key == "" { + return nil, nil + } + + if ClientOpCallback != nil { + defer func(t time.Time) { ClientOpCallback("GetsMC", key, t, err) }(time.Now()) + } + + err = b.DoNoDeadline(key, func(mc *memcached.Client, vb uint16) error { + var err1 error + + mc.SetDeadline(getDeadline(reqDeadline, DefaultTimeout)) + response, err1 = mc.Get(vb, key) + mc.SetDeadline(noDeadline) + if err1 != nil { + return err1 + } + return nil + }) + return response, err +} + +// Get a value through the subdoc API +func (b *Bucket) GetsSubDoc(key string, reqDeadline time.Time, subPaths []string) (*gomemcached.MCResponse, error) { + var err error + var response *gomemcached.MCResponse + + if key == "" { + return nil, nil + } + + if ClientOpCallback != nil { + defer func(t time.Time) { ClientOpCallback("GetsSubDoc", key, t, err) }(time.Now()) + } + + err = b.DoNoDeadline(key, func(mc *memcached.Client, vb uint16) error { + var err1 error + + mc.SetDeadline(getDeadline(reqDeadline, DefaultTimeout)) + response, err1 = mc.GetSubdoc(vb, key, subPaths) + mc.SetDeadline(noDeadline) + if err1 != nil { + return err1 + } + return nil + }) + return response, err +} + +// GetsRaw gets a raw value from this bucket including its CAS +// counter and flags. +func (b *Bucket) GetsRaw(k string) (data []byte, flags int, + cas uint64, err error) { + + if ClientOpCallback != nil { + defer func(t time.Time) { ClientOpCallback("GetsRaw", k, t, err) }(time.Now()) + } + + err = b.Do(k, func(mc *memcached.Client, vb uint16) error { + res, err := mc.Get(vb, k) + if err != nil { + return err + } + cas = res.Cas + if len(res.Extras) >= 4 { + flags = int(binary.BigEndian.Uint32(res.Extras)) + } + data = res.Body + return nil + }) + return +} + +// Gets gets a value from this bucket, including its CAS counter. The +// value is expected to be a JSON stream and will be deserialized into +// rv. +func (b *Bucket) Gets(k string, rv interface{}, caso *uint64) error { + data, _, cas, err := b.GetsRaw(k) + if err != nil { + return err + } + if caso != nil { + *caso = cas + } + return json.Unmarshal(data, rv) +} + +// Get a value from this bucket. +// The value is expected to be a JSON stream and will be deserialized +// into rv. +func (b *Bucket) Get(k string, rv interface{}) error { + return b.Gets(k, rv, nil) +} + +// GetRaw gets a raw value from this bucket. No marshaling is performed. +func (b *Bucket) GetRaw(k string) ([]byte, error) { + d, _, _, err := b.GetsRaw(k) + return d, err +} + +// GetAndTouchRaw gets a raw value from this bucket including its CAS +// counter and flags, and updates the expiry on the doc. +func (b *Bucket) GetAndTouchRaw(k string, exp int) (data []byte, + cas uint64, err error) { + + if ClientOpCallback != nil { + defer func(t time.Time) { ClientOpCallback("GetsRaw", k, t, err) }(time.Now()) + } + + err = b.Do(k, func(mc *memcached.Client, vb uint16) error { + res, err := mc.GetAndTouch(vb, k, exp) + if err != nil { + return err + } + cas = res.Cas + data = res.Body + return nil + }) + return data, cas, err +} + +// GetMeta returns the meta values for a key +func (b *Bucket) GetMeta(k string, flags *int, expiry *int, cas *uint64, seqNo *uint64) (err error) { + + if ClientOpCallback != nil { + defer func(t time.Time) { ClientOpCallback("GetsMeta", k, t, err) }(time.Now()) + } + + err = b.Do(k, func(mc *memcached.Client, vb uint16) error { + res, err := mc.GetMeta(vb, k) + if err != nil { + return err + } + + *cas = res.Cas + if len(res.Extras) >= 8 { + *flags = int(binary.BigEndian.Uint32(res.Extras[4:])) + } + + if len(res.Extras) >= 12 { + *expiry = int(binary.BigEndian.Uint32(res.Extras[8:])) + } + + if len(res.Extras) >= 20 { + *seqNo = uint64(binary.BigEndian.Uint64(res.Extras[12:])) + } + + return nil + }) + + return err +} + +// Delete a key from this bucket. +func (b *Bucket) Delete(k string) error { + return b.Write(k, 0, 0, nil, Raw) +} + +// Incr increments the value at a given key by amt and defaults to def if no value present. +func (b *Bucket) Incr(k string, amt, def uint64, exp int) (val uint64, err error) { + if ClientOpCallback != nil { + defer func(t time.Time) { ClientOpCallback("Incr", k, t, err) }(time.Now()) + } + + var rv uint64 + err = b.Do(k, func(mc *memcached.Client, vb uint16) error { + res, err := mc.Incr(vb, k, amt, def, exp) + if err != nil { + return err + } + rv = res + return nil + }) + return rv, err +} + +// Decr decrements the value at a given key by amt and defaults to def if no value present +func (b *Bucket) Decr(k string, amt, def uint64, exp int) (val uint64, err error) { + if ClientOpCallback != nil { + defer func(t time.Time) { ClientOpCallback("Decr", k, t, err) }(time.Now()) + } + + var rv uint64 + err = b.Do(k, func(mc *memcached.Client, vb uint16) error { + res, err := mc.Decr(vb, k, amt, def, exp) + if err != nil { + return err + } + rv = res + return nil + }) + return rv, err +} + +// Wrapper around memcached.CASNext() +func (b *Bucket) casNext(k string, exp int, state *memcached.CASState) bool { + if ClientOpCallback != nil { + defer func(t time.Time) { + ClientOpCallback("casNext", k, t, state.Err) + }(time.Now()) + } + + keepGoing := false + state.Err = b.Do(k, func(mc *memcached.Client, vb uint16) error { + keepGoing = mc.CASNext(vb, k, exp, state) + return state.Err + }) + return keepGoing && state.Err == nil +} + +// An UpdateFunc is a callback function to update a document +type UpdateFunc func(current []byte) (updated []byte, err error) + +// Return this as the error from an UpdateFunc to cancel the Update +// operation. +const UpdateCancel = memcached.CASQuit + +// Update performs a Safe update of a document, avoiding conflicts by +// using CAS. +// +// The callback function will be invoked with the current raw document +// contents (or nil if the document doesn't exist); it should return +// the updated raw contents (or nil to delete.) If it decides not to +// change anything it can return UpdateCancel as the error. +// +// If another writer modifies the document between the get and the +// set, the callback will be invoked again with the newer value. +func (b *Bucket) Update(k string, exp int, callback UpdateFunc) error { + _, err := b.update(k, exp, callback) + return err +} + +// internal version of Update that returns a CAS value +func (b *Bucket) update(k string, exp int, callback UpdateFunc) (newCas uint64, err error) { + var state memcached.CASState + for b.casNext(k, exp, &state) { + var err error + if state.Value, err = callback(state.Value); err != nil { + return 0, err + } + } + return state.Cas, state.Err +} + +// A WriteUpdateFunc is a callback function to update a document +type WriteUpdateFunc func(current []byte) (updated []byte, opt WriteOptions, err error) + +// WriteUpdate performs a Safe update of a document, avoiding +// conflicts by using CAS. WriteUpdate is like Update, except that +// the callback can return a set of WriteOptions, of which Persist and +// Indexable are recognized: these cause the call to wait until the +// document update has been persisted to disk and/or become available +// to index. +func (b *Bucket) WriteUpdate(k string, exp int, callback WriteUpdateFunc) error { + var writeOpts WriteOptions + var deletion bool + // Wrap the callback in an UpdateFunc we can pass to Update: + updateCallback := func(current []byte) (updated []byte, err error) { + update, opt, err := callback(current) + writeOpts = opt + deletion = (update == nil) + return update, err + } + cas, err := b.update(k, exp, updateCallback) + if err != nil { + return err + } + // If callback asked, wait for persistence or indexability: + if writeOpts&(Persist|Indexable) != 0 { + err = b.WaitForPersistence(k, cas, deletion) + } + return err +} + +// Observe observes the current state of a document. +func (b *Bucket) Observe(k string) (result memcached.ObserveResult, err error) { + if ClientOpCallback != nil { + defer func(t time.Time) { ClientOpCallback("Observe", k, t, err) }(time.Now()) + } + + err = b.Do(k, func(mc *memcached.Client, vb uint16) error { + result, err = mc.Observe(vb, k) + return err + }) + return +} + +// Returned from WaitForPersistence (or Write, if the Persistent or Indexable flag is used) +// if the value has been overwritten by another before being persisted. +var ErrOverwritten = errors.New("overwritten") + +// Returned from WaitForPersistence (or Write, if the Persistent or Indexable flag is used) +// if the value hasn't been persisted by the timeout interval +var ErrTimeout = errors.New("timeout") + +// WaitForPersistence waits for an item to be considered durable. +// +// Besides transport errors, ErrOverwritten may be returned if the +// item is overwritten before it reaches durability. ErrTimeout may +// occur if the item isn't found durable in a reasonable amount of +// time. +func (b *Bucket) WaitForPersistence(k string, cas uint64, deletion bool) error { + timeout := 10 * time.Second + sleepDelay := 5 * time.Millisecond + start := time.Now() + for { + time.Sleep(sleepDelay) + sleepDelay += sleepDelay / 2 // multiply delay by 1.5 every time + + result, err := b.Observe(k) + if err != nil { + return err + } + if persisted, overwritten := result.CheckPersistence(cas, deletion); overwritten { + return ErrOverwritten + } else if persisted { + return nil + } + + if result.PersistenceTime > 0 { + timeout = 2 * result.PersistenceTime + } + if time.Since(start) >= timeout-sleepDelay { + return ErrTimeout + } + } +} + +var _STRING_MCRESPONSE_POOL = gomemcached.NewStringMCResponsePool(16) + +type stringPool struct { + pool *sync.Pool + size int +} + +func newStringPool(size int) *stringPool { + rv := &stringPool{ + pool: &sync.Pool{ + New: func() interface{} { + return make([]string, 0, size) + }, + }, + size: size, + } + + return rv +} + +func (this *stringPool) Get() []string { + return this.pool.Get().([]string) +} + +func (this *stringPool) Put(s []string) { + if s == nil || cap(s) < this.size || cap(s) > 2*this.size { + return + } + + this.pool.Put(s[0:0]) +} + +var _STRING_POOL = newStringPool(16) + +type vbStringPool struct { + pool *sync.Pool + strPool *stringPool +} + +func newVBStringPool(size int, sp *stringPool) *vbStringPool { + rv := &vbStringPool{ + pool: &sync.Pool{ + New: func() interface{} { + return make(map[uint16][]string, size) + }, + }, + strPool: sp, + } + + return rv +} + +func (this *vbStringPool) Get() map[uint16][]string { + return this.pool.Get().(map[uint16][]string) +} + +func (this *vbStringPool) Put(s map[uint16][]string) { + if s == nil { + return + } + + for k, v := range s { + delete(s, k) + this.strPool.Put(v) + } + + this.pool.Put(s) +} + +var _VB_STRING_POOL = newVBStringPool(16, _STRING_POOL) diff --git a/vendor/github.com/couchbase/go-couchbase/conn_pool.go b/vendor/github.com/couchbase/go-couchbase/conn_pool.go new file mode 100644 index 00000000..babd3adb --- /dev/null +++ b/vendor/github.com/couchbase/go-couchbase/conn_pool.go @@ -0,0 +1,387 @@ +package couchbase + +import ( + "errors" + "sync/atomic" + "time" + + "github.com/couchbase/gomemcached" + "github.com/couchbase/gomemcached/client" + "github.com/couchbase/goutils/logging" +) + +// GenericMcdAuthHandler is a kind of AuthHandler that performs +// special auth exchange (like non-standard auth, possibly followed by +// select-bucket). +type GenericMcdAuthHandler interface { + AuthHandler + AuthenticateMemcachedConn(host string, conn *memcached.Client) error +} + +// Error raised when a connection can't be retrieved from a pool. +var TimeoutError = errors.New("timeout waiting to build connection") +var errClosedPool = errors.New("the connection pool is closed") +var errNoPool = errors.New("no connection pool") + +// Default timeout for retrieving a connection from the pool. +var ConnPoolTimeout = time.Hour * 24 * 30 + +// overflow connection closer cycle time +var ConnCloserInterval = time.Second * 30 + +// ConnPoolAvailWaitTime is the amount of time to wait for an existing +// connection from the pool before considering the creation of a new +// one. +var ConnPoolAvailWaitTime = time.Millisecond + +type connectionPool struct { + host string + mkConn func(host string, ah AuthHandler) (*memcached.Client, error) + auth AuthHandler + connections chan *memcached.Client + createsem chan bool + bailOut chan bool + poolSize int + connCount uint64 + inUse bool +} + +func newConnectionPool(host string, ah AuthHandler, closer bool, poolSize, poolOverflow int) *connectionPool { + connSize := poolSize + if closer { + connSize += poolOverflow + } + rv := &connectionPool{ + host: host, + connections: make(chan *memcached.Client, connSize), + createsem: make(chan bool, poolSize+poolOverflow), + mkConn: defaultMkConn, + auth: ah, + poolSize: poolSize, + } + if closer { + rv.bailOut = make(chan bool, 1) + go rv.connCloser() + } + return rv +} + +// ConnPoolTimeout is notified whenever connections are acquired from a pool. +var ConnPoolCallback func(host string, source string, start time.Time, err error) + +func defaultMkConn(host string, ah AuthHandler) (*memcached.Client, error) { + var features memcached.Features + + conn, err := memcached.Connect("tcp", host) + if err != nil { + return nil, err + } + + if TCPKeepalive == true { + conn.SetKeepAliveOptions(time.Duration(TCPKeepaliveInterval) * time.Second) + } + + if EnableMutationToken == true { + features = append(features, memcached.FeatureMutationToken) + } + if EnableDataType == true { + features = append(features, memcached.FeatureDataType) + } + + if EnableXattr == true { + features = append(features, memcached.FeatureXattr) + } + + if len(features) > 0 { + if DefaultTimeout > 0 { + conn.SetDeadline(getDeadline(noDeadline, DefaultTimeout)) + } + + res, err := conn.EnableFeatures(features) + + if DefaultTimeout > 0 { + conn.SetDeadline(noDeadline) + } + + if err != nil && isTimeoutError(err) { + conn.Close() + return nil, err + } + + if err != nil || res.Status != gomemcached.SUCCESS { + logging.Warnf("Unable to enable features %v", err) + } + } + + if gah, ok := ah.(GenericMcdAuthHandler); ok { + err = gah.AuthenticateMemcachedConn(host, conn) + if err != nil { + conn.Close() + return nil, err + } + return conn, nil + } + name, pass, bucket := ah.GetCredentials() + if name != "default" { + _, err = conn.Auth(name, pass) + if err != nil { + conn.Close() + return nil, err + } + // Select bucket (Required for cb_auth creds) + // Required when doing auth with _admin credentials + if bucket != "" && bucket != name { + _, err = conn.SelectBucket(bucket) + if err != nil { + conn.Close() + return nil, err + } + } + } + return conn, nil +} + +func (cp *connectionPool) Close() (err error) { + defer func() { + if recover() != nil { + err = errors.New("connectionPool.Close error") + } + }() + if cp.bailOut != nil { + + // defensively, we won't wait if the channel is full + select { + case cp.bailOut <- false: + default: + } + } + close(cp.connections) + for c := range cp.connections { + c.Close() + } + return +} + +func (cp *connectionPool) Node() string { + return cp.host +} + +func (cp *connectionPool) GetWithTimeout(d time.Duration) (rv *memcached.Client, err error) { + if cp == nil { + return nil, errNoPool + } + + path := "" + + if ConnPoolCallback != nil { + defer func(path *string, start time.Time) { + ConnPoolCallback(cp.host, *path, start, err) + }(&path, time.Now()) + } + + path = "short-circuit" + + // short-circuit available connetions. + select { + case rv, isopen := <-cp.connections: + if !isopen { + return nil, errClosedPool + } + atomic.AddUint64(&cp.connCount, 1) + return rv, nil + default: + } + + t := time.NewTimer(ConnPoolAvailWaitTime) + defer t.Stop() + + // Try to grab an available connection within 1ms + select { + case rv, isopen := <-cp.connections: + path = "avail1" + if !isopen { + return nil, errClosedPool + } + atomic.AddUint64(&cp.connCount, 1) + return rv, nil + case <-t.C: + // No connection came around in time, let's see + // whether we can get one or build a new one first. + t.Reset(d) // Reuse the timer for the full timeout. + select { + case rv, isopen := <-cp.connections: + path = "avail2" + if !isopen { + return nil, errClosedPool + } + atomic.AddUint64(&cp.connCount, 1) + return rv, nil + case cp.createsem <- true: + path = "create" + // Build a connection if we can't get a real one. + // This can potentially be an overflow connection, or + // a pooled connection. + rv, err := cp.mkConn(cp.host, cp.auth) + if err != nil { + // On error, release our create hold + <-cp.createsem + } else { + atomic.AddUint64(&cp.connCount, 1) + } + return rv, err + case <-t.C: + return nil, ErrTimeout + } + } +} + +func (cp *connectionPool) Get() (*memcached.Client, error) { + return cp.GetWithTimeout(ConnPoolTimeout) +} + +func (cp *connectionPool) Return(c *memcached.Client) { + if c == nil { + return + } + + if cp == nil { + c.Close() + } + + if c.IsHealthy() { + defer func() { + if recover() != nil { + // This happens when the pool has already been + // closed and we're trying to return a + // connection to it anyway. Just close the + // connection. + c.Close() + } + }() + + select { + case cp.connections <- c: + default: + <-cp.createsem + c.Close() + } + } else { + <-cp.createsem + c.Close() + } +} + +// give the ability to discard a connection from a pool +// useful for ditching connections to the wrong node after a rebalance +func (cp *connectionPool) Discard(c *memcached.Client) { + <-cp.createsem + c.Close() +} + +// asynchronous connection closer +func (cp *connectionPool) connCloser() { + var connCount uint64 + + t := time.NewTimer(ConnCloserInterval) + defer t.Stop() + + for { + connCount = cp.connCount + + // we don't exist anymore! bail out! + select { + case <-cp.bailOut: + return + case <-t.C: + } + t.Reset(ConnCloserInterval) + + // no overflow connections open or sustained requests for connections + // nothing to do until the next cycle + if len(cp.connections) <= cp.poolSize || + ConnCloserInterval/ConnPoolAvailWaitTime < time.Duration(cp.connCount-connCount) { + continue + } + + // close overflow connections now that they are not needed + for c := range cp.connections { + select { + case <-cp.bailOut: + return + default: + } + + // bail out if close did not work out + if !cp.connCleanup(c) { + return + } + if len(cp.connections) <= cp.poolSize { + break + } + } + } +} + +// close connection with recovery on error +func (cp *connectionPool) connCleanup(c *memcached.Client) (rv bool) { + + // just in case we are closing a connection after + // bailOut has been sent but we haven't yet read it + defer func() { + if recover() != nil { + rv = false + } + }() + rv = true + + c.Close() + <-cp.createsem + return +} + +func (cp *connectionPool) StartTapFeed(args *memcached.TapArguments) (*memcached.TapFeed, error) { + if cp == nil { + return nil, errNoPool + } + mc, err := cp.Get() + if err != nil { + return nil, err + } + + // A connection can't be used after TAP; Dont' count it against the + // connection pool capacity + <-cp.createsem + + return mc.StartTapFeed(*args) +} + +const DEFAULT_WINDOW_SIZE = 20 * 1024 * 1024 // 20 Mb + +func (cp *connectionPool) StartUprFeed(name string, sequence uint32, dcp_buffer_size uint32, data_chan_size int) (*memcached.UprFeed, error) { + if cp == nil { + return nil, errNoPool + } + mc, err := cp.Get() + if err != nil { + return nil, err + } + + // A connection can't be used after it has been allocated to UPR; + // Dont' count it against the connection pool capacity + <-cp.createsem + + uf, err := mc.NewUprFeed() + if err != nil { + return nil, err + } + + if err := uf.UprOpen(name, sequence, dcp_buffer_size); err != nil { + return nil, err + } + + if err := uf.StartFeedWithConfig(data_chan_size); err != nil { + return nil, err + } + + return uf, nil +} diff --git a/vendor/github.com/couchbase/go-couchbase/ddocs.go b/vendor/github.com/couchbase/go-couchbase/ddocs.go new file mode 100644 index 00000000..f9cc343a --- /dev/null +++ b/vendor/github.com/couchbase/go-couchbase/ddocs.go @@ -0,0 +1,288 @@ +package couchbase + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/couchbase/goutils/logging" + "io/ioutil" + "net/http" +) + +// ViewDefinition represents a single view within a design document. +type ViewDefinition struct { + Map string `json:"map"` + Reduce string `json:"reduce,omitempty"` +} + +// DDoc is the document body of a design document specifying a view. +type DDoc struct { + Language string `json:"language,omitempty"` + Views map[string]ViewDefinition `json:"views"` +} + +// DDocsResult represents the result from listing the design +// documents. +type DDocsResult struct { + Rows []struct { + DDoc struct { + Meta map[string]interface{} + JSON DDoc + } `json:"doc"` + } `json:"rows"` +} + +// GetDDocs lists all design documents +func (b *Bucket) GetDDocs() (DDocsResult, error) { + var ddocsResult DDocsResult + b.RLock() + pool := b.pool + uri := b.DDocs.URI + b.RUnlock() + + // MB-23555 ephemeral buckets have no ddocs + if uri == "" { + return DDocsResult{}, nil + } + + err := pool.client.parseURLResponse(uri, &ddocsResult) + if err != nil { + return DDocsResult{}, err + } + return ddocsResult, nil +} + +func (b *Bucket) GetDDocWithRetry(docname string, into interface{}) error { + ddocURI := fmt.Sprintf("/%s/_design/%s", b.GetName(), docname) + err := b.parseAPIResponse(ddocURI, &into) + if err != nil { + return err + } + return nil +} + +func (b *Bucket) GetDDocsWithRetry() (DDocsResult, error) { + var ddocsResult DDocsResult + b.RLock() + uri := b.DDocs.URI + b.RUnlock() + + // MB-23555 ephemeral buckets have no ddocs + if uri == "" { + return DDocsResult{}, nil + } + + err := b.parseURLResponse(uri, &ddocsResult) + if err != nil { + return DDocsResult{}, err + } + return ddocsResult, nil +} + +func (b *Bucket) ddocURL(docname string) (string, error) { + u, err := b.randomBaseURL() + if err != nil { + return "", err + } + u.Path = fmt.Sprintf("/%s/_design/%s", b.GetName(), docname) + return u.String(), nil +} + +func (b *Bucket) ddocURLNext(nodeId int, docname string) (string, int, error) { + u, selected, err := b.randomNextURL(nodeId) + if err != nil { + return "", -1, err + } + u.Path = fmt.Sprintf("/%s/_design/%s", b.GetName(), docname) + return u.String(), selected, nil +} + +const ABS_MAX_RETRIES = 10 +const ABS_MIN_RETRIES = 3 + +func (b *Bucket) getMaxRetries() (int, error) { + + maxRetries := len(b.Nodes()) + + if maxRetries == 0 { + return 0, fmt.Errorf("No available Couch rest URLs") + } + + if maxRetries > ABS_MAX_RETRIES { + maxRetries = ABS_MAX_RETRIES + } else if maxRetries < ABS_MIN_RETRIES { + maxRetries = ABS_MIN_RETRIES + } + + return maxRetries, nil +} + +// PutDDoc installs a design document. +func (b *Bucket) PutDDoc(docname string, value interface{}) error { + + var Err error + + maxRetries, err := b.getMaxRetries() + if err != nil { + return err + } + + lastNode := START_NODE_ID + + for retryCount := 0; retryCount < maxRetries; retryCount++ { + + Err = nil + + ddocU, selectedNode, err := b.ddocURLNext(lastNode, docname) + if err != nil { + return err + } + + lastNode = selectedNode + + logging.Infof(" Trying with selected node %d", selectedNode) + j, err := json.Marshal(value) + if err != nil { + return err + } + + req, err := http.NewRequest("PUT", ddocU, bytes.NewReader(j)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + err = maybeAddAuth(req, b.authHandler(false /* bucket not yet locked */)) + if err != nil { + return err + } + + res, err := doHTTPRequest(req) + if err != nil { + return err + } + + if res.StatusCode != 201 { + body, _ := ioutil.ReadAll(res.Body) + Err = fmt.Errorf("error installing view: %v / %s", + res.Status, body) + logging.Errorf(" Error in PutDDOC %v. Retrying...", Err) + res.Body.Close() + b.Refresh() + continue + } + + res.Body.Close() + break + } + + return Err +} + +// GetDDoc retrieves a specific a design doc. +func (b *Bucket) GetDDoc(docname string, into interface{}) error { + var Err error + var res *http.Response + + maxRetries, err := b.getMaxRetries() + if err != nil { + return err + } + + lastNode := START_NODE_ID + for retryCount := 0; retryCount < maxRetries; retryCount++ { + + Err = nil + ddocU, selectedNode, err := b.ddocURLNext(lastNode, docname) + if err != nil { + return err + } + + lastNode = selectedNode + logging.Infof(" Trying with selected node %d", selectedNode) + + req, err := http.NewRequest("GET", ddocU, nil) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + err = maybeAddAuth(req, b.authHandler(false /* bucket not yet locked */)) + if err != nil { + return err + } + + res, err = doHTTPRequest(req) + if err != nil { + return err + } + if res.StatusCode != 200 { + body, _ := ioutil.ReadAll(res.Body) + Err = fmt.Errorf("error reading view: %v / %s", + res.Status, body) + logging.Errorf(" Error in GetDDOC %v Retrying...", Err) + b.Refresh() + res.Body.Close() + continue + } + defer res.Body.Close() + break + } + + if Err != nil { + return Err + } + + d := json.NewDecoder(res.Body) + return d.Decode(into) +} + +// DeleteDDoc removes a design document. +func (b *Bucket) DeleteDDoc(docname string) error { + + var Err error + + maxRetries, err := b.getMaxRetries() + if err != nil { + return err + } + + lastNode := START_NODE_ID + + for retryCount := 0; retryCount < maxRetries; retryCount++ { + + Err = nil + ddocU, selectedNode, err := b.ddocURLNext(lastNode, docname) + if err != nil { + return err + } + + lastNode = selectedNode + logging.Infof(" Trying with selected node %d", selectedNode) + + req, err := http.NewRequest("DELETE", ddocU, nil) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + err = maybeAddAuth(req, b.authHandler(false /* bucket not already locked */)) + if err != nil { + return err + } + + res, err := doHTTPRequest(req) + if err != nil { + return err + } + if res.StatusCode != 200 { + body, _ := ioutil.ReadAll(res.Body) + Err = fmt.Errorf("error deleting view : %v / %s", res.Status, body) + logging.Errorf(" Error in DeleteDDOC %v. Retrying ... ", Err) + b.Refresh() + res.Body.Close() + continue + } + + res.Body.Close() + break + } + return Err +} diff --git a/vendor/github.com/couchbase/go-couchbase/observe.go b/vendor/github.com/couchbase/go-couchbase/observe.go new file mode 100644 index 00000000..6e746f5a --- /dev/null +++ b/vendor/github.com/couchbase/go-couchbase/observe.go @@ -0,0 +1,300 @@ +package couchbase + +import ( + "fmt" + "github.com/couchbase/goutils/logging" + "sync" +) + +type PersistTo uint8 + +const ( + PersistNone = PersistTo(0x00) + PersistMaster = PersistTo(0x01) + PersistOne = PersistTo(0x02) + PersistTwo = PersistTo(0x03) + PersistThree = PersistTo(0x04) + PersistFour = PersistTo(0x05) +) + +type ObserveTo uint8 + +const ( + ObserveNone = ObserveTo(0x00) + ObserveReplicateOne = ObserveTo(0x01) + ObserveReplicateTwo = ObserveTo(0x02) + ObserveReplicateThree = ObserveTo(0x03) + ObserveReplicateFour = ObserveTo(0x04) +) + +type JobType uint8 + +const ( + OBSERVE = JobType(0x00) + PERSIST = JobType(0x01) +) + +type ObservePersistJob struct { + vb uint16 + vbuuid uint64 + hostname string + jobType JobType + failover uint8 + lastPersistedSeqNo uint64 + currentSeqNo uint64 + resultChan chan *ObservePersistJob + errorChan chan *OPErrResponse +} + +type OPErrResponse struct { + vb uint16 + vbuuid uint64 + err error + job *ObservePersistJob +} + +var ObservePersistPool = NewPool(1024) +var OPJobChan = make(chan *ObservePersistJob, 1024) +var OPJobDone = make(chan bool) + +var wg sync.WaitGroup + +func (b *Bucket) StartOPPollers(maxWorkers int) { + + for i := 0; i < maxWorkers; i++ { + go b.OPJobPoll() + wg.Add(1) + } + wg.Wait() +} + +func (b *Bucket) SetObserveAndPersist(nPersist PersistTo, nObserve ObserveTo) (err error) { + + numNodes := len(b.Nodes()) + if int(nPersist) > numNodes || int(nObserve) > numNodes { + return fmt.Errorf("Not enough healthy nodes in the cluster") + } + + if int(nPersist) > (b.Replicas+1) || int(nObserve) > b.Replicas { + return fmt.Errorf("Not enough replicas in the cluster") + } + + if EnableMutationToken == false { + return fmt.Errorf("Mutation Tokens not enabled ") + } + + b.ds = &DurablitySettings{Persist: PersistTo(nPersist), Observe: ObserveTo(nObserve)} + return +} + +func (b *Bucket) ObserveAndPersistPoll(vb uint16, vbuuid uint64, seqNo uint64) (err error, failover bool) { + b.RLock() + ds := b.ds + b.RUnlock() + + if ds == nil { + return + } + + nj := 0 // total number of jobs + resultChan := make(chan *ObservePersistJob, 10) + errChan := make(chan *OPErrResponse, 10) + + nodes := b.GetNodeList(vb) + if int(ds.Observe) > len(nodes) || int(ds.Persist) > len(nodes) { + return fmt.Errorf("Not enough healthy nodes in the cluster"), false + } + + logging.Infof("Node list %v", nodes) + + if ds.Observe >= ObserveReplicateOne { + // create a job for each host + for i := ObserveReplicateOne; i < ds.Observe+1; i++ { + opJob := ObservePersistPool.Get() + opJob.vb = vb + opJob.vbuuid = vbuuid + opJob.jobType = OBSERVE + opJob.hostname = nodes[i] + opJob.resultChan = resultChan + opJob.errorChan = errChan + + OPJobChan <- opJob + nj++ + + } + } + + if ds.Persist >= PersistMaster { + for i := PersistMaster; i < ds.Persist+1; i++ { + opJob := ObservePersistPool.Get() + opJob.vb = vb + opJob.vbuuid = vbuuid + opJob.jobType = PERSIST + opJob.hostname = nodes[i] + opJob.resultChan = resultChan + opJob.errorChan = errChan + + OPJobChan <- opJob + nj++ + + } + } + + ok := true + for ok { + select { + case res := <-resultChan: + jobDone := false + if res.failover == 0 { + // no failover + if res.jobType == PERSIST { + if res.lastPersistedSeqNo >= seqNo { + jobDone = true + } + + } else { + if res.currentSeqNo >= seqNo { + jobDone = true + } + } + + if jobDone == true { + nj-- + ObservePersistPool.Put(res) + } else { + // requeue this job + OPJobChan <- res + } + + } else { + // Not currently handling failover scenarios TODO + nj-- + ObservePersistPool.Put(res) + failover = true + } + + if nj == 0 { + // done with all the jobs + ok = false + close(resultChan) + close(errChan) + } + + case Err := <-errChan: + logging.Errorf("Error in Observe/Persist %v", Err.err) + err = fmt.Errorf("Error in Observe/Persist job %v", Err.err) + nj-- + ObservePersistPool.Put(Err.job) + if nj == 0 { + close(resultChan) + close(errChan) + ok = false + } + } + } + + return +} + +func (b *Bucket) OPJobPoll() { + + ok := true + for ok == true { + select { + case job := <-OPJobChan: + pool := b.getConnPoolByHost(job.hostname, false /* bucket not already locked */) + if pool == nil { + errRes := &OPErrResponse{vb: job.vb, vbuuid: job.vbuuid} + errRes.err = fmt.Errorf("Pool not found for host %v", job.hostname) + errRes.job = job + job.errorChan <- errRes + continue + } + conn, err := pool.Get() + if err != nil { + errRes := &OPErrResponse{vb: job.vb, vbuuid: job.vbuuid} + errRes.err = fmt.Errorf("Unable to get connection from pool %v", err) + errRes.job = job + job.errorChan <- errRes + continue + } + + res, err := conn.ObserveSeq(job.vb, job.vbuuid) + if err != nil { + errRes := &OPErrResponse{vb: job.vb, vbuuid: job.vbuuid} + errRes.err = fmt.Errorf("Command failed %v", err) + errRes.job = job + job.errorChan <- errRes + continue + + } + pool.Return(conn) + job.lastPersistedSeqNo = res.LastPersistedSeqNo + job.currentSeqNo = res.CurrentSeqNo + job.failover = res.Failover + + job.resultChan <- job + case <-OPJobDone: + logging.Infof("Observe Persist Poller exitting") + ok = false + } + } + wg.Done() +} + +func (b *Bucket) GetNodeList(vb uint16) []string { + + vbm := b.VBServerMap() + if len(vbm.VBucketMap) < int(vb) { + logging.Infof("vbmap smaller than vblist") + return nil + } + + nodes := make([]string, len(vbm.VBucketMap[vb])) + for i := 0; i < len(vbm.VBucketMap[vb]); i++ { + n := vbm.VBucketMap[vb][i] + if n < 0 { + continue + } + + node := b.getMasterNode(n) + if len(node) > 1 { + nodes[i] = node + } + continue + + } + return nodes +} + +//pool of ObservePersist Jobs +type OPpool struct { + pool chan *ObservePersistJob +} + +// NewPool creates a new pool of jobs +func NewPool(max int) *OPpool { + return &OPpool{ + pool: make(chan *ObservePersistJob, max), + } +} + +// Borrow a Client from the pool. +func (p *OPpool) Get() *ObservePersistJob { + var o *ObservePersistJob + select { + case o = <-p.pool: + default: + o = &ObservePersistJob{} + } + return o +} + +// Return returns a Client to the pool. +func (p *OPpool) Put(o *ObservePersistJob) { + select { + case p.pool <- o: + default: + // let it go, let it go... + } +} diff --git a/vendor/github.com/couchbase/go-couchbase/pools.go b/vendor/github.com/couchbase/go-couchbase/pools.go new file mode 100644 index 00000000..c6f0a20b --- /dev/null +++ b/vendor/github.com/couchbase/go-couchbase/pools.go @@ -0,0 +1,1246 @@ +package couchbase + +import ( + "bufio" + "bytes" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "math/rand" + "net/http" + "net/url" + "runtime" + "sort" + "strings" + "sync" + "unsafe" + + "github.com/couchbase/goutils/logging" + + "github.com/couchbase/gomemcached" // package name is 'gomemcached' + "github.com/couchbase/gomemcached/client" // package name is 'memcached' +) + +// HTTPClient to use for REST and view operations. +var MaxIdleConnsPerHost = 256 +var HTTPTransport = &http.Transport{MaxIdleConnsPerHost: MaxIdleConnsPerHost} +var HTTPClient = &http.Client{Transport: HTTPTransport} + +// PoolSize is the size of each connection pool (per host). +var PoolSize = 64 + +// PoolOverflow is the number of overflow connections allowed in a +// pool. +var PoolOverflow = 16 + +// AsynchronousCloser turns on asynchronous closing for overflow connections +var AsynchronousCloser = false + +// TCP KeepAlive enabled/disabled +var TCPKeepalive = false + +// Enable MutationToken +var EnableMutationToken = false + +// Enable Data Type response +var EnableDataType = false + +// Enable Xattr +var EnableXattr = false + +// TCP keepalive interval in seconds. Default 30 minutes +var TCPKeepaliveInterval = 30 * 60 + +// Used to decide whether to skip verification of certificates when +// connecting to an ssl port. +var skipVerify = true +var certFile = "" +var keyFile = "" +var rootFile = "" + +func SetSkipVerify(skip bool) { + skipVerify = skip +} + +func SetCertFile(cert string) { + certFile = cert +} + +func SetKeyFile(cert string) { + keyFile = cert +} + +func SetRootFile(cert string) { + rootFile = cert +} + +// Allow applications to speciify the Poolsize and Overflow +func SetConnectionPoolParams(size, overflow int) { + + if size > 0 { + PoolSize = size + } + + if overflow > 0 { + PoolOverflow = overflow + } +} + +// Turn off overflow connections +func DisableOverflowConnections() { + PoolOverflow = 0 +} + +// Toggle asynchronous overflow closer +func EnableAsynchronousCloser(closer bool) { + AsynchronousCloser = closer +} + +// Allow TCP keepalive parameters to be set by the application +func SetTcpKeepalive(enabled bool, interval int) { + + TCPKeepalive = enabled + + if interval > 0 { + TCPKeepaliveInterval = interval + } +} + +// AuthHandler is a callback that gets the auth username and password +// for the given bucket. +type AuthHandler interface { + GetCredentials() (string, string, string) +} + +// AuthHandler is a callback that gets the auth username and password +// for the given bucket and sasl for memcached. +type AuthWithSaslHandler interface { + AuthHandler + GetSaslCredentials() (string, string) +} + +// MultiBucketAuthHandler is kind of AuthHandler that may perform +// different auth for different buckets. +type MultiBucketAuthHandler interface { + AuthHandler + ForBucket(bucket string) AuthHandler +} + +// HTTPAuthHandler is kind of AuthHandler that performs more general +// for outgoing http requests than is possible via simple +// GetCredentials() call (i.e. digest auth or different auth per +// different destinations). +type HTTPAuthHandler interface { + AuthHandler + SetCredsForRequest(req *http.Request) error +} + +// RestPool represents a single pool returned from the pools REST API. +type RestPool struct { + Name string `json:"name"` + StreamingURI string `json:"streamingUri"` + URI string `json:"uri"` +} + +// Pools represents the collection of pools as returned from the REST API. +type Pools struct { + ComponentsVersion map[string]string `json:"componentsVersion,omitempty"` + ImplementationVersion string `json:"implementationVersion"` + IsAdmin bool `json:"isAdminCreds"` + UUID string `json:"uuid"` + Pools []RestPool `json:"pools"` +} + +// A Node is a computer in a cluster running the couchbase software. +type Node struct { + ClusterCompatibility int `json:"clusterCompatibility"` + ClusterMembership string `json:"clusterMembership"` + CouchAPIBase string `json:"couchApiBase"` + Hostname string `json:"hostname"` + InterestingStats map[string]float64 `json:"interestingStats,omitempty"` + MCDMemoryAllocated float64 `json:"mcdMemoryAllocated"` + MCDMemoryReserved float64 `json:"mcdMemoryReserved"` + MemoryFree float64 `json:"memoryFree"` + MemoryTotal float64 `json:"memoryTotal"` + OS string `json:"os"` + Ports map[string]int `json:"ports"` + Services []string `json:"services"` + Status string `json:"status"` + Uptime int `json:"uptime,string"` + Version string `json:"version"` + ThisNode bool `json:"thisNode,omitempty"` +} + +// A Pool of nodes and buckets. +type Pool struct { + BucketMap map[string]Bucket + Nodes []Node + + BucketURL map[string]string `json:"buckets"` + + client Client +} + +// VBucketServerMap is the a mapping of vbuckets to nodes. +type VBucketServerMap struct { + HashAlgorithm string `json:"hashAlgorithm"` + NumReplicas int `json:"numReplicas"` + ServerList []string `json:"serverList"` + VBucketMap [][]int `json:"vBucketMap"` +} + +type DurablitySettings struct { + Persist PersistTo + Observe ObserveTo +} + +// Bucket is the primary entry point for most data operations. +// Bucket is a locked data structure. All access to its fields should be done using read or write locking, +// as appropriate. +// +// Some access methods require locking, but rely on the caller to do so. These are appropriate +// for calls from methods that have already locked the structure. Methods like this +// take a boolean parameter "bucketLocked". +type Bucket struct { + sync.RWMutex + AuthType string `json:"authType"` + Capabilities []string `json:"bucketCapabilities"` + CapabilitiesVersion string `json:"bucketCapabilitiesVer"` + Type string `json:"bucketType"` + Name string `json:"name"` + NodeLocator string `json:"nodeLocator"` + Quota map[string]float64 `json:"quota,omitempty"` + Replicas int `json:"replicaNumber"` + Password string `json:"saslPassword"` + URI string `json:"uri"` + StreamingURI string `json:"streamingUri"` + LocalRandomKeyURI string `json:"localRandomKeyUri,omitempty"` + UUID string `json:"uuid"` + ConflictResolutionType string `json:"conflictResolutionType,omitempty"` + DDocs struct { + URI string `json:"uri"` + } `json:"ddocs,omitempty"` + BasicStats map[string]interface{} `json:"basicStats,omitempty"` + Controllers map[string]interface{} `json:"controllers,omitempty"` + + // These are used for JSON IO, but isn't used for processing + // since it needs to be swapped out safely. + VBSMJson VBucketServerMap `json:"vBucketServerMap"` + NodesJSON []Node `json:"nodes"` + + pool *Pool + connPools unsafe.Pointer // *[]*connectionPool + vBucketServerMap unsafe.Pointer // *VBucketServerMap + nodeList unsafe.Pointer // *[]Node + commonSufix string + ah AuthHandler // auth handler + ds *DurablitySettings // Durablity Settings for this bucket +} + +// PoolServices is all the bucket-independent services in a pool +type PoolServices struct { + Rev int `json:"rev"` + NodesExt []NodeServices `json:"nodesExt"` +} + +// NodeServices is all the bucket-independent services running on +// a node (given by Hostname) +type NodeServices struct { + Services map[string]int `json:"services,omitempty"` + Hostname string `json:"hostname"` + ThisNode bool `json:"thisNode"` +} + +type BucketNotFoundError struct { + bucket string +} + +func (e *BucketNotFoundError) Error() string { + return fmt.Sprint("No bucket named " + e.bucket) +} + +type BucketAuth struct { + name string + saslPwd string + bucket string +} + +func newBucketAuth(name string, pass string, bucket string) *BucketAuth { + return &BucketAuth{name: name, saslPwd: pass, bucket: bucket} +} + +func (ba *BucketAuth) GetCredentials() (string, string, string) { + return ba.name, ba.saslPwd, ba.bucket +} + +// VBServerMap returns the current VBucketServerMap. +func (b *Bucket) VBServerMap() *VBucketServerMap { + b.RLock() + defer b.RUnlock() + ret := (*VBucketServerMap)(b.vBucketServerMap) + return ret +} + +func (b *Bucket) GetVBmap(addrs []string) (map[string][]uint16, error) { + vbmap := b.VBServerMap() + servers := vbmap.ServerList + if addrs == nil { + addrs = vbmap.ServerList + } + + m := make(map[string][]uint16) + for _, addr := range addrs { + m[addr] = make([]uint16, 0) + } + for vbno, idxs := range vbmap.VBucketMap { + if len(idxs) == 0 { + return nil, fmt.Errorf("vbmap: No KV node no for vb %d", vbno) + } else if idxs[0] < 0 || idxs[0] >= len(servers) { + return nil, fmt.Errorf("vbmap: Invalid KV node no %d for vb %d", idxs[0], vbno) + } + addr := servers[idxs[0]] + if _, ok := m[addr]; ok { + m[addr] = append(m[addr], uint16(vbno)) + } + } + return m, nil +} + +// true if node is not on the bucket VBmap +func (b *Bucket) checkVBmap(node string) bool { + vbmap := b.VBServerMap() + servers := vbmap.ServerList + + for _, idxs := range vbmap.VBucketMap { + if len(idxs) == 0 { + return true + } else if idxs[0] < 0 || idxs[0] >= len(servers) { + return true + } + if servers[idxs[0]] == node { + return false + } + } + return true +} + +func (b *Bucket) GetName() string { + b.RLock() + defer b.RUnlock() + ret := b.Name + return ret +} + +// Nodes returns teh current list of nodes servicing this bucket. +func (b *Bucket) Nodes() []Node { + b.RLock() + defer b.RUnlock() + ret := *(*[]Node)(b.nodeList) + return ret +} + +// return the list of healthy nodes +func (b *Bucket) HealthyNodes() []Node { + nodes := []Node{} + + for _, n := range b.Nodes() { + if n.Status == "healthy" && n.CouchAPIBase != "" { + nodes = append(nodes, n) + } + if n.Status != "healthy" { // log non-healthy node + logging.Infof("Non-healthy node; node details:") + logging.Infof("Hostname=%v, Status=%v, CouchAPIBase=%v, ThisNode=%v", n.Hostname, n.Status, n.CouchAPIBase, n.ThisNode) + } + } + + return nodes +} + +func (b *Bucket) getConnPools(bucketLocked bool) []*connectionPool { + if !bucketLocked { + b.RLock() + defer b.RUnlock() + } + if b.connPools != nil { + return *(*[]*connectionPool)(b.connPools) + } else { + return nil + } +} + +func (b *Bucket) replaceConnPools(with []*connectionPool) { + b.Lock() + defer b.Unlock() + + old := b.connPools + b.connPools = unsafe.Pointer(&with) + if old != nil { + for _, pool := range *(*[]*connectionPool)(old) { + if pool != nil { + pool.Close() + } + } + } + return +} + +func (b *Bucket) getConnPool(i int) *connectionPool { + + if i < 0 { + return nil + } + + p := b.getConnPools(false /* not already locked */) + if len(p) > i { + return p[i] + } + + return nil +} + +func (b *Bucket) getConnPoolByHost(host string, bucketLocked bool) *connectionPool { + pools := b.getConnPools(bucketLocked) + for _, p := range pools { + if p != nil && p.host == host { + return p + } + } + + return nil +} + +// Given a vbucket number, returns a memcached connection to it. +// The connection must be returned to its pool after use. +func (b *Bucket) getConnectionToVBucket(vb uint32) (*memcached.Client, *connectionPool, error) { + for { + vbm := b.VBServerMap() + if len(vbm.VBucketMap) < int(vb) { + return nil, nil, fmt.Errorf("go-couchbase: vbmap smaller than vbucket list: %v vs. %v", + vb, vbm.VBucketMap) + } + masterId := vbm.VBucketMap[vb][0] + if masterId < 0 { + return nil, nil, fmt.Errorf("go-couchbase: No master for vbucket %d", vb) + } + pool := b.getConnPool(masterId) + conn, err := pool.Get() + if err != errClosedPool { + return conn, pool, err + } + // If conn pool was closed, because another goroutine refreshed the vbucket map, retry... + } +} + +// To get random documents, we need to cover all the nodes, so select +// a connection at random. + +func (b *Bucket) getRandomConnection() (*memcached.Client, *connectionPool, error) { + for { + var currentPool = 0 + pools := b.getConnPools(false /* not already locked */) + if len(pools) == 0 { + return nil, nil, fmt.Errorf("No connection pool found") + } else if len(pools) > 1 { // choose a random connection + currentPool = rand.Intn(len(pools)) + } // if only one pool, currentPool defaults to 0, i.e., the only pool + + // get the pool + pool := pools[currentPool] + conn, err := pool.Get() + if err != errClosedPool { + return conn, pool, err + } + + // If conn pool was closed, because another goroutine refreshed the vbucket map, retry... + } +} + +// +// Get a random document from a bucket. Since the bucket may be distributed +// across nodes, we must first select a random connection, and then use the +// Client.GetRandomDoc() call to get a random document from that node. +// + +func (b *Bucket) GetRandomDoc() (*gomemcached.MCResponse, error) { + // get a connection from the pool + conn, pool, err := b.getRandomConnection() + + if err != nil { + return nil, err + } + + // get a randomm document from the connection + doc, err := conn.GetRandomDoc() + // need to return the connection to the pool + pool.Return(conn) + return doc, err +} + +func (b *Bucket) getMasterNode(i int) string { + p := b.getConnPools(false /* not already locked */) + if len(p) > i { + return p[i].host + } + return "" +} + +func (b *Bucket) authHandler(bucketLocked bool) (ah AuthHandler) { + if !bucketLocked { + b.RLock() + defer b.RUnlock() + } + pool := b.pool + name := b.Name + + if pool != nil { + ah = pool.client.ah + } + if mbah, ok := ah.(MultiBucketAuthHandler); ok { + return mbah.ForBucket(name) + } + if ah == nil { + ah = &basicAuth{name, ""} + } + return +} + +// NodeAddresses gets the (sorted) list of memcached node addresses +// (hostname:port). +func (b *Bucket) NodeAddresses() []string { + vsm := b.VBServerMap() + rv := make([]string, len(vsm.ServerList)) + copy(rv, vsm.ServerList) + sort.Strings(rv) + return rv +} + +// CommonAddressSuffix finds the longest common suffix of all +// host:port strings in the node list. +func (b *Bucket) CommonAddressSuffix() string { + input := []string{} + for _, n := range b.Nodes() { + input = append(input, n.Hostname) + } + return FindCommonSuffix(input) +} + +// A Client is the starting point for all services across all buckets +// in a Couchbase cluster. +type Client struct { + BaseURL *url.URL + ah AuthHandler + Info Pools +} + +func maybeAddAuth(req *http.Request, ah AuthHandler) error { + if hah, ok := ah.(HTTPAuthHandler); ok { + return hah.SetCredsForRequest(req) + } + if ah != nil { + user, pass, _ := ah.GetCredentials() + req.Header.Set("Authorization", "Basic "+ + base64.StdEncoding.EncodeToString([]byte(user+":"+pass))) + } + return nil +} + +// arbitary number, may need to be tuned #FIXME +const HTTP_MAX_RETRY = 5 + +// Someday golang network packages will implement standard +// error codes. Until then #sigh +func isHttpConnError(err error) bool { + + estr := err.Error() + return strings.Contains(estr, "broken pipe") || + strings.Contains(estr, "broken connection") || + strings.Contains(estr, "connection reset") +} + +var client *http.Client + +func ClientConfigForX509(certFile, keyFile, rootFile string) (*tls.Config, error) { + cfg := &tls.Config{} + + if certFile != "" && keyFile != "" { + tlsCert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + cfg.Certificates = []tls.Certificate{tlsCert} + } else { + //error need to pass both certfile and keyfile + return nil, fmt.Errorf("N1QL: Need to pass both certfile and keyfile") + } + + var caCert []byte + var err1 error + + caCertPool := x509.NewCertPool() + if rootFile != "" { + // Read that value in + caCert, err1 = ioutil.ReadFile(rootFile) + if err1 != nil { + return nil, fmt.Errorf(" Error in reading cacert file, err: %v", err1) + } + caCertPool.AppendCertsFromPEM(caCert) + } + + cfg.RootCAs = caCertPool + return cfg, nil +} + +func doHTTPRequest(req *http.Request) (*http.Response, error) { + + var err error + var res *http.Response + + tr := &http.Transport{} + + // we need a client that ignores certificate errors, since we self-sign + // our certs + if client == nil && req.URL.Scheme == "https" { + if skipVerify { + tr = &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + } else { + // Handle cases with cert + + cfg, err := ClientConfigForX509(certFile, keyFile, rootFile) + if err != nil { + return nil, err + } + + tr = &http.Transport{ + TLSClientConfig: cfg, + } + } + + client = &http.Client{Transport: tr} + + } else if client == nil { + client = HTTPClient + } + + for i := 0; i < HTTP_MAX_RETRY; i++ { + res, err = client.Do(req) + if err != nil && isHttpConnError(err) { + continue + } + break + } + + if err != nil { + return nil, err + } + + return res, err +} + +func doPutAPI(baseURL *url.URL, path string, params map[string]interface{}, authHandler AuthHandler, out interface{}) error { + return doOutputAPI("PUT", baseURL, path, params, authHandler, out) +} + +func doPostAPI(baseURL *url.URL, path string, params map[string]interface{}, authHandler AuthHandler, out interface{}) error { + return doOutputAPI("POST", baseURL, path, params, authHandler, out) +} + +func doOutputAPI( + httpVerb string, + baseURL *url.URL, + path string, + params map[string]interface{}, + authHandler AuthHandler, + out interface{}) error { + + var requestUrl string + + if q := strings.Index(path, "?"); q > 0 { + requestUrl = baseURL.Scheme + "://" + baseURL.Host + path[:q] + "?" + path[q+1:] + } else { + requestUrl = baseURL.Scheme + "://" + baseURL.Host + path + } + + postData := url.Values{} + for k, v := range params { + postData.Set(k, fmt.Sprintf("%v", v)) + } + + req, err := http.NewRequest(httpVerb, requestUrl, bytes.NewBufferString(postData.Encode())) + if err != nil { + return err + } + + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + err = maybeAddAuth(req, authHandler) + if err != nil { + return err + } + + res, err := doHTTPRequest(req) + if err != nil { + return err + } + + defer res.Body.Close() + if res.StatusCode != 200 { + bod, _ := ioutil.ReadAll(io.LimitReader(res.Body, 512)) + return fmt.Errorf("HTTP error %v getting %q: %s", + res.Status, requestUrl, bod) + } + + d := json.NewDecoder(res.Body) + if err = d.Decode(&out); err != nil { + return err + } + return nil +} + +func queryRestAPI( + baseURL *url.URL, + path string, + authHandler AuthHandler, + out interface{}) error { + + var requestUrl string + + if q := strings.Index(path, "?"); q > 0 { + requestUrl = baseURL.Scheme + "://" + baseURL.Host + path[:q] + "?" + path[q+1:] + } else { + requestUrl = baseURL.Scheme + "://" + baseURL.Host + path + } + + req, err := http.NewRequest("GET", requestUrl, nil) + if err != nil { + return err + } + + err = maybeAddAuth(req, authHandler) + if err != nil { + return err + } + + res, err := doHTTPRequest(req) + if err != nil { + return err + } + + defer res.Body.Close() + if res.StatusCode != 200 { + bod, _ := ioutil.ReadAll(io.LimitReader(res.Body, 512)) + return fmt.Errorf("HTTP error %v getting %q: %s", + res.Status, requestUrl, bod) + } + + d := json.NewDecoder(res.Body) + if err = d.Decode(&out); err != nil { + return err + } + return nil +} + +func (c *Client) ProcessStream(path string, callb func(interface{}) error, data interface{}) error { + return c.processStream(c.BaseURL, path, c.ah, callb, data) +} + +// Based on code in http://src.couchbase.org/source/xref/trunk/goproj/src/github.com/couchbase/indexing/secondary/dcp/pools.go#309 +func (c *Client) processStream(baseURL *url.URL, path string, authHandler AuthHandler, callb func(interface{}) error, data interface{}) error { + var requestUrl string + + if q := strings.Index(path, "?"); q > 0 { + requestUrl = baseURL.Scheme + "://" + baseURL.Host + path[:q] + "?" + path[q+1:] + } else { + requestUrl = baseURL.Scheme + "://" + baseURL.Host + path + } + + req, err := http.NewRequest("GET", requestUrl, nil) + if err != nil { + return err + } + + err = maybeAddAuth(req, authHandler) + if err != nil { + return err + } + + res, err := doHTTPRequest(req) + if err != nil { + return err + } + + defer res.Body.Close() + if res.StatusCode != 200 { + bod, _ := ioutil.ReadAll(io.LimitReader(res.Body, 512)) + return fmt.Errorf("HTTP error %v getting %q: %s", + res.Status, requestUrl, bod) + } + + reader := bufio.NewReader(res.Body) + for { + bs, err := reader.ReadBytes('\n') + if err != nil { + return err + } + if len(bs) == 1 && bs[0] == '\n' { + continue + } + + err = json.Unmarshal(bs, data) + if err != nil { + return err + } + err = callb(data) + if err != nil { + return err + } + } + return nil + +} + +func (c *Client) parseURLResponse(path string, out interface{}) error { + return queryRestAPI(c.BaseURL, path, c.ah, out) +} + +func (c *Client) parsePostURLResponse(path string, params map[string]interface{}, out interface{}) error { + return doPostAPI(c.BaseURL, path, params, c.ah, out) +} + +func (c *Client) parsePutURLResponse(path string, params map[string]interface{}, out interface{}) error { + return doPutAPI(c.BaseURL, path, params, c.ah, out) +} + +func (b *Bucket) parseURLResponse(path string, out interface{}) error { + nodes := b.Nodes() + if len(nodes) == 0 { + return errors.New("no couch rest URLs") + } + + // Pick a random node to start querying. + startNode := rand.Intn(len(nodes)) + maxRetries := len(nodes) + for i := 0; i < maxRetries; i++ { + node := nodes[(startNode+i)%len(nodes)] // Wrap around the nodes list. + // Skip non-healthy nodes. + if node.Status != "healthy" || node.CouchAPIBase == "" { + continue + } + url := &url.URL{ + Host: node.Hostname, + Scheme: "http", + } + + // Lock here to avoid having pool closed under us. + b.RLock() + err := queryRestAPI(url, path, b.pool.client.ah, out) + b.RUnlock() + if err == nil { + return err + } + } + return errors.New("All nodes failed to respond or no healthy nodes for bucket found") +} + +func (b *Bucket) parseAPIResponse(path string, out interface{}) error { + nodes := b.Nodes() + if len(nodes) == 0 { + return errors.New("no couch rest URLs") + } + + var err error + var u *url.URL + + // Pick a random node to start querying. + startNode := rand.Intn(len(nodes)) + maxRetries := len(nodes) + for i := 0; i < maxRetries; i++ { + node := nodes[(startNode+i)%len(nodes)] // Wrap around the nodes list. + // Skip non-healthy nodes. + if node.Status != "healthy" || node.CouchAPIBase == "" { + continue + } + + u, err = ParseURL(node.CouchAPIBase) + // Lock here so pool does not get closed under us. + b.RLock() + if err != nil { + b.RUnlock() + return fmt.Errorf("config error: Bucket %q node #%d CouchAPIBase=%q: %v", + b.Name, i, node.CouchAPIBase, err) + } else if b.pool != nil { + u.User = b.pool.client.BaseURL.User + } + u.Path = path + + // generate the path so that the strings are properly escaped + // MB-13770 + requestPath := strings.Split(u.String(), u.Host)[1] + + err = queryRestAPI(u, requestPath, b.pool.client.ah, out) + b.RUnlock() + if err == nil { + return err + } + } + + var errStr string + if err != nil { + errStr = "Error " + err.Error() + } + + return errors.New("All nodes failed to respond or returned error or no healthy nodes for bucket found." + errStr) +} + +type basicAuth struct { + u, p string +} + +func (b basicAuth) GetCredentials() (string, string, string) { + return b.u, b.p, b.u +} + +func basicAuthFromURL(us string) (ah AuthHandler) { + u, err := ParseURL(us) + if err != nil { + return + } + if user := u.User; user != nil { + pw, _ := user.Password() + ah = basicAuth{user.Username(), pw} + } + return +} + +// ConnectWithAuth connects to a couchbase cluster with the given +// authentication handler. +func ConnectWithAuth(baseU string, ah AuthHandler) (c Client, err error) { + c.BaseURL, err = ParseURL(baseU) + if err != nil { + return + } + c.ah = ah + + return c, c.parseURLResponse("/pools", &c.Info) +} + +// ConnectWithAuthCreds connects to a couchbase cluster with the give +// authorization creds returned by cb_auth +func ConnectWithAuthCreds(baseU, username, password string) (c Client, err error) { + c.BaseURL, err = ParseURL(baseU) + if err != nil { + return + } + + c.ah = newBucketAuth(username, password, "") + return c, c.parseURLResponse("/pools", &c.Info) + +} + +// Connect to a couchbase cluster. An authentication handler will be +// created from the userinfo in the URL if provided. +func Connect(baseU string) (Client, error) { + return ConnectWithAuth(baseU, basicAuthFromURL(baseU)) +} + +type BucketInfo struct { + Name string // name of bucket + Password string // SASL password of bucket +} + +//Get SASL buckets +func GetBucketList(baseU string) (bInfo []BucketInfo, err error) { + + c := &Client{} + c.BaseURL, err = ParseURL(baseU) + if err != nil { + return + } + c.ah = basicAuthFromURL(baseU) + + var buckets []Bucket + err = c.parseURLResponse("/pools/default/buckets", &buckets) + if err != nil { + return + } + bInfo = make([]BucketInfo, 0) + for _, bucket := range buckets { + bucketInfo := BucketInfo{Name: bucket.Name, Password: bucket.Password} + bInfo = append(bInfo, bucketInfo) + } + return bInfo, err +} + +//Set viewUpdateDaemonOptions +func SetViewUpdateParams(baseU string, params map[string]interface{}) (viewOpts map[string]interface{}, err error) { + + c := &Client{} + c.BaseURL, err = ParseURL(baseU) + if err != nil { + return + } + c.ah = basicAuthFromURL(baseU) + + if len(params) < 1 { + return nil, fmt.Errorf("No params to set") + } + + err = c.parsePostURLResponse("/settings/viewUpdateDaemon", params, &viewOpts) + if err != nil { + return + } + return viewOpts, err +} + +// This API lets the caller know, if the list of nodes a bucket is +// connected to has gone through an edit (a rebalance operation) +// since the last update to the bucket, in which case a Refresh is +// advised. +func (b *Bucket) NodeListChanged() bool { + b.RLock() + pool := b.pool + uri := b.URI + b.RUnlock() + + tmpb := &Bucket{} + err := pool.client.parseURLResponse(uri, tmpb) + if err != nil { + return true + } + + bNodes := *(*[]Node)(b.nodeList) + if len(bNodes) != len(tmpb.NodesJSON) { + return true + } + + bucketHostnames := map[string]bool{} + for _, node := range bNodes { + bucketHostnames[node.Hostname] = true + } + + for _, node := range tmpb.NodesJSON { + if _, found := bucketHostnames[node.Hostname]; !found { + return true + } + } + + return false +} + +func (b *Bucket) Refresh() error { + b.RLock() + pool := b.pool + uri := b.URI + b.RUnlock() + + tmpb := &Bucket{} + err := pool.client.parseURLResponse(uri, tmpb) + if err != nil { + return err + } + + pools := b.getConnPools(false /* bucket not already locked */) + + // We need this lock to ensure that bucket refreshes happening because + // of NMVb errors received during bulkGet do not end up over-writing + // pool.inUse. + b.Lock() + + for _, pool := range pools { + if pool != nil { + pool.inUse = false + } + } + + newcps := make([]*connectionPool, len(tmpb.VBSMJson.ServerList)) + for i := range newcps { + + pool := b.getConnPoolByHost(tmpb.VBSMJson.ServerList[i], true /* bucket already locked */) + if pool != nil && pool.inUse == false { + // if the hostname and index is unchanged then reuse this pool + newcps[i] = pool + pool.inUse = true + continue + } + + if b.ah != nil { + newcps[i] = newConnectionPool( + tmpb.VBSMJson.ServerList[i], + b.ah, AsynchronousCloser, PoolSize, PoolOverflow) + + } else { + newcps[i] = newConnectionPool( + tmpb.VBSMJson.ServerList[i], + b.authHandler(true /* bucket already locked */), + AsynchronousCloser, PoolSize, PoolOverflow) + } + } + b.replaceConnPools2(newcps, true /* bucket already locked */) + tmpb.ah = b.ah + b.vBucketServerMap = unsafe.Pointer(&tmpb.VBSMJson) + b.nodeList = unsafe.Pointer(&tmpb.NodesJSON) + + b.Unlock() + return nil +} + +func (p *Pool) refresh() (err error) { + p.BucketMap = make(map[string]Bucket) + + buckets := []Bucket{} + err = p.client.parseURLResponse(p.BucketURL["uri"], &buckets) + if err != nil { + return err + } + for _, b := range buckets { + b.pool = p + b.nodeList = unsafe.Pointer(&b.NodesJSON) + b.replaceConnPools(make([]*connectionPool, len(b.VBSMJson.ServerList))) + + p.BucketMap[b.Name] = b + } + return nil +} + +// GetPool gets a pool from within the couchbase cluster (usually +// "default"). +func (c *Client) GetPool(name string) (p Pool, err error) { + var poolURI string + for _, p := range c.Info.Pools { + if p.Name == name { + poolURI = p.URI + } + } + if poolURI == "" { + return p, errors.New("No pool named " + name) + } + + err = c.parseURLResponse(poolURI, &p) + + p.client = *c + + err = p.refresh() + return +} + +// GetPoolServices returns all the bucket-independent services in a pool. +// (See "Exposing services outside of bucket context" in http://goo.gl/uuXRkV) +func (c *Client) GetPoolServices(name string) (ps PoolServices, err error) { + var poolName string + for _, p := range c.Info.Pools { + if p.Name == name { + poolName = p.Name + } + } + if poolName == "" { + return ps, errors.New("No pool named " + name) + } + + poolURI := "/pools/" + poolName + "/nodeServices" + err = c.parseURLResponse(poolURI, &ps) + + return +} + +// Close marks this bucket as no longer needed, closing connections it +// may have open. +func (b *Bucket) Close() { + b.Lock() + defer b.Unlock() + if b.connPools != nil { + for _, c := range b.getConnPools(true /* already locked */) { + if c != nil { + c.Close() + } + } + b.connPools = nil + } +} + +func bucketFinalizer(b *Bucket) { + if b.connPools != nil { + logging.Warnf("Finalizing a bucket with active connections.") + } +} + +// GetBucket gets a bucket from within this pool. +func (p *Pool) GetBucket(name string) (*Bucket, error) { + rv, ok := p.BucketMap[name] + if !ok { + return nil, &BucketNotFoundError{bucket: name} + } + runtime.SetFinalizer(&rv, bucketFinalizer) + err := rv.Refresh() + if err != nil { + return nil, err + } + return &rv, nil +} + +// GetBucket gets a bucket from within this pool. +func (p *Pool) GetBucketWithAuth(bucket, username, password string) (*Bucket, error) { + rv, ok := p.BucketMap[bucket] + if !ok { + return nil, &BucketNotFoundError{bucket: bucket} + } + runtime.SetFinalizer(&rv, bucketFinalizer) + rv.ah = newBucketAuth(username, password, bucket) + err := rv.Refresh() + if err != nil { + return nil, err + } + return &rv, nil +} + +// GetPool gets the pool to which this bucket belongs. +func (b *Bucket) GetPool() *Pool { + b.RLock() + defer b.RUnlock() + ret := b.pool + return ret +} + +// GetClient gets the client from which we got this pool. +func (p *Pool) GetClient() *Client { + return &p.client +} + +// GetBucket is a convenience function for getting a named bucket from +// a URL +func GetBucket(endpoint, poolname, bucketname string) (*Bucket, error) { + var err error + client, err := Connect(endpoint) + if err != nil { + return nil, err + } + + pool, err := client.GetPool(poolname) + if err != nil { + return nil, err + } + + return pool.GetBucket(bucketname) +} + +// ConnectWithAuthAndGetBucket is a convenience function for +// getting a named bucket from a given URL and an auth callback +func ConnectWithAuthAndGetBucket(endpoint, poolname, bucketname string, + ah AuthHandler) (*Bucket, error) { + client, err := ConnectWithAuth(endpoint, ah) + if err != nil { + return nil, err + } + + pool, err := client.GetPool(poolname) + if err != nil { + return nil, err + } + + return pool.GetBucket(bucketname) +} diff --git a/vendor/github.com/couchbase/go-couchbase/streaming.go b/vendor/github.com/couchbase/go-couchbase/streaming.go new file mode 100644 index 00000000..8ac6f531 --- /dev/null +++ b/vendor/github.com/couchbase/go-couchbase/streaming.go @@ -0,0 +1,199 @@ +package couchbase + +import ( + "encoding/json" + "fmt" + "github.com/couchbase/goutils/logging" + "io" + "io/ioutil" + "math/rand" + "net" + "net/http" + "time" + "unsafe" +) + +// Bucket auto-updater gets the latest version of the bucket config from +// the server. If the configuration has changed then updated the local +// bucket information. If the bucket has been deleted then notify anyone +// who is holding a reference to this bucket + +const MAX_RETRY_COUNT = 5 +const DISCONNECT_PERIOD = 120 * time.Second + +type NotifyFn func(bucket string, err error) + +// Use TCP keepalive to detect half close sockets +var updaterTransport http.RoundTripper = &http.Transport{ + Proxy: http.ProxyFromEnvironment, + Dial: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).Dial, +} + +var updaterHTTPClient = &http.Client{Transport: updaterTransport} + +func doHTTPRequestForUpdate(req *http.Request) (*http.Response, error) { + + var err error + var res *http.Response + + for i := 0; i < HTTP_MAX_RETRY; i++ { + res, err = updaterHTTPClient.Do(req) + if err != nil && isHttpConnError(err) { + continue + } + break + } + + if err != nil { + return nil, err + } + + return res, err +} + +func (b *Bucket) RunBucketUpdater(notify NotifyFn) { + go func() { + err := b.UpdateBucket() + if err != nil { + if notify != nil { + notify(b.GetName(), err) + } + logging.Errorf(" Bucket Updater exited with err %v", err) + } + }() +} + +func (b *Bucket) replaceConnPools2(with []*connectionPool, bucketLocked bool) { + if !bucketLocked { + b.Lock() + defer b.Unlock() + } + old := b.connPools + b.connPools = unsafe.Pointer(&with) + if old != nil { + for _, pool := range *(*[]*connectionPool)(old) { + if pool != nil && pool.inUse == false { + pool.Close() + } + } + } + return +} + +func (b *Bucket) UpdateBucket() error { + + var failures int + var returnErr error + + for { + + if failures == MAX_RETRY_COUNT { + logging.Errorf(" Maximum failures reached. Exiting loop...") + return fmt.Errorf("Max failures reached. Last Error %v", returnErr) + } + + nodes := b.Nodes() + if len(nodes) < 1 { + return fmt.Errorf("No healthy nodes found") + } + + startNode := rand.Intn(len(nodes)) + node := nodes[(startNode)%len(nodes)] + + streamUrl := fmt.Sprintf("http://%s/pools/default/bucketsStreaming/%s", node.Hostname, b.GetName()) + logging.Infof(" Trying with %s", streamUrl) + req, err := http.NewRequest("GET", streamUrl, nil) + if err != nil { + return err + } + + // Lock here to avoid having pool closed under us. + b.RLock() + err = maybeAddAuth(req, b.pool.client.ah) + b.RUnlock() + if err != nil { + return err + } + + res, err := doHTTPRequestForUpdate(req) + if err != nil { + return err + } + + if res.StatusCode != 200 { + bod, _ := ioutil.ReadAll(io.LimitReader(res.Body, 512)) + logging.Errorf("Failed to connect to host, unexpected status code: %v. Body %s", res.StatusCode, bod) + res.Body.Close() + returnErr = fmt.Errorf("Failed to connect to host. Status %v Body %s", res.StatusCode, bod) + failures++ + continue + } + + dec := json.NewDecoder(res.Body) + + tmpb := &Bucket{} + for { + + err := dec.Decode(&tmpb) + if err != nil { + returnErr = err + res.Body.Close() + break + } + + // if we got here, reset failure count + failures = 0 + b.Lock() + + // mark all the old connection pools for deletion + pools := b.getConnPools(true /* already locked */) + for _, pool := range pools { + if pool != nil { + pool.inUse = false + } + } + + newcps := make([]*connectionPool, len(tmpb.VBSMJson.ServerList)) + for i := range newcps { + // get the old connection pool and check if it is still valid + pool := b.getConnPoolByHost(tmpb.VBSMJson.ServerList[i], true /* bucket already locked */) + if pool != nil && pool.inUse == false { + // if the hostname and index is unchanged then reuse this pool + newcps[i] = pool + pool.inUse = true + continue + } + // else create a new pool + if b.ah != nil { + newcps[i] = newConnectionPool( + tmpb.VBSMJson.ServerList[i], + b.ah, false, PoolSize, PoolOverflow) + + } else { + newcps[i] = newConnectionPool( + tmpb.VBSMJson.ServerList[i], + b.authHandler(true /* bucket already locked */), + false, PoolSize, PoolOverflow) + } + } + + b.replaceConnPools2(newcps, true /* bucket already locked */) + + tmpb.ah = b.ah + b.vBucketServerMap = unsafe.Pointer(&tmpb.VBSMJson) + b.nodeList = unsafe.Pointer(&tmpb.NodesJSON) + b.Unlock() + + logging.Infof("Got new configuration for bucket %s", b.GetName()) + + } + // we are here because of an error + failures++ + continue + + } + return nil +} diff --git a/vendor/github.com/couchbase/go-couchbase/tap.go b/vendor/github.com/couchbase/go-couchbase/tap.go new file mode 100644 index 00000000..86edd305 --- /dev/null +++ b/vendor/github.com/couchbase/go-couchbase/tap.go @@ -0,0 +1,143 @@ +package couchbase + +import ( + "github.com/couchbase/gomemcached/client" + "github.com/couchbase/goutils/logging" + "sync" + "time" +) + +const initialRetryInterval = 1 * time.Second +const maximumRetryInterval = 30 * time.Second + +// A TapFeed streams mutation events from a bucket. +// +// Events from the bucket can be read from the channel 'C'. Remember +// to call Close() on it when you're done, unless its channel has +// closed itself already. +type TapFeed struct { + C <-chan memcached.TapEvent + + bucket *Bucket + args *memcached.TapArguments + nodeFeeds []*memcached.TapFeed // The TAP feeds of the individual nodes + output chan memcached.TapEvent // Same as C but writeably-typed + wg sync.WaitGroup + quit chan bool +} + +// StartTapFeed creates and starts a new Tap feed +func (b *Bucket) StartTapFeed(args *memcached.TapArguments) (*TapFeed, error) { + if args == nil { + defaultArgs := memcached.DefaultTapArguments() + args = &defaultArgs + } + + feed := &TapFeed{ + bucket: b, + args: args, + output: make(chan memcached.TapEvent, 10), + quit: make(chan bool), + } + + go feed.run() + + feed.C = feed.output + return feed, nil +} + +// Goroutine that runs the feed +func (feed *TapFeed) run() { + retryInterval := initialRetryInterval + bucketOK := true + for { + // Connect to the TAP feed of each server node: + if bucketOK { + killSwitch, err := feed.connectToNodes() + if err == nil { + // Run until one of the sub-feeds fails: + select { + case <-killSwitch: + case <-feed.quit: + return + } + feed.closeNodeFeeds() + retryInterval = initialRetryInterval + } + } + + // On error, try to refresh the bucket in case the list of nodes changed: + logging.Infof("go-couchbase: TAP connection lost; reconnecting to bucket %q in %v", + feed.bucket.Name, retryInterval) + err := feed.bucket.Refresh() + bucketOK = err == nil + + select { + case <-time.After(retryInterval): + case <-feed.quit: + return + } + if retryInterval *= 2; retryInterval > maximumRetryInterval { + retryInterval = maximumRetryInterval + } + } +} + +func (feed *TapFeed) connectToNodes() (killSwitch chan bool, err error) { + killSwitch = make(chan bool) + for _, serverConn := range feed.bucket.getConnPools(false /* not already locked */) { + var singleFeed *memcached.TapFeed + singleFeed, err = serverConn.StartTapFeed(feed.args) + if err != nil { + logging.Errorf("go-couchbase: Error connecting to tap feed of %s: %v", serverConn.host, err) + feed.closeNodeFeeds() + return + } + feed.nodeFeeds = append(feed.nodeFeeds, singleFeed) + go feed.forwardTapEvents(singleFeed, killSwitch, serverConn.host) + feed.wg.Add(1) + } + return +} + +// Goroutine that forwards Tap events from a single node's feed to the aggregate feed. +func (feed *TapFeed) forwardTapEvents(singleFeed *memcached.TapFeed, killSwitch chan bool, host string) { + defer feed.wg.Done() + for { + select { + case event, ok := <-singleFeed.C: + if !ok { + if singleFeed.Error != nil { + logging.Errorf("go-couchbase: Tap feed from %s failed: %v", host, singleFeed.Error) + } + killSwitch <- true + return + } + feed.output <- event + case <-feed.quit: + return + } + } +} + +func (feed *TapFeed) closeNodeFeeds() { + for _, f := range feed.nodeFeeds { + f.Close() + } + feed.nodeFeeds = nil +} + +// Close a Tap feed. +func (feed *TapFeed) Close() error { + select { + case <-feed.quit: + return nil + default: + } + + feed.closeNodeFeeds() + close(feed.quit) + feed.wg.Wait() + close(feed.output) + return nil +} diff --git a/vendor/github.com/couchbase/go-couchbase/upr.go b/vendor/github.com/couchbase/go-couchbase/upr.go new file mode 100644 index 00000000..1c43ceeb --- /dev/null +++ b/vendor/github.com/couchbase/go-couchbase/upr.go @@ -0,0 +1,398 @@ +package couchbase + +import ( + "log" + "sync" + "time" + + "fmt" + "github.com/couchbase/gomemcached" + "github.com/couchbase/gomemcached/client" + "github.com/couchbase/goutils/logging" +) + +// A UprFeed streams mutation events from a bucket. +// +// Events from the bucket can be read from the channel 'C'. Remember +// to call Close() on it when you're done, unless its channel has +// closed itself already. +type UprFeed struct { + C <-chan *memcached.UprEvent + + bucket *Bucket + nodeFeeds map[string]*FeedInfo // The UPR feeds of the individual nodes + output chan *memcached.UprEvent // Same as C but writeably-typed + outputClosed bool + quit chan bool + name string // name of this UPR feed + sequence uint32 // sequence number for this feed + connected bool + killSwitch chan bool + closing bool + wg sync.WaitGroup + dcp_buffer_size uint32 + data_chan_size int +} + +// UprFeed from a single connection +type FeedInfo struct { + uprFeed *memcached.UprFeed // UPR feed handle + host string // hostname + connected bool // connected + quit chan bool // quit channel +} + +type FailoverLog map[uint16]memcached.FailoverLog + +// GetFailoverLogs, get the failover logs for a set of vbucket ids +func (b *Bucket) GetFailoverLogs(vBuckets []uint16) (FailoverLog, error) { + + // map vbids to their corresponding hosts + vbHostList := make(map[string][]uint16) + vbm := b.VBServerMap() + if len(vbm.VBucketMap) < len(vBuckets) { + return nil, fmt.Errorf("vbmap smaller than vbucket list: %v vs. %v", + vbm.VBucketMap, vBuckets) + } + + for _, vb := range vBuckets { + masterID := vbm.VBucketMap[vb][0] + master := b.getMasterNode(masterID) + if master == "" { + return nil, fmt.Errorf("No master found for vb %d", vb) + } + + vbList := vbHostList[master] + if vbList == nil { + vbList = make([]uint16, 0) + } + vbList = append(vbList, vb) + vbHostList[master] = vbList + } + + failoverLogMap := make(FailoverLog) + for _, serverConn := range b.getConnPools(false /* not already locked */) { + + vbList := vbHostList[serverConn.host] + if vbList == nil { + continue + } + + mc, err := serverConn.Get() + if err != nil { + logging.Infof("No Free connections for vblist %v", vbList) + return nil, fmt.Errorf("No Free connections for host %s", + serverConn.host) + + } + // close the connection so that it doesn't get reused for upr data + // connection + defer mc.Close() + failoverlogs, err := mc.UprGetFailoverLog(vbList) + if err != nil { + return nil, fmt.Errorf("Error getting failover log %s host %s", + err.Error(), serverConn.host) + + } + + for vb, log := range failoverlogs { + failoverLogMap[vb] = *log + } + } + + return failoverLogMap, nil +} + +func (b *Bucket) StartUprFeed(name string, sequence uint32) (*UprFeed, error) { + return b.StartUprFeedWithConfig(name, sequence, 10, DEFAULT_WINDOW_SIZE) +} + +// StartUprFeed creates and starts a new Upr feed +// No data will be sent on the channel unless vbuckets streams are requested +func (b *Bucket) StartUprFeedWithConfig(name string, sequence uint32, data_chan_size int, dcp_buffer_size uint32) (*UprFeed, error) { + + feed := &UprFeed{ + bucket: b, + output: make(chan *memcached.UprEvent, data_chan_size), + quit: make(chan bool), + nodeFeeds: make(map[string]*FeedInfo, 0), + name: name, + sequence: sequence, + killSwitch: make(chan bool), + dcp_buffer_size: dcp_buffer_size, + data_chan_size: data_chan_size, + } + + err := feed.connectToNodes() + if err != nil { + return nil, fmt.Errorf("Cannot connect to bucket %s", err.Error()) + } + feed.connected = true + go feed.run() + + feed.C = feed.output + return feed, nil +} + +// UprRequestStream starts a stream for a vb on a feed +func (feed *UprFeed) UprRequestStream(vb uint16, opaque uint16, flags uint32, + vuuid, startSequence, endSequence, snapStart, snapEnd uint64) error { + + defer func() { + if r := recover(); r != nil { + log.Panic("Panic in UprRequestStream. Feed %v Bucket %v ", feed, feed.bucket) + } + }() + + vbm := feed.bucket.VBServerMap() + if len(vbm.VBucketMap) < int(vb) { + return fmt.Errorf("vbmap smaller than vbucket list: %v vs. %v", + vb, vbm.VBucketMap) + } + + if int(vb) >= len(vbm.VBucketMap) { + return fmt.Errorf("Invalid vbucket id %d", vb) + } + + masterID := vbm.VBucketMap[vb][0] + master := feed.bucket.getMasterNode(masterID) + if master == "" { + return fmt.Errorf("Master node not found for vbucket %d", vb) + } + singleFeed := feed.nodeFeeds[master] + if singleFeed == nil { + return fmt.Errorf("UprFeed for this host not found") + } + + if err := singleFeed.uprFeed.UprRequestStream(vb, opaque, flags, + vuuid, startSequence, endSequence, snapStart, snapEnd); err != nil { + return err + } + + return nil +} + +// UprCloseStream ends a vbucket stream. +func (feed *UprFeed) UprCloseStream(vb, opaqueMSB uint16) error { + + defer func() { + if r := recover(); r != nil { + log.Panic("Panic in UprCloseStream. Feed %v Bucket %v ", feed, feed.bucket) + } + }() + + vbm := feed.bucket.VBServerMap() + if len(vbm.VBucketMap) < int(vb) { + return fmt.Errorf("vbmap smaller than vbucket list: %v vs. %v", + vb, vbm.VBucketMap) + } + + if int(vb) >= len(vbm.VBucketMap) { + return fmt.Errorf("Invalid vbucket id %d", vb) + } + + masterID := vbm.VBucketMap[vb][0] + master := feed.bucket.getMasterNode(masterID) + if master == "" { + return fmt.Errorf("Master node not found for vbucket %d", vb) + } + singleFeed := feed.nodeFeeds[master] + if singleFeed == nil { + return fmt.Errorf("UprFeed for this host not found") + } + + if err := singleFeed.uprFeed.CloseStream(vb, opaqueMSB); err != nil { + return err + } + return nil +} + +// Goroutine that runs the feed +func (feed *UprFeed) run() { + retryInterval := initialRetryInterval + bucketOK := true + for { + // Connect to the UPR feed of each server node: + if bucketOK { + // Run until one of the sub-feeds fails: + select { + case <-feed.killSwitch: + case <-feed.quit: + return + } + //feed.closeNodeFeeds() + retryInterval = initialRetryInterval + } + + if feed.closing == true { + // we have been asked to shut down + return + } + + // On error, try to refresh the bucket in case the list of nodes changed: + logging.Infof("go-couchbase: UPR connection lost; reconnecting to bucket %q in %v", + feed.bucket.Name, retryInterval) + + if err := feed.bucket.Refresh(); err != nil { + // if we fail to refresh the bucket, exit the feed + // MB-14917 + logging.Infof("Unable to refresh bucket %s ", err.Error()) + close(feed.output) + feed.outputClosed = true + feed.closeNodeFeeds() + return + } + + // this will only connect to nodes that are not connected or changed + // user will have to reconnect the stream + err := feed.connectToNodes() + if err != nil { + logging.Infof("Unable to connect to nodes..exit ") + close(feed.output) + feed.outputClosed = true + feed.closeNodeFeeds() + return + } + bucketOK = err == nil + + select { + case <-time.After(retryInterval): + case <-feed.quit: + return + } + if retryInterval *= 2; retryInterval > maximumRetryInterval { + retryInterval = maximumRetryInterval + } + } +} + +func (feed *UprFeed) connectToNodes() (err error) { + nodeCount := 0 + for _, serverConn := range feed.bucket.getConnPools(false /* not already locked */) { + + // this maybe a reconnection, so check if the connection to the node + // already exists. Connect only if the node is not found in the list + // or connected == false + nodeFeed := feed.nodeFeeds[serverConn.host] + + if nodeFeed != nil && nodeFeed.connected == true { + continue + } + + var singleFeed *memcached.UprFeed + var name string + if feed.name == "" { + name = "DefaultUprClient" + } else { + name = feed.name + } + singleFeed, err = serverConn.StartUprFeed(name, feed.sequence, feed.dcp_buffer_size, feed.data_chan_size) + if err != nil { + logging.Errorf("go-couchbase: Error connecting to upr feed of %s: %v", serverConn.host, err) + feed.closeNodeFeeds() + return + } + // add the node to the connection map + feedInfo := &FeedInfo{ + uprFeed: singleFeed, + connected: true, + host: serverConn.host, + quit: make(chan bool), + } + feed.nodeFeeds[serverConn.host] = feedInfo + go feed.forwardUprEvents(feedInfo, feed.killSwitch, serverConn.host) + feed.wg.Add(1) + nodeCount++ + } + if nodeCount == 0 { + return fmt.Errorf("No connection to bucket") + } + + return nil +} + +// Goroutine that forwards Upr events from a single node's feed to the aggregate feed. +func (feed *UprFeed) forwardUprEvents(nodeFeed *FeedInfo, killSwitch chan bool, host string) { + singleFeed := nodeFeed.uprFeed + + defer func() { + feed.wg.Done() + if r := recover(); r != nil { + //if feed is not closing, re-throw the panic + if feed.outputClosed != true && feed.closing != true { + panic(r) + } else { + logging.Errorf("Panic is recovered. Since feed is closed, exit gracefully") + + } + } + }() + + for { + select { + case <-nodeFeed.quit: + nodeFeed.connected = false + return + + case event, ok := <-singleFeed.C: + if !ok { + if singleFeed.Error != nil { + logging.Errorf("go-couchbase: Upr feed from %s failed: %v", host, singleFeed.Error) + } + killSwitch <- true + return + } + if feed.outputClosed == true { + // someone closed the node feed + logging.Infof("Node need closed, returning from forwardUprEvent") + return + } + feed.output <- event + if event.Status == gomemcached.NOT_MY_VBUCKET { + logging.Infof(" Got a not my vbucket error !! ") + if err := feed.bucket.Refresh(); err != nil { + logging.Errorf("Unable to refresh bucket %s ", err.Error()) + feed.closeNodeFeeds() + return + } + // this will only connect to nodes that are not connected or changed + // user will have to reconnect the stream + if err := feed.connectToNodes(); err != nil { + logging.Errorf("Unable to connect to nodes %s", err.Error()) + return + } + + } + } + } +} + +func (feed *UprFeed) closeNodeFeeds() { + for _, f := range feed.nodeFeeds { + logging.Infof(" Sending close to forwardUprEvent ") + close(f.quit) + f.uprFeed.Close() + } + feed.nodeFeeds = nil +} + +// Close a Upr feed. +func (feed *UprFeed) Close() error { + select { + case <-feed.quit: + return nil + default: + } + + feed.closing = true + feed.closeNodeFeeds() + close(feed.quit) + + feed.wg.Wait() + if feed.outputClosed == false { + feed.outputClosed = true + close(feed.output) + } + + return nil +} diff --git a/vendor/github.com/couchbase/go-couchbase/users.go b/vendor/github.com/couchbase/go-couchbase/users.go new file mode 100644 index 00000000..47d48615 --- /dev/null +++ b/vendor/github.com/couchbase/go-couchbase/users.go @@ -0,0 +1,119 @@ +package couchbase + +import ( + "bytes" + "fmt" +) + +type User struct { + Name string + Id string + Domain string + Roles []Role +} + +type Role struct { + Role string + BucketName string `json:"bucket_name"` +} + +// Sample: +// {"role":"admin","name":"Admin","desc":"Can manage ALL cluster features including security.","ce":true} +// {"role":"query_select","bucket_name":"*","name":"Query Select","desc":"Can execute SELECT statement on bucket to retrieve data"} +type RoleDescription struct { + Role string + Name string + Desc string + Ce bool + BucketName string `json:"bucket_name"` +} + +// Return user-role data, as parsed JSON. +// Sample: +// [{"id":"ivanivanov","name":"Ivan Ivanov","roles":[{"role":"cluster_admin"},{"bucket_name":"default","role":"bucket_admin"}]}, +// {"id":"petrpetrov","name":"Petr Petrov","roles":[{"role":"replication_admin"}]}] +func (c *Client) GetUserRoles() ([]interface{}, error) { + ret := make([]interface{}, 0, 1) + err := c.parseURLResponse("/settings/rbac/users", &ret) + if err != nil { + return nil, err + } + + // Get the configured administrator. + // Expected result: {"port":8091,"username":"Administrator"} + adminInfo := make(map[string]interface{}, 2) + err = c.parseURLResponse("/settings/web", &adminInfo) + if err != nil { + return nil, err + } + + // Create a special entry for the configured administrator. + adminResult := map[string]interface{}{ + "name": adminInfo["username"], + "id": adminInfo["username"], + "domain": "ns_server", + "roles": []interface{}{ + map[string]interface{}{ + "role": "admin", + }, + }, + } + + // Add the configured administrator to the list of results. + ret = append(ret, adminResult) + + return ret, nil +} + +func (c *Client) GetUserInfoAll() ([]User, error) { + ret := make([]User, 0, 16) + err := c.parseURLResponse("/settings/rbac/users", &ret) + if err != nil { + return nil, err + } + return ret, nil +} + +func rolesToParamFormat(roles []Role) string { + var buffer bytes.Buffer + for i, role := range roles { + if i > 0 { + buffer.WriteString(",") + } + buffer.WriteString(role.Role) + if role.BucketName != "" { + buffer.WriteString("[") + buffer.WriteString(role.BucketName) + buffer.WriteString("]") + } + } + return buffer.String() +} + +func (c *Client) PutUserInfo(u *User) error { + params := map[string]interface{}{ + "name": u.Name, + "roles": rolesToParamFormat(u.Roles), + } + var target string + switch u.Domain { + case "external": + target = "/settings/rbac/users/" + u.Id + case "local": + target = "/settings/rbac/users/local/" + u.Id + default: + return fmt.Errorf("Unknown user type: %s", u.Domain) + } + var ret string // PUT returns an empty string. We ignore it. + err := c.parsePutURLResponse(target, params, &ret) + return err +} + +func (c *Client) GetRolesAll() ([]RoleDescription, error) { + ret := make([]RoleDescription, 0, 32) + err := c.parseURLResponse("/settings/rbac/roles", &ret) + if err != nil { + return nil, err + } + return ret, nil +} diff --git a/vendor/github.com/couchbase/go-couchbase/util.go b/vendor/github.com/couchbase/go-couchbase/util.go new file mode 100644 index 00000000..4d286a32 --- /dev/null +++ b/vendor/github.com/couchbase/go-couchbase/util.go @@ -0,0 +1,49 @@ +package couchbase + +import ( + "fmt" + "net/url" + "strings" +) + +// CleanupHost returns the hostname with the given suffix removed. +func CleanupHost(h, commonSuffix string) string { + if strings.HasSuffix(h, commonSuffix) { + return h[:len(h)-len(commonSuffix)] + } + return h +} + +// FindCommonSuffix returns the longest common suffix from the given +// strings. +func FindCommonSuffix(input []string) string { + rv := "" + if len(input) < 2 { + return "" + } + from := input + for i := len(input[0]); i > 0; i-- { + common := true + suffix := input[0][i:] + for _, s := range from { + if !strings.HasSuffix(s, suffix) { + common = false + break + } + } + if common { + rv = suffix + } + } + return rv +} + +// ParseURL is a wrapper around url.Parse with some sanity-checking +func ParseURL(urlStr string) (result *url.URL, err error) { + result, err = url.Parse(urlStr) + if result != nil && result.Scheme == "" { + result = nil + err = fmt.Errorf("invalid URL <%s>", urlStr) + } + return +} diff --git a/vendor/github.com/couchbase/go-couchbase/vbmap.go b/vendor/github.com/couchbase/go-couchbase/vbmap.go new file mode 100644 index 00000000..b96a18ed --- /dev/null +++ b/vendor/github.com/couchbase/go-couchbase/vbmap.go @@ -0,0 +1,77 @@ +package couchbase + +var crc32tab = []uint32{ + 0x00000000, 0x77073096, 0xee0e612c, 0x990951ba, + 0x076dc419, 0x706af48f, 0xe963a535, 0x9e6495a3, + 0x0edb8832, 0x79dcb8a4, 0xe0d5e91e, 0x97d2d988, + 0x09b64c2b, 0x7eb17cbd, 0xe7b82d07, 0x90bf1d91, + 0x1db71064, 0x6ab020f2, 0xf3b97148, 0x84be41de, + 0x1adad47d, 0x6ddde4eb, 0xf4d4b551, 0x83d385c7, + 0x136c9856, 0x646ba8c0, 0xfd62f97a, 0x8a65c9ec, + 0x14015c4f, 0x63066cd9, 0xfa0f3d63, 0x8d080df5, + 0x3b6e20c8, 0x4c69105e, 0xd56041e4, 0xa2677172, + 0x3c03e4d1, 0x4b04d447, 0xd20d85fd, 0xa50ab56b, + 0x35b5a8fa, 0x42b2986c, 0xdbbbc9d6, 0xacbcf940, + 0x32d86ce3, 0x45df5c75, 0xdcd60dcf, 0xabd13d59, + 0x26d930ac, 0x51de003a, 0xc8d75180, 0xbfd06116, + 0x21b4f4b5, 0x56b3c423, 0xcfba9599, 0xb8bda50f, + 0x2802b89e, 0x5f058808, 0xc60cd9b2, 0xb10be924, + 0x2f6f7c87, 0x58684c11, 0xc1611dab, 0xb6662d3d, + 0x76dc4190, 0x01db7106, 0x98d220bc, 0xefd5102a, + 0x71b18589, 0x06b6b51f, 0x9fbfe4a5, 0xe8b8d433, + 0x7807c9a2, 0x0f00f934, 0x9609a88e, 0xe10e9818, + 0x7f6a0dbb, 0x086d3d2d, 0x91646c97, 0xe6635c01, + 0x6b6b51f4, 0x1c6c6162, 0x856530d8, 0xf262004e, + 0x6c0695ed, 0x1b01a57b, 0x8208f4c1, 0xf50fc457, + 0x65b0d9c6, 0x12b7e950, 0x8bbeb8ea, 0xfcb9887c, + 0x62dd1ddf, 0x15da2d49, 0x8cd37cf3, 0xfbd44c65, + 0x4db26158, 0x3ab551ce, 0xa3bc0074, 0xd4bb30e2, + 0x4adfa541, 0x3dd895d7, 0xa4d1c46d, 0xd3d6f4fb, + 0x4369e96a, 0x346ed9fc, 0xad678846, 0xda60b8d0, + 0x44042d73, 0x33031de5, 0xaa0a4c5f, 0xdd0d7cc9, + 0x5005713c, 0x270241aa, 0xbe0b1010, 0xc90c2086, + 0x5768b525, 0x206f85b3, 0xb966d409, 0xce61e49f, + 0x5edef90e, 0x29d9c998, 0xb0d09822, 0xc7d7a8b4, + 0x59b33d17, 0x2eb40d81, 0xb7bd5c3b, 0xc0ba6cad, + 0xedb88320, 0x9abfb3b6, 0x03b6e20c, 0x74b1d29a, + 0xead54739, 0x9dd277af, 0x04db2615, 0x73dc1683, + 0xe3630b12, 0x94643b84, 0x0d6d6a3e, 0x7a6a5aa8, + 0xe40ecf0b, 0x9309ff9d, 0x0a00ae27, 0x7d079eb1, + 0xf00f9344, 0x8708a3d2, 0x1e01f268, 0x6906c2fe, + 0xf762575d, 0x806567cb, 0x196c3671, 0x6e6b06e7, + 0xfed41b76, 0x89d32be0, 0x10da7a5a, 0x67dd4acc, + 0xf9b9df6f, 0x8ebeeff9, 0x17b7be43, 0x60b08ed5, + 0xd6d6a3e8, 0xa1d1937e, 0x38d8c2c4, 0x4fdff252, + 0xd1bb67f1, 0xa6bc5767, 0x3fb506dd, 0x48b2364b, + 0xd80d2bda, 0xaf0a1b4c, 0x36034af6, 0x41047a60, + 0xdf60efc3, 0xa867df55, 0x316e8eef, 0x4669be79, + 0xcb61b38c, 0xbc66831a, 0x256fd2a0, 0x5268e236, + 0xcc0c7795, 0xbb0b4703, 0x220216b9, 0x5505262f, + 0xc5ba3bbe, 0xb2bd0b28, 0x2bb45a92, 0x5cb36a04, + 0xc2d7ffa7, 0xb5d0cf31, 0x2cd99e8b, 0x5bdeae1d, + 0x9b64c2b0, 0xec63f226, 0x756aa39c, 0x026d930a, + 0x9c0906a9, 0xeb0e363f, 0x72076785, 0x05005713, + 0x95bf4a82, 0xe2b87a14, 0x7bb12bae, 0x0cb61b38, + 0x92d28e9b, 0xe5d5be0d, 0x7cdcefb7, 0x0bdbdf21, + 0x86d3d2d4, 0xf1d4e242, 0x68ddb3f8, 0x1fda836e, + 0x81be16cd, 0xf6b9265b, 0x6fb077e1, 0x18b74777, + 0x88085ae6, 0xff0f6a70, 0x66063bca, 0x11010b5c, + 0x8f659eff, 0xf862ae69, 0x616bffd3, 0x166ccf45, + 0xa00ae278, 0xd70dd2ee, 0x4e048354, 0x3903b3c2, + 0xa7672661, 0xd06016f7, 0x4969474d, 0x3e6e77db, + 0xaed16a4a, 0xd9d65adc, 0x40df0b66, 0x37d83bf0, + 0xa9bcae53, 0xdebb9ec5, 0x47b2cf7f, 0x30b5ffe9, + 0xbdbdf21c, 0xcabac28a, 0x53b39330, 0x24b4a3a6, + 0xbad03605, 0xcdd70693, 0x54de5729, 0x23d967bf, + 0xb3667a2e, 0xc4614ab8, 0x5d681b02, 0x2a6f2b94, + 0xb40bbe37, 0xc30c8ea1, 0x5a05df1b, 0x2d02ef8d} + +// VBHash finds the vbucket for the given key. +func (b *Bucket) VBHash(key string) uint32 { + crc := uint32(0xffffffff) + for x := 0; x < len(key); x++ { + crc = (crc >> 8) ^ crc32tab[(uint64(crc)^uint64(key[x]))&0xff] + } + vbm := b.VBServerMap() + return ((^crc) >> 16) & 0x7fff & (uint32(len(vbm.VBucketMap)) - 1) +} diff --git a/vendor/github.com/couchbase/go-couchbase/views.go b/vendor/github.com/couchbase/go-couchbase/views.go new file mode 100644 index 00000000..2f68642f --- /dev/null +++ b/vendor/github.com/couchbase/go-couchbase/views.go @@ -0,0 +1,231 @@ +package couchbase + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "math/rand" + "net/http" + "net/url" + "time" +) + +// ViewRow represents a single result from a view. +// +// Doc is present only if include_docs was set on the request. +type ViewRow struct { + ID string + Key interface{} + Value interface{} + Doc *interface{} +} + +// A ViewError is a node-specific error indicating a partial failure +// within a view result. +type ViewError struct { + From string + Reason string +} + +func (ve ViewError) Error() string { + return "Node: " + ve.From + ", reason: " + ve.Reason +} + +// ViewResult holds the entire result set from a view request, +// including the rows and the errors. +type ViewResult struct { + TotalRows int `json:"total_rows"` + Rows []ViewRow + Errors []ViewError +} + +func (b *Bucket) randomBaseURL() (*url.URL, error) { + nodes := b.HealthyNodes() + if len(nodes) == 0 { + return nil, errors.New("no available couch rest URLs") + } + nodeNo := rand.Intn(len(nodes)) + node := nodes[nodeNo] + + b.RLock() + name := b.Name + pool := b.pool + b.RUnlock() + + u, err := ParseURL(node.CouchAPIBase) + if err != nil { + return nil, fmt.Errorf("config error: Bucket %q node #%d CouchAPIBase=%q: %v", + name, nodeNo, node.CouchAPIBase, err) + } else if pool != nil { + u.User = pool.client.BaseURL.User + } + return u, err +} + +const START_NODE_ID = -1 + +func (b *Bucket) randomNextURL(lastNode int) (*url.URL, int, error) { + nodes := b.HealthyNodes() + if len(nodes) == 0 { + return nil, -1, errors.New("no available couch rest URLs") + } + + var nodeNo int + if lastNode == START_NODE_ID || lastNode >= len(nodes) { + // randomly select a node if the value of lastNode is invalid + nodeNo = rand.Intn(len(nodes)) + } else { + // wrap around the node list + nodeNo = (lastNode + 1) % len(nodes) + } + + b.RLock() + name := b.Name + pool := b.pool + b.RUnlock() + + node := nodes[nodeNo] + u, err := ParseURL(node.CouchAPIBase) + if err != nil { + return nil, -1, fmt.Errorf("config error: Bucket %q node #%d CouchAPIBase=%q: %v", + name, nodeNo, node.CouchAPIBase, err) + } else if pool != nil { + u.User = pool.client.BaseURL.User + } + return u, nodeNo, err +} + +// DocID is the document ID type for the startkey_docid parameter in +// views. +type DocID string + +func qParam(k, v string) string { + format := `"%s"` + switch k { + case "startkey_docid", "endkey_docid", "stale": + format = "%s" + } + return fmt.Sprintf(format, v) +} + +// ViewURL constructs a URL for a view with the given ddoc, view name, +// and parameters. +func (b *Bucket) ViewURL(ddoc, name string, + params map[string]interface{}) (string, error) { + u, err := b.randomBaseURL() + if err != nil { + return "", err + } + + values := url.Values{} + for k, v := range params { + switch t := v.(type) { + case DocID: + values[k] = []string{string(t)} + case string: + values[k] = []string{qParam(k, t)} + case int: + values[k] = []string{fmt.Sprintf(`%d`, t)} + case bool: + values[k] = []string{fmt.Sprintf(`%v`, t)} + default: + b, err := json.Marshal(v) + if err != nil { + return "", fmt.Errorf("unsupported value-type %T in Query, "+ + "json encoder said %v", t, err) + } + values[k] = []string{fmt.Sprintf(`%v`, string(b))} + } + } + + if ddoc == "" && name == "_all_docs" { + u.Path = fmt.Sprintf("/%s/_all_docs", b.GetName()) + } else { + u.Path = fmt.Sprintf("/%s/_design/%s/_view/%s", b.GetName(), ddoc, name) + } + u.RawQuery = values.Encode() + + return u.String(), nil +} + +// ViewCallback is called for each view invocation. +var ViewCallback func(ddoc, name string, start time.Time, err error) + +// ViewCustom performs a view request that can map row values to a +// custom type. +// +// See the source to View for an example usage. +func (b *Bucket) ViewCustom(ddoc, name string, params map[string]interface{}, + vres interface{}) (err error) { + if SlowServerCallWarningThreshold > 0 { + defer slowLog(time.Now(), "call to ViewCustom(%q, %q)", ddoc, name) + } + + if ViewCallback != nil { + defer func(t time.Time) { ViewCallback(ddoc, name, t, err) }(time.Now()) + } + + u, err := b.ViewURL(ddoc, name, params) + if err != nil { + return err + } + + req, err := http.NewRequest("GET", u, nil) + if err != nil { + return err + } + + ah := b.authHandler(false /* bucket not yet locked */) + maybeAddAuth(req, ah) + + res, err := doHTTPRequest(req) + if err != nil { + return fmt.Errorf("error starting view req at %v: %v", u, err) + } + defer res.Body.Close() + + if res.StatusCode != 200 { + bod := make([]byte, 512) + l, _ := res.Body.Read(bod) + return fmt.Errorf("error executing view req at %v: %v - %s", + u, res.Status, bod[:l]) + } + + body, err := ioutil.ReadAll(res.Body) + if err := json.Unmarshal(body, vres); err != nil { + return nil + } + + return nil +} + +// View executes a view. +// +// The ddoc parameter is just the bare name of your design doc without +// the "_design/" prefix. +// +// Parameters are string keys with values that correspond to couchbase +// view parameters. Primitive should work fairly naturally (booleans, +// ints, strings, etc...) and other values will attempt to be JSON +// marshaled (useful for array indexing on on view keys, for example). +// +// Example: +// +// res, err := couchbase.View("myddoc", "myview", map[string]interface{}{ +// "group_level": 2, +// "startkey_docid": []interface{}{"thing"}, +// "endkey_docid": []interface{}{"thing", map[string]string{}}, +// "stale": false, +// }) +func (b *Bucket) View(ddoc, name string, params map[string]interface{}) (ViewResult, error) { + vres := ViewResult{} + + if err := b.ViewCustom(ddoc, name, params, &vres); err != nil { + //error in accessing views. Retry once after a bucket refresh + b.Refresh() + return vres, b.ViewCustom(ddoc, name, params, &vres) + } else { + return vres, nil + } +} diff --git a/vendor/github.com/couchbase/gomemcached/.gitignore b/vendor/github.com/couchbase/gomemcached/.gitignore new file mode 100644 index 00000000..f75d85a8 --- /dev/null +++ b/vendor/github.com/couchbase/gomemcached/.gitignore @@ -0,0 +1,6 @@ +#* +*.[68] +*~ +*.swp +/gocache/gocache +c.out diff --git a/vendor/github.com/couchbase/gomemcached/LICENSE b/vendor/github.com/couchbase/gomemcached/LICENSE new file mode 100644 index 00000000..b01ef802 --- /dev/null +++ b/vendor/github.com/couchbase/gomemcached/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2013 Dustin Sallings + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/vendor/github.com/couchbase/gomemcached/README.markdown b/vendor/github.com/couchbase/gomemcached/README.markdown new file mode 100644 index 00000000..5e9b2de5 --- /dev/null +++ b/vendor/github.com/couchbase/gomemcached/README.markdown @@ -0,0 +1,32 @@ +# gomemcached + +This is a memcached binary protocol toolkit in [go][go]. + +It provides client and server functionality as well as a little sample +server showing how I might make a server if I valued purity over +performance. + +## Server Design + +
+ overview +
+ +The basic design can be seen in [gocache]. A [storage +server][storage] is run as a goroutine that receives a `MCRequest` on +a channel, and then issues an `MCResponse` to a channel contained +within the request. + +Each connection is a separate goroutine, of course, and is responsible +for all IO for that connection until the connection drops or the +`dataServer` decides it's stupid and sends a fatal response back over +the channel. + +There is currently no work at all in making the thing perform (there +are specific areas I know need work). This is just my attempt to +learn the language somewhat. + +[go]: http://golang.org/ +[gocache]: gomemcached/blob/master/gocache/gocache.go +[storage]: gomemcached/blob/master/gocache/mc_storage.go diff --git a/vendor/github.com/couchbase/gomemcached/client/mc.go b/vendor/github.com/couchbase/gomemcached/client/mc.go new file mode 100644 index 00000000..435ba6fa --- /dev/null +++ b/vendor/github.com/couchbase/gomemcached/client/mc.go @@ -0,0 +1,1078 @@ +// Package memcached provides a memcached binary protocol client. +package memcached + +import ( + "encoding/binary" + "fmt" + "github.com/couchbase/gomemcached" + "github.com/couchbase/goutils/logging" + "github.com/couchbase/goutils/scramsha" + "github.com/pkg/errors" + "io" + "math" + "net" + "strings" + "sync" + "sync/atomic" + "time" +) + +type ClientIface interface { + Add(vb uint16, key string, flags int, exp int, body []byte) (*gomemcached.MCResponse, error) + Append(vb uint16, key string, data []byte) (*gomemcached.MCResponse, error) + Auth(user, pass string) (*gomemcached.MCResponse, error) + AuthList() (*gomemcached.MCResponse, error) + AuthPlain(user, pass string) (*gomemcached.MCResponse, error) + AuthScramSha(user, pass string) (*gomemcached.MCResponse, error) + CASNext(vb uint16, k string, exp int, state *CASState) bool + CAS(vb uint16, k string, f CasFunc, initexp int) (*gomemcached.MCResponse, error) + Close() error + Decr(vb uint16, key string, amt, def uint64, exp int) (uint64, error) + Del(vb uint16, key string) (*gomemcached.MCResponse, error) + EnableMutationToken() (*gomemcached.MCResponse, error) + Get(vb uint16, key string) (*gomemcached.MCResponse, error) + GetSubdoc(vb uint16, key string, subPaths []string) (*gomemcached.MCResponse, error) + GetAndTouch(vb uint16, key string, exp int) (*gomemcached.MCResponse, error) + GetBulk(vb uint16, keys []string, rv map[string]*gomemcached.MCResponse, subPaths []string) error + GetMeta(vb uint16, key string) (*gomemcached.MCResponse, error) + GetRandomDoc() (*gomemcached.MCResponse, error) + Hijack() io.ReadWriteCloser + Incr(vb uint16, key string, amt, def uint64, exp int) (uint64, error) + Observe(vb uint16, key string) (result ObserveResult, err error) + ObserveSeq(vb uint16, vbuuid uint64) (result *ObserveSeqResult, err error) + Receive() (*gomemcached.MCResponse, error) + ReceiveWithDeadline(deadline time.Time) (*gomemcached.MCResponse, error) + Send(req *gomemcached.MCRequest) (rv *gomemcached.MCResponse, err error) + Set(vb uint16, key string, flags int, exp int, body []byte) (*gomemcached.MCResponse, error) + SetKeepAliveOptions(interval time.Duration) + SetReadDeadline(t time.Time) + SetDeadline(t time.Time) + SelectBucket(bucket string) (*gomemcached.MCResponse, error) + SetCas(vb uint16, key string, flags int, exp int, cas uint64, body []byte) (*gomemcached.MCResponse, error) + Stats(key string) ([]StatValue, error) + StatsMap(key string) (map[string]string, error) + StatsMapForSpecifiedStats(key string, statsMap map[string]string) error + Transmit(req *gomemcached.MCRequest) error + TransmitWithDeadline(req *gomemcached.MCRequest, deadline time.Time) error + TransmitResponse(res *gomemcached.MCResponse) error + + // UprFeed Related + NewUprFeed() (*UprFeed, error) + NewUprFeedIface() (UprFeedIface, error) + NewUprFeedWithConfig(ackByClient bool) (*UprFeed, error) + NewUprFeedWithConfigIface(ackByClient bool) (UprFeedIface, error) + UprGetFailoverLog(vb []uint16) (map[uint16]*FailoverLog, error) +} + +const bufsize = 1024 + +var UnHealthy uint32 = 0 +var Healthy uint32 = 1 + +type Features []Feature +type Feature uint16 + +const FeatureMutationToken = Feature(0x04) +const FeatureXattr = Feature(0x06) +const FeatureDataType = Feature(0x0b) + +// The Client itself. +type Client struct { + conn io.ReadWriteCloser + // use uint32 type so that it can be accessed through atomic APIs + healthy uint32 + opaque uint32 + + hdrBuf []byte +} + +var ( + DefaultDialTimeout = time.Duration(0) // No timeout + + DefaultWriteTimeout = time.Duration(0) // No timeout + + dialFun = func(prot, dest string) (net.Conn, error) { + return net.DialTimeout(prot, dest, DefaultDialTimeout) + } +) + +// Connect to a memcached server. +func Connect(prot, dest string) (rv *Client, err error) { + conn, err := dialFun(prot, dest) + if err != nil { + return nil, err + } + return Wrap(conn) +} + +func SetDefaultTimeouts(dial, read, write time.Duration) { + DefaultDialTimeout = dial + DefaultWriteTimeout = write +} + +func SetDefaultDialTimeout(dial time.Duration) { + DefaultDialTimeout = dial +} + +func (c *Client) SetKeepAliveOptions(interval time.Duration) { + c.conn.(*net.TCPConn).SetKeepAlive(true) + c.conn.(*net.TCPConn).SetKeepAlivePeriod(interval) +} + +func (c *Client) SetReadDeadline(t time.Time) { + c.conn.(*net.TCPConn).SetReadDeadline(t) +} + +func (c *Client) SetDeadline(t time.Time) { + c.conn.(*net.TCPConn).SetDeadline(t) +} + +// Wrap an existing transport. +func Wrap(rwc io.ReadWriteCloser) (rv *Client, err error) { + client := &Client{ + conn: rwc, + hdrBuf: make([]byte, gomemcached.HDR_LEN), + opaque: uint32(1), + } + client.setHealthy(true) + return client, nil +} + +// Close the connection when you're done. +func (c *Client) Close() error { + return c.conn.Close() +} + +// IsHealthy returns true unless the client is belived to have +// difficulty communicating to its server. +// +// This is useful for connection pools where we want to +// non-destructively determine that a connection may be reused. +func (c Client) IsHealthy() bool { + healthyState := atomic.LoadUint32(&c.healthy) + return healthyState == Healthy +} + +// Send a custom request and get the response. +func (c *Client) Send(req *gomemcached.MCRequest) (rv *gomemcached.MCResponse, err error) { + err = c.Transmit(req) + if err != nil { + return + } + resp, _, err := getResponse(c.conn, c.hdrBuf) + c.setHealthy(!gomemcached.IsFatal(err)) + return resp, err +} + +// Transmit send a request, but does not wait for a response. +func (c *Client) Transmit(req *gomemcached.MCRequest) error { + if DefaultWriteTimeout > 0 { + c.conn.(net.Conn).SetWriteDeadline(time.Now().Add(DefaultWriteTimeout)) + } + _, err := transmitRequest(c.conn, req) + // clear write deadline to avoid interference with future write operations + if DefaultWriteTimeout > 0 { + c.conn.(net.Conn).SetWriteDeadline(time.Time{}) + } + if err != nil { + c.setHealthy(false) + } + return err +} + +func (c *Client) TransmitWithDeadline(req *gomemcached.MCRequest, deadline time.Time) error { + c.conn.(net.Conn).SetWriteDeadline(deadline) + + _, err := transmitRequest(c.conn, req) + + // clear write deadline to avoid interference with future write operations + c.conn.(net.Conn).SetWriteDeadline(time.Time{}) + + if err != nil { + c.setHealthy(false) + } + return err +} + +// TransmitResponse send a response, does not wait. +func (c *Client) TransmitResponse(res *gomemcached.MCResponse) error { + if DefaultWriteTimeout > 0 { + c.conn.(net.Conn).SetWriteDeadline(time.Now().Add(DefaultWriteTimeout)) + } + _, err := transmitResponse(c.conn, res) + // clear write deadline to avoid interference with future write operations + if DefaultWriteTimeout > 0 { + c.conn.(net.Conn).SetWriteDeadline(time.Time{}) + } + if err != nil { + c.setHealthy(false) + } + return err +} + +// Receive a response +func (c *Client) Receive() (*gomemcached.MCResponse, error) { + resp, _, err := getResponse(c.conn, c.hdrBuf) + if err != nil && resp.Status != gomemcached.KEY_ENOENT && resp.Status != gomemcached.EBUSY { + c.setHealthy(false) + } + return resp, err +} + +func (c *Client) ReceiveWithDeadline(deadline time.Time) (*gomemcached.MCResponse, error) { + c.conn.(net.Conn).SetReadDeadline(deadline) + + resp, _, err := getResponse(c.conn, c.hdrBuf) + + // Clear read deadline to avoid interference with future read operations. + c.conn.(net.Conn).SetReadDeadline(time.Time{}) + + if err != nil && resp.Status != gomemcached.KEY_ENOENT && resp.Status != gomemcached.EBUSY { + c.setHealthy(false) + } + return resp, err +} + +func appendMutationToken(bytes []byte) []byte { + bytes = append(bytes, 0, 0) + binary.BigEndian.PutUint16(bytes[len(bytes)-2:], uint16(0x04)) + return bytes +} + +//Send a hello command to enable MutationTokens +func (c *Client) EnableMutationToken() (*gomemcached.MCResponse, error) { + var payload []byte + payload = appendMutationToken(payload) + + return c.Send(&gomemcached.MCRequest{ + Opcode: gomemcached.HELLO, + Key: []byte("GoMemcached"), + Body: payload, + }) + +} + +//Send a hello command to enable specific features +func (c *Client) EnableFeatures(features Features) (*gomemcached.MCResponse, error) { + var payload []byte + + for _, feature := range features { + payload = append(payload, 0, 0) + binary.BigEndian.PutUint16(payload[len(payload)-2:], uint16(feature)) + } + + return c.Send(&gomemcached.MCRequest{ + Opcode: gomemcached.HELLO, + Key: []byte("GoMemcached"), + Body: payload, + }) + +} + +// Get the value for a key. +func (c *Client) Get(vb uint16, key string) (*gomemcached.MCResponse, error) { + return c.Send(&gomemcached.MCRequest{ + Opcode: gomemcached.GET, + VBucket: vb, + Key: []byte(key), + }) +} + +// Get the xattrs, doc value for the input key +func (c *Client) GetSubdoc(vb uint16, key string, subPaths []string) (*gomemcached.MCResponse, error) { + + extraBuf, valueBuf := GetSubDocVal(subPaths) + res, err := c.Send(&gomemcached.MCRequest{ + Opcode: gomemcached.SUBDOC_MULTI_LOOKUP, + VBucket: vb, + Key: []byte(key), + Extras: extraBuf, + Body: valueBuf, + }) + + if err != nil && IfResStatusError(res) { + return res, err + } + return res, nil +} + +// Get the value for a key, and update expiry +func (c *Client) GetAndTouch(vb uint16, key string, exp int) (*gomemcached.MCResponse, error) { + extraBuf := make([]byte, 4) + binary.BigEndian.PutUint32(extraBuf[0:], uint32(exp)) + return c.Send(&gomemcached.MCRequest{ + Opcode: gomemcached.GAT, + VBucket: vb, + Key: []byte(key), + Extras: extraBuf, + }) +} + +// Get metadata for a key +func (c *Client) GetMeta(vb uint16, key string) (*gomemcached.MCResponse, error) { + return c.Send(&gomemcached.MCRequest{ + Opcode: gomemcached.GET_META, + VBucket: vb, + Key: []byte(key), + }) +} + +// Del deletes a key. +func (c *Client) Del(vb uint16, key string) (*gomemcached.MCResponse, error) { + return c.Send(&gomemcached.MCRequest{ + Opcode: gomemcached.DELETE, + VBucket: vb, + Key: []byte(key)}) +} + +// Get a random document +func (c *Client) GetRandomDoc() (*gomemcached.MCResponse, error) { + return c.Send(&gomemcached.MCRequest{ + Opcode: 0xB6, + }) +} + +// AuthList lists SASL auth mechanisms. +func (c *Client) AuthList() (*gomemcached.MCResponse, error) { + return c.Send(&gomemcached.MCRequest{ + Opcode: gomemcached.SASL_LIST_MECHS}) +} + +// Auth performs SASL PLAIN authentication against the server. +func (c *Client) Auth(user, pass string) (*gomemcached.MCResponse, error) { + res, err := c.AuthList() + + if err != nil { + return res, err + } + + authMech := string(res.Body) + if strings.Index(authMech, "PLAIN") != -1 { + return c.AuthPlain(user, pass) + } + return nil, fmt.Errorf("auth mechanism PLAIN not supported") +} + +// AuthScramSha performs SCRAM-SHA authentication against the server. +func (c *Client) AuthScramSha(user, pass string) (*gomemcached.MCResponse, error) { + res, err := c.AuthList() + if err != nil { + return nil, errors.Wrap(err, "Unable to obtain list of methods.") + } + + methods := string(res.Body) + method, err := scramsha.BestMethod(methods) + if err != nil { + return nil, errors.Wrap(err, + "Unable to select SCRAM-SHA method.") + } + + s, err := scramsha.NewScramSha(method) + if err != nil { + return nil, errors.Wrap(err, "Unable to initialize scramsha.") + } + + logging.Infof("Using %v authentication for user %v%v%v", method, gomemcached.UdTagBegin, user, gomemcached.UdTagEnd) + + message, err := s.GetStartRequest(user) + if err != nil { + return nil, errors.Wrapf(err, + "Error building start request for user %s.", user) + } + + startRequest := &gomemcached.MCRequest{ + Opcode: 0x21, + Key: []byte(method), + Body: []byte(message)} + + startResponse, err := c.Send(startRequest) + if err != nil { + return nil, errors.Wrap(err, "Error sending start request.") + } + + err = s.HandleStartResponse(string(startResponse.Body)) + if err != nil { + return nil, errors.Wrap(err, "Error handling start response.") + } + + message = s.GetFinalRequest(pass) + + // send step request + finalRequest := &gomemcached.MCRequest{ + Opcode: 0x22, + Key: []byte(method), + Body: []byte(message)} + finalResponse, err := c.Send(finalRequest) + if err != nil { + return nil, errors.Wrap(err, "Error sending final request.") + } + + err = s.HandleFinalResponse(string(finalResponse.Body)) + if err != nil { + return nil, errors.Wrap(err, "Error handling final response.") + } + + return finalResponse, nil +} + +func (c *Client) AuthPlain(user, pass string) (*gomemcached.MCResponse, error) { + logging.Infof("Using plain authentication for user %v%v%v", gomemcached.UdTagBegin, user, gomemcached.UdTagEnd) + return c.Send(&gomemcached.MCRequest{ + Opcode: gomemcached.SASL_AUTH, + Key: []byte("PLAIN"), + Body: []byte(fmt.Sprintf("\x00%s\x00%s", user, pass))}) +} + +// select bucket +func (c *Client) SelectBucket(bucket string) (*gomemcached.MCResponse, error) { + + return c.Send(&gomemcached.MCRequest{ + Opcode: gomemcached.SELECT_BUCKET, + Key: []byte(fmt.Sprintf("%s", bucket))}) +} + +func (c *Client) store(opcode gomemcached.CommandCode, vb uint16, + key string, flags int, exp int, body []byte) (*gomemcached.MCResponse, error) { + + req := &gomemcached.MCRequest{ + Opcode: opcode, + VBucket: vb, + Key: []byte(key), + Cas: 0, + Opaque: 0, + Extras: []byte{0, 0, 0, 0, 0, 0, 0, 0}, + Body: body} + + binary.BigEndian.PutUint64(req.Extras, uint64(flags)<<32|uint64(exp)) + return c.Send(req) +} + +func (c *Client) storeCas(opcode gomemcached.CommandCode, vb uint16, + key string, flags int, exp int, cas uint64, body []byte) (*gomemcached.MCResponse, error) { + + req := &gomemcached.MCRequest{ + Opcode: opcode, + VBucket: vb, + Key: []byte(key), + Cas: cas, + Opaque: 0, + Extras: []byte{0, 0, 0, 0, 0, 0, 0, 0}, + Body: body} + + binary.BigEndian.PutUint64(req.Extras, uint64(flags)<<32|uint64(exp)) + return c.Send(req) +} + +// Incr increments the value at the given key. +func (c *Client) Incr(vb uint16, key string, + amt, def uint64, exp int) (uint64, error) { + + req := &gomemcached.MCRequest{ + Opcode: gomemcached.INCREMENT, + VBucket: vb, + Key: []byte(key), + Extras: make([]byte, 8+8+4), + } + binary.BigEndian.PutUint64(req.Extras[:8], amt) + binary.BigEndian.PutUint64(req.Extras[8:16], def) + binary.BigEndian.PutUint32(req.Extras[16:20], uint32(exp)) + + resp, err := c.Send(req) + if err != nil { + return 0, err + } + + return binary.BigEndian.Uint64(resp.Body), nil +} + +// Decr decrements the value at the given key. +func (c *Client) Decr(vb uint16, key string, + amt, def uint64, exp int) (uint64, error) { + + req := &gomemcached.MCRequest{ + Opcode: gomemcached.DECREMENT, + VBucket: vb, + Key: []byte(key), + Extras: make([]byte, 8+8+4), + } + binary.BigEndian.PutUint64(req.Extras[:8], amt) + binary.BigEndian.PutUint64(req.Extras[8:16], def) + binary.BigEndian.PutUint32(req.Extras[16:20], uint32(exp)) + + resp, err := c.Send(req) + if err != nil { + return 0, err + } + + return binary.BigEndian.Uint64(resp.Body), nil +} + +// Add a value for a key (store if not exists). +func (c *Client) Add(vb uint16, key string, flags int, exp int, + body []byte) (*gomemcached.MCResponse, error) { + return c.store(gomemcached.ADD, vb, key, flags, exp, body) +} + +// Set the value for a key. +func (c *Client) Set(vb uint16, key string, flags int, exp int, + body []byte) (*gomemcached.MCResponse, error) { + return c.store(gomemcached.SET, vb, key, flags, exp, body) +} + +// SetCas set the value for a key with cas +func (c *Client) SetCas(vb uint16, key string, flags int, exp int, cas uint64, + body []byte) (*gomemcached.MCResponse, error) { + return c.storeCas(gomemcached.SET, vb, key, flags, exp, cas, body) +} + +// Append data to the value of a key. +func (c *Client) Append(vb uint16, key string, data []byte) (*gomemcached.MCResponse, error) { + req := &gomemcached.MCRequest{ + Opcode: gomemcached.APPEND, + VBucket: vb, + Key: []byte(key), + Cas: 0, + Opaque: 0, + Body: data} + + return c.Send(req) +} + +// GetBulk gets keys in bulk +func (c *Client) GetBulk(vb uint16, keys []string, rv map[string]*gomemcached.MCResponse, subPaths []string) error { + stopch := make(chan bool) + var wg sync.WaitGroup + + defer func() { + close(stopch) + wg.Wait() + }() + + if (math.MaxInt32 - c.opaque) < (uint32(len(keys)) + 1) { + c.opaque = uint32(1) + } + + opStart := c.opaque + + errch := make(chan error, 2) + + wg.Add(1) + go func() { + defer func() { + if r := recover(); r != nil { + logging.Infof("Recovered in f %v", r) + } + errch <- nil + wg.Done() + }() + + ok := true + for ok { + + select { + case <-stopch: + return + default: + res, err := c.Receive() + + if err != nil && IfResStatusError(res) { + if res == nil || res.Status != gomemcached.KEY_ENOENT { + errch <- err + return + } + // continue receiving in case of KEY_ENOENT + } else if res.Opcode == gomemcached.GET { + opaque := res.Opaque - opStart + if opaque < 0 || opaque >= uint32(len(keys)) { + // Every now and then we seem to be seeing an invalid opaque + // value returned from the server. When this happens log the error + // and the calling function will retry the bulkGet. MB-15140 + logging.Errorf(" Invalid opaque Value. Debug info : Res.opaque : %v(%v), Keys %v, Response received %v \n key list %v this key %v", res.Opaque, opaque, len(keys), res, keys, string(res.Body)) + errch <- fmt.Errorf("Out of Bounds error") + return + } + + rv[keys[opaque]] = res + } + + if res.Opcode == gomemcached.NOOP { + ok = false + } + + if res.Opcode == gomemcached.SUBDOC_GET || res.Opcode == gomemcached.SUBDOC_MULTI_LOOKUP { + opaque := res.Opaque - opStart + rv[keys[opaque]] = res + } + + } + } + }() + + memcachedReqPkt := &gomemcached.MCRequest{ + Opcode: gomemcached.GET, + VBucket: vb, + } + + if len(subPaths) > 0 { + extraBuf, valueBuf := GetSubDocVal(subPaths) + memcachedReqPkt.Opcode = gomemcached.SUBDOC_MULTI_LOOKUP + memcachedReqPkt.Extras = extraBuf + memcachedReqPkt.Body = valueBuf + } + + for _, k := range keys { // Start of Get request + memcachedReqPkt.Key = []byte(k) + memcachedReqPkt.Opaque = c.opaque + + err := c.Transmit(memcachedReqPkt) + if err != nil { + logging.Errorf(" Transmit failed in GetBulkAll %v", err) + return err + } + c.opaque++ + } // End of Get request + + // finally transmit a NOOP + err := c.Transmit(&gomemcached.MCRequest{ + Opcode: gomemcached.NOOP, + VBucket: vb, + Opaque: c.opaque, + }) + + if err != nil { + logging.Errorf(" Transmit of NOOP failed %v", err) + return err + } + c.opaque++ + + return <-errch +} + +func GetSubDocVal(subPaths []string) (extraBuf, valueBuf []byte) { + + var ops []string + totalBytesLen := 0 + num := 1 + + for _, v := range subPaths { + totalBytesLen = totalBytesLen + len([]byte(v)) + ops = append(ops, v) + num = num + 1 + } + + // Xattr retrieval - subdoc multi get + extraBuf = append(extraBuf, uint8(0x04)) + + valueBuf = make([]byte, num*4+totalBytesLen) + + //opcode for subdoc get + op := gomemcached.SUBDOC_GET + + // Calculate path total bytes + // There are 2 ops - get xattrs - both input and $document and get whole doc + valIter := 0 + + for _, v := range ops { + pathBytes := []byte(v) + valueBuf[valIter+0] = uint8(op) + + // SubdocFlagXattrPath indicates that the path refers to + // an Xattr rather than the document body. + valueBuf[valIter+1] = uint8(gomemcached.SUBDOC_FLAG_XATTR) + + // 2 byte key + binary.BigEndian.PutUint16(valueBuf[valIter+2:], uint16(len(pathBytes))) + + // Then n bytes path + copy(valueBuf[valIter+4:], pathBytes) + valIter = valIter + 4 + len(pathBytes) + } + + return +} + +// ObservedStatus is the type reported by the Observe method +type ObservedStatus uint8 + +// Observation status values. +const ( + ObservedNotPersisted = ObservedStatus(0x00) // found, not persisted + ObservedPersisted = ObservedStatus(0x01) // found, persisted + ObservedNotFound = ObservedStatus(0x80) // not found (or a persisted delete) + ObservedLogicallyDeleted = ObservedStatus(0x81) // pending deletion (not persisted yet) +) + +// ObserveResult represents the data obtained by an Observe call +type ObserveResult struct { + Status ObservedStatus // Whether the value has been persisted/deleted + Cas uint64 // Current value's CAS + PersistenceTime time.Duration // Node's average time to persist a value + ReplicationTime time.Duration // Node's average time to replicate a value +} + +// Observe gets the persistence/replication/CAS state of a key +func (c *Client) Observe(vb uint16, key string) (result ObserveResult, err error) { + // http://www.couchbase.com/wiki/display/couchbase/Observe + body := make([]byte, 4+len(key)) + binary.BigEndian.PutUint16(body[0:2], vb) + binary.BigEndian.PutUint16(body[2:4], uint16(len(key))) + copy(body[4:4+len(key)], key) + + res, err := c.Send(&gomemcached.MCRequest{ + Opcode: gomemcached.OBSERVE, + VBucket: vb, + Body: body, + }) + if err != nil { + return + } + + // Parse the response data from the body: + if len(res.Body) < 2+2+1 { + err = io.ErrUnexpectedEOF + return + } + outVb := binary.BigEndian.Uint16(res.Body[0:2]) + keyLen := binary.BigEndian.Uint16(res.Body[2:4]) + if len(res.Body) < 2+2+int(keyLen)+1+8 { + err = io.ErrUnexpectedEOF + return + } + outKey := string(res.Body[4 : 4+keyLen]) + if outVb != vb || outKey != key { + err = fmt.Errorf("observe returned wrong vbucket/key: %d/%q", outVb, outKey) + return + } + result.Status = ObservedStatus(res.Body[4+keyLen]) + result.Cas = binary.BigEndian.Uint64(res.Body[5+keyLen:]) + // The response reuses the Cas field to store time statistics: + result.PersistenceTime = time.Duration(res.Cas>>32) * time.Millisecond + result.ReplicationTime = time.Duration(res.Cas&math.MaxUint32) * time.Millisecond + return +} + +// CheckPersistence checks whether a stored value has been persisted to disk yet. +func (result ObserveResult) CheckPersistence(cas uint64, deletion bool) (persisted bool, overwritten bool) { + switch { + case result.Status == ObservedNotFound && deletion: + persisted = true + case result.Cas != cas: + overwritten = true + case result.Status == ObservedPersisted: + persisted = true + } + return +} + +// Sequence number based Observe Implementation +type ObserveSeqResult struct { + Failover uint8 // Set to 1 if a failover took place + VbId uint16 // vbucket id + Vbuuid uint64 // vucket uuid + LastPersistedSeqNo uint64 // last persisted sequence number + CurrentSeqNo uint64 // current sequence number + OldVbuuid uint64 // Old bucket vbuuid + LastSeqNo uint64 // last sequence number received before failover +} + +func (c *Client) ObserveSeq(vb uint16, vbuuid uint64) (result *ObserveSeqResult, err error) { + // http://www.couchbase.com/wiki/display/couchbase/Observe + body := make([]byte, 8) + binary.BigEndian.PutUint64(body[0:8], vbuuid) + + res, err := c.Send(&gomemcached.MCRequest{ + Opcode: gomemcached.OBSERVE_SEQNO, + VBucket: vb, + Body: body, + Opaque: 0x01, + }) + if err != nil { + return + } + + if res.Status != gomemcached.SUCCESS { + return nil, fmt.Errorf(" Observe returned error %v", res.Status) + } + + // Parse the response data from the body: + if len(res.Body) < (1 + 2 + 8 + 8 + 8) { + err = io.ErrUnexpectedEOF + return + } + + result = &ObserveSeqResult{} + result.Failover = res.Body[0] + result.VbId = binary.BigEndian.Uint16(res.Body[1:3]) + result.Vbuuid = binary.BigEndian.Uint64(res.Body[3:11]) + result.LastPersistedSeqNo = binary.BigEndian.Uint64(res.Body[11:19]) + result.CurrentSeqNo = binary.BigEndian.Uint64(res.Body[19:27]) + + // in case of failover processing we can have old vbuuid and the last persisted seq number + if result.Failover == 1 && len(res.Body) >= (1+2+8+8+8+8+8) { + result.OldVbuuid = binary.BigEndian.Uint64(res.Body[27:35]) + result.LastSeqNo = binary.BigEndian.Uint64(res.Body[35:43]) + } + + return +} + +// CasOp is the type of operation to perform on this CAS loop. +type CasOp uint8 + +const ( + // CASStore instructs the server to store the new value normally + CASStore = CasOp(iota) + // CASQuit instructs the client to stop attempting to CAS, leaving value untouched + CASQuit + // CASDelete instructs the server to delete the current value + CASDelete +) + +// User specified termination is returned as an error. +func (c CasOp) Error() string { + switch c { + case CASStore: + return "CAS store" + case CASQuit: + return "CAS quit" + case CASDelete: + return "CAS delete" + } + panic("Unhandled value") +} + +//////// CAS TRANSFORM + +// CASState tracks the state of CAS over several operations. +// +// This is used directly by CASNext and indirectly by CAS +type CASState struct { + initialized bool // false on the first call to CASNext, then true + Value []byte // Current value of key; update in place to new value + Cas uint64 // Current CAS value of key + Exists bool // Does a value exist for the key? (If not, Value will be nil) + Err error // Error, if any, after CASNext returns false + resp *gomemcached.MCResponse +} + +// CASNext is a non-callback, loop-based version of CAS method. +// +// Usage is like this: +// +// var state memcached.CASState +// for client.CASNext(vb, key, exp, &state) { +// state.Value = some_mutation(state.Value) +// } +// if state.Err != nil { ... } +func (c *Client) CASNext(vb uint16, k string, exp int, state *CASState) bool { + if state.initialized { + if !state.Exists { + // Adding a new key: + if state.Value == nil { + state.Cas = 0 + return false // no-op (delete of non-existent value) + } + state.resp, state.Err = c.Add(vb, k, 0, exp, state.Value) + } else { + // Updating / deleting a key: + req := &gomemcached.MCRequest{ + Opcode: gomemcached.DELETE, + VBucket: vb, + Key: []byte(k), + Cas: state.Cas} + if state.Value != nil { + req.Opcode = gomemcached.SET + req.Opaque = 0 + req.Extras = []byte{0, 0, 0, 0, 0, 0, 0, 0} + req.Body = state.Value + + flags := 0 + binary.BigEndian.PutUint64(req.Extras, uint64(flags)<<32|uint64(exp)) + } + state.resp, state.Err = c.Send(req) + } + + // If the response status is KEY_EEXISTS or NOT_STORED there's a conflict and we'll need to + // get the new value (below). Otherwise, we're done (either success or failure) so return: + if !(state.resp != nil && (state.resp.Status == gomemcached.KEY_EEXISTS || + state.resp.Status == gomemcached.NOT_STORED)) { + state.Cas = state.resp.Cas + return false // either success or fatal error + } + } + + // Initial call, or after a conflict: GET the current value and CAS and return them: + state.initialized = true + if state.resp, state.Err = c.Get(vb, k); state.Err == nil { + state.Exists = true + state.Value = state.resp.Body + state.Cas = state.resp.Cas + } else if state.resp != nil && state.resp.Status == gomemcached.KEY_ENOENT { + state.Err = nil + state.Exists = false + state.Value = nil + state.Cas = 0 + } else { + return false // fatal error + } + return true // keep going... +} + +// CasFunc is type type of function to perform a CAS transform. +// +// Input is the current value, or nil if no value exists. +// The function should return the new value (if any) to set, and the store/quit/delete operation. +type CasFunc func(current []byte) ([]byte, CasOp) + +// CAS performs a CAS transform with the given function. +// +// If the value does not exist, a nil current value will be sent to f. +func (c *Client) CAS(vb uint16, k string, f CasFunc, + initexp int) (*gomemcached.MCResponse, error) { + var state CASState + for c.CASNext(vb, k, initexp, &state) { + newValue, operation := f(state.Value) + if operation == CASQuit || (operation == CASDelete && state.Value == nil) { + return nil, operation + } + state.Value = newValue + } + return state.resp, state.Err +} + +// StatValue is one of the stats returned from the Stats method. +type StatValue struct { + // The stat key + Key string + // The stat value + Val string +} + +// Stats requests server-side stats. +// +// Use "" as the stat key for toplevel stats. +func (c *Client) Stats(key string) ([]StatValue, error) { + rv := make([]StatValue, 0, 128) + + req := &gomemcached.MCRequest{ + Opcode: gomemcached.STAT, + Key: []byte(key), + Opaque: 918494, + } + + err := c.Transmit(req) + if err != nil { + return rv, err + } + + for { + res, _, err := getResponse(c.conn, c.hdrBuf) + if err != nil { + return rv, err + } + k := string(res.Key) + if k == "" { + break + } + rv = append(rv, StatValue{ + Key: k, + Val: string(res.Body), + }) + } + return rv, nil +} + +// StatsMap requests server-side stats similarly to Stats, but returns +// them as a map. +// +// Use "" as the stat key for toplevel stats. +func (c *Client) StatsMap(key string) (map[string]string, error) { + rv := make(map[string]string) + + req := &gomemcached.MCRequest{ + Opcode: gomemcached.STAT, + Key: []byte(key), + Opaque: 918494, + } + + err := c.Transmit(req) + if err != nil { + return rv, err + } + + for { + res, _, err := getResponse(c.conn, c.hdrBuf) + if err != nil { + return rv, err + } + k := string(res.Key) + if k == "" { + break + } + rv[k] = string(res.Body) + } + + return rv, nil +} + +// instead of returning a new statsMap, simply populate passed in statsMap, which contains all the keys +// for which stats needs to be retrieved +func (c *Client) StatsMapForSpecifiedStats(key string, statsMap map[string]string) error { + + // clear statsMap + for key, _ := range statsMap { + statsMap[key] = "" + } + + req := &gomemcached.MCRequest{ + Opcode: gomemcached.STAT, + Key: []byte(key), + Opaque: 918494, + } + + err := c.Transmit(req) + if err != nil { + return err + } + + for { + res, _, err := getResponse(c.conn, c.hdrBuf) + if err != nil { + return err + } + k := string(res.Key) + if k == "" { + break + } + if _, ok := statsMap[k]; ok { + statsMap[k] = string(res.Body) + } + } + + return nil +} + +// Hijack exposes the underlying connection from this client. +// +// It also marks the connection as unhealthy since the client will +// have lost control over the connection and can't otherwise verify +// things are in good shape for connection pools. +func (c *Client) Hijack() io.ReadWriteCloser { + c.setHealthy(false) + return c.conn +} + +func (c *Client) setHealthy(healthy bool) { + healthyState := UnHealthy + if healthy { + healthyState = Healthy + } + atomic.StoreUint32(&c.healthy, healthyState) +} + +func IfResStatusError(response *gomemcached.MCResponse) bool { + return response == nil || + (response.Status != gomemcached.SUBDOC_BAD_MULTI && + response.Status != gomemcached.SUBDOC_PATH_NOT_FOUND && + response.Status != gomemcached.SUBDOC_MULTI_PATH_FAILURE_DELETED) +} diff --git a/vendor/github.com/couchbase/gomemcached/client/tap_feed.go b/vendor/github.com/couchbase/gomemcached/client/tap_feed.go new file mode 100644 index 00000000..fd628c5d --- /dev/null +++ b/vendor/github.com/couchbase/gomemcached/client/tap_feed.go @@ -0,0 +1,333 @@ +package memcached + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "math" + + "github.com/couchbase/gomemcached" + "github.com/couchbase/goutils/logging" +) + +// TAP protocol docs: + +// TapOpcode is the tap operation type (found in TapEvent) +type TapOpcode uint8 + +// Tap opcode values. +const ( + TapBeginBackfill = TapOpcode(iota) + TapEndBackfill + TapMutation + TapDeletion + TapCheckpointStart + TapCheckpointEnd + tapEndStream +) + +const tapMutationExtraLen = 16 + +var tapOpcodeNames map[TapOpcode]string + +func init() { + tapOpcodeNames = map[TapOpcode]string{ + TapBeginBackfill: "BeginBackfill", + TapEndBackfill: "EndBackfill", + TapMutation: "Mutation", + TapDeletion: "Deletion", + TapCheckpointStart: "TapCheckpointStart", + TapCheckpointEnd: "TapCheckpointEnd", + tapEndStream: "EndStream", + } +} + +func (opcode TapOpcode) String() string { + name := tapOpcodeNames[opcode] + if name == "" { + name = fmt.Sprintf("#%d", opcode) + } + return name +} + +// TapEvent is a TAP notification of an operation on the server. +type TapEvent struct { + Opcode TapOpcode // Type of event + VBucket uint16 // VBucket this event applies to + Flags uint32 // Item flags + Expiry uint32 // Item expiration time + Key, Value []byte // Item key/value + Cas uint64 +} + +func makeTapEvent(req gomemcached.MCRequest) *TapEvent { + event := TapEvent{ + VBucket: req.VBucket, + } + switch req.Opcode { + case gomemcached.TAP_MUTATION: + event.Opcode = TapMutation + event.Key = req.Key + event.Value = req.Body + event.Cas = req.Cas + case gomemcached.TAP_DELETE: + event.Opcode = TapDeletion + event.Key = req.Key + event.Cas = req.Cas + case gomemcached.TAP_CHECKPOINT_START: + event.Opcode = TapCheckpointStart + case gomemcached.TAP_CHECKPOINT_END: + event.Opcode = TapCheckpointEnd + case gomemcached.TAP_OPAQUE: + if len(req.Extras) < 8+4 { + return nil + } + switch op := int(binary.BigEndian.Uint32(req.Extras[8:])); op { + case gomemcached.TAP_OPAQUE_INITIAL_VBUCKET_STREAM: + event.Opcode = TapBeginBackfill + case gomemcached.TAP_OPAQUE_CLOSE_BACKFILL: + event.Opcode = TapEndBackfill + case gomemcached.TAP_OPAQUE_CLOSE_TAP_STREAM: + event.Opcode = tapEndStream + case gomemcached.TAP_OPAQUE_ENABLE_AUTO_NACK: + return nil + case gomemcached.TAP_OPAQUE_ENABLE_CHECKPOINT_SYNC: + return nil + default: + logging.Infof("TapFeed: Ignoring TAP_OPAQUE/%d", op) + return nil // unknown opaque event + } + case gomemcached.NOOP: + return nil // ignore + default: + logging.Infof("TapFeed: Ignoring %s", req.Opcode) + return nil // unknown event + } + + if len(req.Extras) >= tapMutationExtraLen && + (event.Opcode == TapMutation || event.Opcode == TapDeletion) { + + event.Flags = binary.BigEndian.Uint32(req.Extras[8:]) + event.Expiry = binary.BigEndian.Uint32(req.Extras[12:]) + } + + return &event +} + +func (event TapEvent) String() string { + switch event.Opcode { + case TapBeginBackfill, TapEndBackfill, TapCheckpointStart, TapCheckpointEnd: + return fmt.Sprintf("", + event.Opcode, event.VBucket) + default: + return fmt.Sprintf("", + event.Opcode, event.Key, len(event.Value), + event.Flags, event.Expiry) + } +} + +// TapArguments are parameters for requesting a TAP feed. +// +// Call DefaultTapArguments to get a default one. +type TapArguments struct { + // Timestamp of oldest item to send. + // + // Use TapNoBackfill to suppress all past items. + Backfill uint64 + // If set, server will disconnect after sending existing items. + Dump bool + // The indices of the vbuckets to watch; empty/nil to watch all. + VBuckets []uint16 + // Transfers ownership of vbuckets during cluster rebalance. + Takeover bool + // If true, server will wait for client ACK after every notification. + SupportAck bool + // If true, client doesn't want values so server shouldn't send them. + KeysOnly bool + // If true, client wants the server to send checkpoint events. + Checkpoint bool + // Optional identifier to use for this client, to allow reconnects + ClientName string + // Registers this client (by name) till explicitly deregistered. + RegisteredClient bool +} + +// Value for TapArguments.Backfill denoting that no past events at all +// should be sent. +const TapNoBackfill = math.MaxUint64 + +// DefaultTapArguments returns a default set of parameter values to +// pass to StartTapFeed. +func DefaultTapArguments() TapArguments { + return TapArguments{ + Backfill: TapNoBackfill, + } +} + +func (args *TapArguments) flags() []byte { + var flags gomemcached.TapConnectFlag + if args.Backfill != 0 { + flags |= gomemcached.BACKFILL + } + if args.Dump { + flags |= gomemcached.DUMP + } + if len(args.VBuckets) > 0 { + flags |= gomemcached.LIST_VBUCKETS + } + if args.Takeover { + flags |= gomemcached.TAKEOVER_VBUCKETS + } + if args.SupportAck { + flags |= gomemcached.SUPPORT_ACK + } + if args.KeysOnly { + flags |= gomemcached.REQUEST_KEYS_ONLY + } + if args.Checkpoint { + flags |= gomemcached.CHECKPOINT + } + if args.RegisteredClient { + flags |= gomemcached.REGISTERED_CLIENT + } + encoded := make([]byte, 4) + binary.BigEndian.PutUint32(encoded, uint32(flags)) + return encoded +} + +func must(err error) { + if err != nil { + panic(err) + } +} + +func (args *TapArguments) bytes() (rv []byte) { + buf := bytes.NewBuffer([]byte{}) + + if args.Backfill > 0 { + must(binary.Write(buf, binary.BigEndian, uint64(args.Backfill))) + } + + if len(args.VBuckets) > 0 { + must(binary.Write(buf, binary.BigEndian, uint16(len(args.VBuckets)))) + for i := 0; i < len(args.VBuckets); i++ { + must(binary.Write(buf, binary.BigEndian, uint16(args.VBuckets[i]))) + } + } + return buf.Bytes() +} + +// TapFeed represents a stream of events from a server. +type TapFeed struct { + C <-chan TapEvent + Error error + closer chan bool +} + +// StartTapFeed starts a TAP feed on a client connection. +// +// The events can be read from the returned channel. The connection +// can no longer be used for other purposes; it's now reserved for +// receiving the TAP messages. To stop receiving events, close the +// client connection. +func (mc *Client) StartTapFeed(args TapArguments) (*TapFeed, error) { + rq := &gomemcached.MCRequest{ + Opcode: gomemcached.TAP_CONNECT, + Key: []byte(args.ClientName), + Extras: args.flags(), + Body: args.bytes()} + + err := mc.Transmit(rq) + if err != nil { + return nil, err + } + + ch := make(chan TapEvent) + feed := &TapFeed{ + C: ch, + closer: make(chan bool), + } + go mc.runFeed(ch, feed) + return feed, nil +} + +// TapRecvHook is called after every incoming tap packet is received. +var TapRecvHook func(*gomemcached.MCRequest, int, error) + +// Internal goroutine that reads from the socket and writes events to +// the channel +func (mc *Client) runFeed(ch chan TapEvent, feed *TapFeed) { + defer close(ch) + var headerBuf [gomemcached.HDR_LEN]byte +loop: + for { + // Read the next request from the server. + // + // (Can't call mc.Receive() because it reads a + // _response_ not a request.) + var pkt gomemcached.MCRequest + n, err := pkt.Receive(mc.conn, headerBuf[:]) + if TapRecvHook != nil { + TapRecvHook(&pkt, n, err) + } + + if err != nil { + if err != io.EOF { + feed.Error = err + } + break loop + } + + //logging.Infof("** TapFeed received %#v : %q", pkt, pkt.Body) + + if pkt.Opcode == gomemcached.TAP_CONNECT { + // This is not an event from the server; it's + // an error response to my connect request. + feed.Error = fmt.Errorf("tap connection failed: %s", pkt.Body) + break loop + } + + event := makeTapEvent(pkt) + if event != nil { + if event.Opcode == tapEndStream { + break loop + } + + select { + case ch <- *event: + case <-feed.closer: + break loop + } + } + + if len(pkt.Extras) >= 4 { + reqFlags := binary.BigEndian.Uint16(pkt.Extras[2:]) + if reqFlags&gomemcached.TAP_ACK != 0 { + if _, err := mc.sendAck(&pkt); err != nil { + feed.Error = err + break loop + } + } + } + } + if err := mc.Close(); err != nil { + logging.Errorf("Error closing memcached client: %v", err) + } +} + +func (mc *Client) sendAck(pkt *gomemcached.MCRequest) (int, error) { + res := gomemcached.MCResponse{ + Opcode: pkt.Opcode, + Opaque: pkt.Opaque, + Status: gomemcached.SUCCESS, + } + return res.Transmit(mc.conn) +} + +// Close terminates a TapFeed. +// +// Call this if you stop using a TapFeed before its channel ends. +func (feed *TapFeed) Close() { + close(feed.closer) +} diff --git a/vendor/github.com/couchbase/gomemcached/client/transport.go b/vendor/github.com/couchbase/gomemcached/client/transport.go new file mode 100644 index 00000000..f4cea17f --- /dev/null +++ b/vendor/github.com/couchbase/gomemcached/client/transport.go @@ -0,0 +1,67 @@ +package memcached + +import ( + "errors" + "io" + + "github.com/couchbase/gomemcached" +) + +var errNoConn = errors.New("no connection") + +// UnwrapMemcachedError converts memcached errors to normal responses. +// +// If the error is a memcached response, declare the error to be nil +// so a client can handle the status without worrying about whether it +// indicates success or failure. +func UnwrapMemcachedError(rv *gomemcached.MCResponse, + err error) (*gomemcached.MCResponse, error) { + + if rv == err { + return rv, nil + } + return rv, err +} + +// ReceiveHook is called after every packet is received (or attempted to be) +var ReceiveHook func(*gomemcached.MCResponse, int, error) + +func getResponse(s io.Reader, hdrBytes []byte) (rv *gomemcached.MCResponse, n int, err error) { + if s == nil { + return nil, 0, errNoConn + } + + rv = &gomemcached.MCResponse{} + n, err = rv.Receive(s, hdrBytes) + + if ReceiveHook != nil { + ReceiveHook(rv, n, err) + } + + if err == nil && (rv.Status != gomemcached.SUCCESS && rv.Status != gomemcached.AUTH_CONTINUE) { + err = rv + } + return rv, n, err +} + +// TransmitHook is called after each packet is transmitted. +var TransmitHook func(*gomemcached.MCRequest, int, error) + +func transmitRequest(o io.Writer, req *gomemcached.MCRequest) (int, error) { + if o == nil { + return 0, errNoConn + } + n, err := req.Transmit(o) + if TransmitHook != nil { + TransmitHook(req, n, err) + } + return n, err +} + +func transmitResponse(o io.Writer, res *gomemcached.MCResponse) (int, error) { + if o == nil { + return 0, errNoConn + } + n, err := res.Transmit(o) + return n, err +} diff --git a/vendor/github.com/couchbase/gomemcached/client/upr_feed.go b/vendor/github.com/couchbase/gomemcached/client/upr_feed.go new file mode 100644 index 00000000..dc737e6c --- /dev/null +++ b/vendor/github.com/couchbase/gomemcached/client/upr_feed.go @@ -0,0 +1,1005 @@ +// go implementation of upr client. +// See https://github.com/couchbaselabs/cbupr/blob/master/transport-spec.md +// TODO +// 1. Use a pool allocator to avoid garbage +package memcached + +import ( + "encoding/binary" + "errors" + "fmt" + "github.com/couchbase/gomemcached" + "github.com/couchbase/goutils/logging" + "strconv" + "sync" + "sync/atomic" +) + +const uprMutationExtraLen = 30 +const uprDeletetionExtraLen = 18 +const uprDeletetionWithDeletionTimeExtraLen = 21 +const uprSnapshotExtraLen = 20 +const bufferAckThreshold = 0.2 +const opaqueOpen = 0xBEAF0001 +const opaqueFailover = 0xDEADBEEF +const uprDefaultNoopInterval = 120 + +// Counter on top of opaqueOpen that others can draw from for open and control msgs +var opaqueOpenCtrlWell uint32 = opaqueOpen + +// UprEvent memcached events for UPR streams. +type UprEvent struct { + Opcode gomemcached.CommandCode // Type of event + Status gomemcached.Status // Response status + VBucket uint16 // VBucket this event applies to + DataType uint8 // data type + Opaque uint16 // 16 MSB of opaque + VBuuid uint64 // This field is set by downstream + Flags uint32 // Item flags + Expiry uint32 // Item expiration time + Key, Value []byte // Item key/value + OldValue []byte // TODO: TBD: old document value + Cas uint64 // CAS value of the item + Seqno uint64 // sequence number of the mutation + RevSeqno uint64 // rev sequence number : deletions + LockTime uint32 // Lock time + MetadataSize uint16 // Metadata size + SnapstartSeq uint64 // start sequence number of this snapshot + SnapendSeq uint64 // End sequence number of the snapshot + SnapshotType uint32 // 0: disk 1: memory + FailoverLog *FailoverLog // Failover log containing vvuid and sequnce number + Error error // Error value in case of a failure + ExtMeta []byte + AckSize uint32 // The number of bytes that can be Acked to DCP +} + +// UprStream is per stream data structure over an UPR Connection. +type UprStream struct { + Vbucket uint16 // Vbucket id + Vbuuid uint64 // vbucket uuid + StartSeq uint64 // start sequence number + EndSeq uint64 // end sequence number + connected bool +} + +const ( + CompressionTypeStartMarker = iota // also means invalid + CompressionTypeNone = iota + CompressionTypeSnappy = iota + CompressionTypeEndMarker = iota // also means invalid +) + +// kv_engine/include/mcbp/protocol/datatype.h +const ( + JSONDataType uint8 = 1 + SnappyDataType uint8 = 2 + XattrDataType uint8 = 4 +) + +type UprFeatures struct { + Xattribute bool + CompressionType int + IncludeDeletionTime bool +} + +/** + * Used to handle multiple concurrent calls UprRequestStream() by UprFeed clients + * It is expected that a client that calls UprRequestStream() more than once should issue + * different "opaque" (version) numbers + */ +type opaqueStreamMap map[uint16]*UprStream // opaque -> stream + +type vbStreamNegotiator struct { + vbHandshakeMap map[uint16]opaqueStreamMap // vbno -> opaqueStreamMap + mutex sync.RWMutex +} + +func (negotiator *vbStreamNegotiator) initialize() { + negotiator.mutex.Lock() + negotiator.vbHandshakeMap = make(map[uint16]opaqueStreamMap) + negotiator.mutex.Unlock() +} + +func (negotiator *vbStreamNegotiator) registerRequest(vbno, opaque uint16, vbuuid, startSequence, endSequence uint64) { + negotiator.mutex.Lock() + defer negotiator.mutex.Unlock() + + var osMap opaqueStreamMap + var ok bool + if osMap, ok = negotiator.vbHandshakeMap[vbno]; !ok { + osMap = make(opaqueStreamMap) + negotiator.vbHandshakeMap[vbno] = osMap + } + + if _, ok = osMap[opaque]; !ok { + osMap[opaque] = &UprStream{ + Vbucket: vbno, + Vbuuid: vbuuid, + StartSeq: startSequence, + EndSeq: endSequence, + } + } +} + +func (negotiator *vbStreamNegotiator) getStreamsCntFromMap(vbno uint16) int { + negotiator.mutex.RLock() + defer negotiator.mutex.RUnlock() + + osmap, ok := negotiator.vbHandshakeMap[vbno] + if !ok { + return 0 + } else { + return len(osmap) + } +} + +func (negotiator *vbStreamNegotiator) getStreamFromMap(vbno, opaque uint16) (*UprStream, error) { + negotiator.mutex.RLock() + defer negotiator.mutex.RUnlock() + + osmap, ok := negotiator.vbHandshakeMap[vbno] + if !ok { + return nil, fmt.Errorf("Error: stream for vb: %v does not exist", vbno) + } + + stream, ok := osmap[opaque] + if !ok { + return nil, fmt.Errorf("Error: stream for vb: %v opaque: %v does not exist", vbno, opaque) + } + return stream, nil +} + +func (negotiator *vbStreamNegotiator) deleteStreamFromMap(vbno, opaque uint16) { + negotiator.mutex.Lock() + defer negotiator.mutex.Unlock() + + osmap, ok := negotiator.vbHandshakeMap[vbno] + if !ok { + return + } + + delete(osmap, opaque) + if len(osmap) == 0 { + delete(negotiator.vbHandshakeMap, vbno) + } +} + +func (negotiator *vbStreamNegotiator) handleStreamRequest(feed *UprFeed, + headerBuf [gomemcached.HDR_LEN]byte, pktPtr *gomemcached.MCRequest, bytesReceivedFromDCP int, + response *gomemcached.MCResponse) (*UprEvent, error) { + var event *UprEvent + + if feed == nil || response == nil || pktPtr == nil { + return nil, errors.New("Invalid inputs") + } + + // Get Stream from negotiator map + vbno := vbOpaque(response.Opaque) + opaque := appOpaque(response.Opaque) + + stream, err := negotiator.getStreamFromMap(vbno, opaque) + if err != nil { + err = fmt.Errorf("Stream not found for vb %d appOpaque %v: %#v", vbno, appOpaque, *pktPtr) + logging.Errorf(err.Error()) + return nil, err + } + + status, rb, flog, err := handleStreamRequest(response, headerBuf[:]) + + if status == gomemcached.ROLLBACK { + event = makeUprEvent(*pktPtr, stream, bytesReceivedFromDCP) + event.Status = status + // rollback stream + logging.Infof("UPR_STREAMREQ with rollback %d for vb %d Failed: %v", rb, vbno, err) + negotiator.deleteStreamFromMap(vbno, opaque) + } else if status == gomemcached.SUCCESS { + event = makeUprEvent(*pktPtr, stream, bytesReceivedFromDCP) + event.Seqno = stream.StartSeq + event.FailoverLog = flog + event.Status = status + feed.activateStream(vbno, opaque, stream) + feed.negotiator.deleteStreamFromMap(vbno, opaque) + logging.Infof("UPR_STREAMREQ for vb %d successful", vbno) + + } else if err != nil { + logging.Errorf("UPR_STREAMREQ for vbucket %d erro %s", vbno, err.Error()) + event = &UprEvent{ + Opcode: gomemcached.UPR_STREAMREQ, + Status: status, + VBucket: vbno, + Error: err, + } + negotiator.deleteStreamFromMap(vbno, opaque) + } + return event, nil +} + +func (negotiator *vbStreamNegotiator) cleanUpVbStreams(vbno uint16) { + negotiator.mutex.Lock() + defer negotiator.mutex.Unlock() + + delete(negotiator.vbHandshakeMap, vbno) +} + +// UprFeed represents an UPR feed. A feed contains a connection to a single +// host and multiple vBuckets +type UprFeed struct { + // lock for feed.vbstreams + muVbstreams sync.RWMutex + // lock for feed.closed + muClosed sync.RWMutex + C <-chan *UprEvent // Exported channel for receiving UPR events + negotiator vbStreamNegotiator // Used for pre-vbstreams, concurrent vb stream negotiation + vbstreams map[uint16]*UprStream // official live vb->stream mapping + closer chan bool // closer + conn *Client // connection to UPR producer + Error error // error + bytesRead uint64 // total bytes read on this connection + toAckBytes uint32 // bytes client has read + maxAckBytes uint32 // Max buffer control ack bytes + stats UprStats // Stats for upr client + transmitCh chan *gomemcached.MCRequest // transmit command channel + transmitCl chan bool // closer channel for transmit go-routine + closed bool // flag indicating whether the feed has been closed + // flag indicating whether client of upr feed will send ack to upr feed + // if flag is true, upr feed will use ack from client to determine whether/when to send ack to DCP + // if flag is false, upr feed will track how many bytes it has sent to client + // and use that to determine whether/when to send ack to DCP + ackByClient bool +} + +// Exported interface - to allow for mocking +type UprFeedIface interface { + Close() + Closed() bool + CloseStream(vbno, opaqueMSB uint16) error + GetError() error + GetUprStats() *UprStats + ClientAck(event *UprEvent) error + GetUprEventCh() <-chan *UprEvent + StartFeed() error + StartFeedWithConfig(datachan_len int) error + UprOpen(name string, sequence uint32, bufSize uint32) error + UprOpenWithXATTR(name string, sequence uint32, bufSize uint32) error + UprOpenWithFeatures(name string, sequence uint32, bufSize uint32, features UprFeatures) (error, UprFeatures) + UprRequestStream(vbno, opaqueMSB uint16, flags uint32, vuuid, startSequence, endSequence, snapStart, snapEnd uint64) error +} + +type UprStats struct { + TotalBytes uint64 + TotalMutation uint64 + TotalBufferAckSent uint64 + TotalSnapShot uint64 +} + +// FailoverLog containing vvuid and sequnce number +type FailoverLog [][2]uint64 + +// error codes +var ErrorInvalidLog = errors.New("couchbase.errorInvalidLog") + +func (flogp *FailoverLog) Latest() (vbuuid, seqno uint64, err error) { + if flogp != nil { + flog := *flogp + latest := flog[len(flog)-1] + return latest[0], latest[1], nil + } + return vbuuid, seqno, ErrorInvalidLog +} + +func makeUprEvent(rq gomemcached.MCRequest, stream *UprStream, bytesReceivedFromDCP int) *UprEvent { + event := &UprEvent{ + Opcode: rq.Opcode, + VBucket: stream.Vbucket, + VBuuid: stream.Vbuuid, + Key: rq.Key, + Value: rq.Body, + Cas: rq.Cas, + ExtMeta: rq.ExtMeta, + DataType: rq.DataType, + } + + // set AckSize for events that need to be acked to DCP, + // i.e., events with CommandCodes that need to be buffered in DCP + if _, ok := gomemcached.BufferedCommandCodeMap[rq.Opcode]; ok { + event.AckSize = uint32(bytesReceivedFromDCP) + } + + // 16 LSBits are used by client library to encode vbucket number. + // 16 MSBits are left for application to multiplex on opaque value. + event.Opaque = appOpaque(rq.Opaque) + + if len(rq.Extras) >= uprMutationExtraLen && + event.Opcode == gomemcached.UPR_MUTATION { + + event.Seqno = binary.BigEndian.Uint64(rq.Extras[:8]) + event.RevSeqno = binary.BigEndian.Uint64(rq.Extras[8:16]) + event.Flags = binary.BigEndian.Uint32(rq.Extras[16:20]) + event.Expiry = binary.BigEndian.Uint32(rq.Extras[20:24]) + event.LockTime = binary.BigEndian.Uint32(rq.Extras[24:28]) + event.MetadataSize = binary.BigEndian.Uint16(rq.Extras[28:30]) + + } else if len(rq.Extras) >= uprDeletetionWithDeletionTimeExtraLen && + event.Opcode == gomemcached.UPR_DELETION { + + event.Seqno = binary.BigEndian.Uint64(rq.Extras[:8]) + event.RevSeqno = binary.BigEndian.Uint64(rq.Extras[8:16]) + event.Expiry = binary.BigEndian.Uint32(rq.Extras[16:20]) + + } else if len(rq.Extras) >= uprDeletetionExtraLen && + event.Opcode == gomemcached.UPR_DELETION || + event.Opcode == gomemcached.UPR_EXPIRATION { + + event.Seqno = binary.BigEndian.Uint64(rq.Extras[:8]) + event.RevSeqno = binary.BigEndian.Uint64(rq.Extras[8:16]) + event.MetadataSize = binary.BigEndian.Uint16(rq.Extras[16:18]) + + } else if len(rq.Extras) >= uprSnapshotExtraLen && + event.Opcode == gomemcached.UPR_SNAPSHOT { + + event.SnapstartSeq = binary.BigEndian.Uint64(rq.Extras[:8]) + event.SnapendSeq = binary.BigEndian.Uint64(rq.Extras[8:16]) + event.SnapshotType = binary.BigEndian.Uint32(rq.Extras[16:20]) + } + + return event +} + +func (event *UprEvent) String() string { + name := gomemcached.CommandNames[event.Opcode] + if name == "" { + name = fmt.Sprintf("#%d", event.Opcode) + } + return name +} + +func (event *UprEvent) IsSnappyDataType() bool { + return event.Opcode == gomemcached.UPR_MUTATION && (event.DataType&SnappyDataType > 0) +} + +func (feed *UprFeed) sendCommands(mc *Client) { + transmitCh := feed.transmitCh + transmitCl := feed.transmitCl +loop: + for { + select { + case command := <-transmitCh: + if err := mc.Transmit(command); err != nil { + logging.Errorf("Failed to transmit command %s. Error %s", command.Opcode.String(), err.Error()) + // get feed to close and runFeed routine to exit + feed.Close() + break loop + } + + case <-transmitCl: + break loop + } + } + + // After sendCommands exits, write to transmitCh will block forever + // when we write to transmitCh, e.g., at CloseStream(), we need to check feed closure to have an exit route + + logging.Infof("sendCommands exiting") +} + +// Sets the specified stream as the connected stream for this vbno, and also cleans up negotiator +func (feed *UprFeed) activateStream(vbno, opaque uint16, stream *UprStream) error { + feed.muVbstreams.Lock() + defer feed.muVbstreams.Unlock() + + // Set this stream as the officially connected stream for this vb + stream.connected = true + feed.vbstreams[vbno] = stream + return nil +} + +func (feed *UprFeed) cleanUpVbStream(vbno uint16) { + feed.muVbstreams.Lock() + defer feed.muVbstreams.Unlock() + + delete(feed.vbstreams, vbno) +} + +// NewUprFeed creates a new UPR Feed. +// TODO: Describe side-effects on bucket instance and its connection pool. +func (mc *Client) NewUprFeed() (*UprFeed, error) { + return mc.NewUprFeedWithConfig(false /*ackByClient*/) +} + +func (mc *Client) NewUprFeedWithConfig(ackByClient bool) (*UprFeed, error) { + + feed := &UprFeed{ + conn: mc, + closer: make(chan bool, 1), + vbstreams: make(map[uint16]*UprStream), + transmitCh: make(chan *gomemcached.MCRequest), + transmitCl: make(chan bool), + ackByClient: ackByClient, + } + + feed.negotiator.initialize() + + go feed.sendCommands(mc) + return feed, nil +} + +func (mc *Client) NewUprFeedIface() (UprFeedIface, error) { + return mc.NewUprFeed() +} + +func (mc *Client) NewUprFeedWithConfigIface(ackByClient bool) (UprFeedIface, error) { + return mc.NewUprFeedWithConfig(ackByClient) +} + +func doUprOpen(mc *Client, name string, sequence uint32, features UprFeatures) error { + rq := &gomemcached.MCRequest{ + Opcode: gomemcached.UPR_OPEN, + Key: []byte(name), + Opaque: getUprOpenCtrlOpaque(), + } + + rq.Extras = make([]byte, 8) + binary.BigEndian.PutUint32(rq.Extras[:4], sequence) + + // opens a producer type connection + flags := gomemcached.DCP_PRODUCER + if features.Xattribute { + flags = flags | gomemcached.DCP_OPEN_INCLUDE_XATTRS + } + if features.IncludeDeletionTime { + flags = flags | gomemcached.DCP_OPEN_INCLUDE_DELETE_TIMES + } + binary.BigEndian.PutUint32(rq.Extras[4:], flags) + + return sendMcRequestSync(mc, rq) +} + +// Synchronously send a memcached request and wait for the response +func sendMcRequestSync(mc *Client, req *gomemcached.MCRequest) error { + if err := mc.Transmit(req); err != nil { + return err + } + + if res, err := mc.Receive(); err != nil { + return err + } else if req.Opcode != res.Opcode { + return fmt.Errorf("unexpected #opcode sent %v received %v", req.Opcode, res.Opaque) + } else if req.Opaque != res.Opaque { + return fmt.Errorf("opaque mismatch, sent %v received %v", req.Opaque, res.Opaque) + } else if res.Status != gomemcached.SUCCESS { + return fmt.Errorf("error %v", res.Status) + } + return nil +} + +// UprOpen to connect with a UPR producer. +// Name: name of te UPR connection +// sequence: sequence number for the connection +// bufsize: max size of the application +func (feed *UprFeed) UprOpen(name string, sequence uint32, bufSize uint32) error { + var allFeaturesDisabled UprFeatures + err, _ := feed.uprOpen(name, sequence, bufSize, allFeaturesDisabled) + return err +} + +// UprOpen with XATTR enabled. +func (feed *UprFeed) UprOpenWithXATTR(name string, sequence uint32, bufSize uint32) error { + var onlyXattrEnabled UprFeatures + onlyXattrEnabled.Xattribute = true + err, _ := feed.uprOpen(name, sequence, bufSize, onlyXattrEnabled) + return err +} + +func (feed *UprFeed) UprOpenWithFeatures(name string, sequence uint32, bufSize uint32, features UprFeatures) (error, UprFeatures) { + return feed.uprOpen(name, sequence, bufSize, features) +} + +func (feed *UprFeed) uprOpen(name string, sequence uint32, bufSize uint32, features UprFeatures) (err error, activatedFeatures UprFeatures) { + mc := feed.conn + + // First set this to an invalid value to state that the method hasn't gotten to executing this control yet + activatedFeatures.CompressionType = CompressionTypeEndMarker + + if err = doUprOpen(mc, name, sequence, features); err != nil { + return + } + + activatedFeatures.Xattribute = features.Xattribute + + // send a UPR control message to set the window size for the this connection + if bufSize > 0 { + rq := &gomemcached.MCRequest{ + Opcode: gomemcached.UPR_CONTROL, + Key: []byte("connection_buffer_size"), + Body: []byte(strconv.Itoa(int(bufSize))), + Opaque: getUprOpenCtrlOpaque(), + } + err = sendMcRequestSync(feed.conn, rq) + if err != nil { + return + } + feed.maxAckBytes = uint32(bufferAckThreshold * float32(bufSize)) + } + + // enable noop and set noop interval + rq := &gomemcached.MCRequest{ + Opcode: gomemcached.UPR_CONTROL, + Key: []byte("enable_noop"), + Body: []byte("true"), + Opaque: getUprOpenCtrlOpaque(), + } + err = sendMcRequestSync(feed.conn, rq) + if err != nil { + return + } + + rq = &gomemcached.MCRequest{ + Opcode: gomemcached.UPR_CONTROL, + Key: []byte("set_noop_interval"), + Body: []byte(strconv.Itoa(int(uprDefaultNoopInterval))), + Opaque: getUprOpenCtrlOpaque(), + } + err = sendMcRequestSync(feed.conn, rq) + if err != nil { + return + } + + if features.CompressionType == CompressionTypeSnappy { + activatedFeatures.CompressionType = CompressionTypeNone + rq = &gomemcached.MCRequest{ + Opcode: gomemcached.UPR_CONTROL, + Key: []byte("force_value_compression"), + Body: []byte("true"), + Opaque: getUprOpenCtrlOpaque(), + } + err = sendMcRequestSync(feed.conn, rq) + } else if features.CompressionType == CompressionTypeEndMarker { + err = fmt.Errorf("UPR_CONTROL Failed - Invalid CompressionType: %v", features.CompressionType) + } + if err != nil { + return + } + activatedFeatures.CompressionType = features.CompressionType + + return +} + +// UprGetFailoverLog for given list of vbuckets. +func (mc *Client) UprGetFailoverLog( + vb []uint16) (map[uint16]*FailoverLog, error) { + + rq := &gomemcached.MCRequest{ + Opcode: gomemcached.UPR_FAILOVERLOG, + Opaque: opaqueFailover, + } + + var allFeaturesDisabled UprFeatures + if err := doUprOpen(mc, "FailoverLog", 0, allFeaturesDisabled); err != nil { + return nil, fmt.Errorf("UPR_OPEN Failed %s", err.Error()) + } + + failoverLogs := make(map[uint16]*FailoverLog) + for _, vBucket := range vb { + rq.VBucket = vBucket + if err := mc.Transmit(rq); err != nil { + return nil, err + } + res, err := mc.Receive() + + if err != nil { + return nil, fmt.Errorf("failed to receive %s", err.Error()) + } else if res.Opcode != gomemcached.UPR_FAILOVERLOG || res.Status != gomemcached.SUCCESS { + return nil, fmt.Errorf("unexpected #opcode %v", res.Opcode) + } + + flog, err := parseFailoverLog(res.Body) + if err != nil { + return nil, fmt.Errorf("unable to parse failover logs for vb %d", vb) + } + failoverLogs[vBucket] = flog + } + + return failoverLogs, nil +} + +// UprRequestStream for a single vbucket. +func (feed *UprFeed) UprRequestStream(vbno, opaqueMSB uint16, flags uint32, + vuuid, startSequence, endSequence, snapStart, snapEnd uint64) error { + + rq := &gomemcached.MCRequest{ + Opcode: gomemcached.UPR_STREAMREQ, + VBucket: vbno, + Opaque: composeOpaque(vbno, opaqueMSB), + } + + rq.Extras = make([]byte, 48) // #Extras + binary.BigEndian.PutUint32(rq.Extras[:4], flags) + binary.BigEndian.PutUint32(rq.Extras[4:8], uint32(0)) + binary.BigEndian.PutUint64(rq.Extras[8:16], startSequence) + binary.BigEndian.PutUint64(rq.Extras[16:24], endSequence) + binary.BigEndian.PutUint64(rq.Extras[24:32], vuuid) + binary.BigEndian.PutUint64(rq.Extras[32:40], snapStart) + binary.BigEndian.PutUint64(rq.Extras[40:48], snapEnd) + + feed.negotiator.registerRequest(vbno, opaqueMSB, vuuid, startSequence, endSequence) + // Any client that has ever called this method, regardless of return code, + // should expect a potential UPR_CLOSESTREAM message due to this new map entry prior to Transmit. + + if err := feed.conn.Transmit(rq); err != nil { + logging.Errorf("Error in StreamRequest %s", err.Error()) + // If an error occurs during transmit, then the UPRFeed will keep the stream + // in the vbstreams map. This is to prevent nil lookup from any previously + // sent stream requests. + return err + } + + return nil +} + +// CloseStream for specified vbucket. +func (feed *UprFeed) CloseStream(vbno, opaqueMSB uint16) error { + + err := feed.validateCloseStream(vbno) + if err != nil { + logging.Infof("CloseStream for %v has been skipped because of error %v", vbno, err) + return err + } + + closeStream := &gomemcached.MCRequest{ + Opcode: gomemcached.UPR_CLOSESTREAM, + VBucket: vbno, + Opaque: composeOpaque(vbno, opaqueMSB), + } + + feed.writeToTransmitCh(closeStream) + + return nil +} + +func (feed *UprFeed) GetUprEventCh() <-chan *UprEvent { + return feed.C +} + +func (feed *UprFeed) GetError() error { + return feed.Error +} + +func (feed *UprFeed) validateCloseStream(vbno uint16) error { + feed.muVbstreams.RLock() + nilVbStream := feed.vbstreams[vbno] == nil + feed.muVbstreams.RUnlock() + + if nilVbStream && (feed.negotiator.getStreamsCntFromMap(vbno) == 0) { + return fmt.Errorf("Stream for vb %d has not been requested", vbno) + } + + return nil +} + +func (feed *UprFeed) writeToTransmitCh(rq *gomemcached.MCRequest) error { + // write to transmitCh may block forever if sendCommands has exited + // check for feed closure to have an exit route in this case + select { + case <-feed.closer: + errMsg := fmt.Sprintf("Abort sending request to transmitCh because feed has been closed. request=%v", rq) + logging.Infof(errMsg) + return errors.New(errMsg) + case feed.transmitCh <- rq: + } + return nil +} + +// StartFeed to start the upper feed. +func (feed *UprFeed) StartFeed() error { + return feed.StartFeedWithConfig(10) +} + +func (feed *UprFeed) StartFeedWithConfig(datachan_len int) error { + ch := make(chan *UprEvent, datachan_len) + feed.C = ch + go feed.runFeed(ch) + return nil +} + +func parseFailoverLog(body []byte) (*FailoverLog, error) { + + if len(body)%16 != 0 { + err := fmt.Errorf("invalid body length %v, in failover-log", len(body)) + return nil, err + } + log := make(FailoverLog, len(body)/16) + for i, j := 0, 0; i < len(body); i += 16 { + vuuid := binary.BigEndian.Uint64(body[i : i+8]) + seqno := binary.BigEndian.Uint64(body[i+8 : i+16]) + log[j] = [2]uint64{vuuid, seqno} + j++ + } + return &log, nil +} + +func handleStreamRequest( + res *gomemcached.MCResponse, + headerBuf []byte, +) (gomemcached.Status, uint64, *FailoverLog, error) { + + var rollback uint64 + var err error + + switch { + case res.Status == gomemcached.ROLLBACK: + logging.Infof("Rollback response. body=%v, headerBuf=%v\n", res.Body, headerBuf) + rollback = binary.BigEndian.Uint64(res.Body) + logging.Infof("Rollback seqno is %v for response with opaque %v\n", rollback, res.Opaque) + return res.Status, rollback, nil, nil + + case res.Status != gomemcached.SUCCESS: + err = fmt.Errorf("unexpected status %v for response with opaque %v", res.Status, res.Opaque) + return res.Status, 0, nil, err + } + + flog, err := parseFailoverLog(res.Body[:]) + return res.Status, rollback, flog, err +} + +// generate stream end responses for all active vb streams +func (feed *UprFeed) doStreamClose(ch chan *UprEvent) { + feed.muVbstreams.RLock() + + uprEvents := make([]*UprEvent, len(feed.vbstreams)) + index := 0 + for vbno, stream := range feed.vbstreams { + uprEvent := &UprEvent{ + VBucket: vbno, + VBuuid: stream.Vbuuid, + Opcode: gomemcached.UPR_STREAMEND, + } + uprEvents[index] = uprEvent + index++ + } + + // release the lock before sending uprEvents to ch, which may block + feed.muVbstreams.RUnlock() + +loop: + for _, uprEvent := range uprEvents { + select { + case ch <- uprEvent: + case <-feed.closer: + logging.Infof("Feed has been closed. Aborting doStreamClose.") + break loop + } + } +} + +func (feed *UprFeed) runFeed(ch chan *UprEvent) { + defer close(ch) + var headerBuf [gomemcached.HDR_LEN]byte + var pkt gomemcached.MCRequest + var event *UprEvent + + mc := feed.conn.Hijack() + uprStats := &feed.stats + +loop: + for { + select { + case <-feed.closer: + logging.Infof("Feed has been closed. Exiting.") + break loop + default: + bytes, err := pkt.Receive(mc, headerBuf[:]) + if err != nil { + logging.Errorf("Error in receive %s", err.Error()) + feed.Error = err + // send all the stream close messages to the client + feed.doStreamClose(ch) + break loop + } else { + event = nil + res := &gomemcached.MCResponse{ + Opcode: pkt.Opcode, + Cas: pkt.Cas, + Opaque: pkt.Opaque, + Status: gomemcached.Status(pkt.VBucket), + Extras: pkt.Extras, + Key: pkt.Key, + Body: pkt.Body, + } + + vb := vbOpaque(pkt.Opaque) + appOpaque := appOpaque(pkt.Opaque) + uprStats.TotalBytes = uint64(bytes) + + feed.muVbstreams.RLock() + stream := feed.vbstreams[vb] + feed.muVbstreams.RUnlock() + + switch pkt.Opcode { + case gomemcached.UPR_STREAMREQ: + event, err = feed.negotiator.handleStreamRequest(feed, headerBuf, &pkt, bytes, res) + if err != nil { + logging.Infof(err.Error()) + break loop + } + case gomemcached.UPR_MUTATION, + gomemcached.UPR_DELETION, + gomemcached.UPR_EXPIRATION: + if stream == nil { + logging.Infof("Stream not found for vb %d: %#v", vb, pkt) + break loop + } + event = makeUprEvent(pkt, stream, bytes) + uprStats.TotalMutation++ + + case gomemcached.UPR_STREAMEND: + if stream == nil { + logging.Infof("Stream not found for vb %d: %#v", vb, pkt) + break loop + } + //stream has ended + event = makeUprEvent(pkt, stream, bytes) + logging.Infof("Stream Ended for vb %d", vb) + + feed.negotiator.deleteStreamFromMap(vb, appOpaque) + feed.cleanUpVbStream(vb) + + case gomemcached.UPR_SNAPSHOT: + if stream == nil { + logging.Infof("Stream not found for vb %d: %#v", vb, pkt) + break loop + } + // snapshot marker + event = makeUprEvent(pkt, stream, bytes) + uprStats.TotalSnapShot++ + + case gomemcached.UPR_FLUSH: + if stream == nil { + logging.Infof("Stream not found for vb %d: %#v", vb, pkt) + break loop + } + // special processing for flush ? + event = makeUprEvent(pkt, stream, bytes) + + case gomemcached.UPR_CLOSESTREAM: + if stream == nil { + logging.Infof("Stream not found for vb %d: %#v", vb, pkt) + break loop + } + event = makeUprEvent(pkt, stream, bytes) + event.Opcode = gomemcached.UPR_STREAMEND // opcode re-write !! + logging.Infof("Stream Closed for vb %d StreamEnd simulated", vb) + + feed.negotiator.deleteStreamFromMap(vb, appOpaque) + feed.cleanUpVbStream(vb) + + case gomemcached.UPR_ADDSTREAM: + logging.Infof("Opcode %v not implemented", pkt.Opcode) + + case gomemcached.UPR_CONTROL, gomemcached.UPR_BUFFERACK: + if res.Status != gomemcached.SUCCESS { + logging.Infof("Opcode %v received status %d", pkt.Opcode.String(), res.Status) + } + + case gomemcached.UPR_NOOP: + // send a NOOP back + noop := &gomemcached.MCResponse{ + Opcode: gomemcached.UPR_NOOP, + Opaque: pkt.Opaque, + } + + if err := feed.conn.TransmitResponse(noop); err != nil { + logging.Warnf("failed to transmit command %s. Error %s", noop.Opcode.String(), err.Error()) + } + default: + logging.Infof("Recived an unknown response for vbucket %d", vb) + } + } + + if event != nil { + select { + case ch <- event: + case <-feed.closer: + logging.Infof("Feed has been closed. Skip sending events. Exiting.") + break loop + } + + feed.muVbstreams.RLock() + l := len(feed.vbstreams) + feed.muVbstreams.RUnlock() + + if event.Opcode == gomemcached.UPR_CLOSESTREAM && l == 0 { + logging.Infof("No more streams") + } + } + + if !feed.ackByClient { + // if client does not ack, do the ack check now + feed.sendBufferAckIfNeeded(event) + } + } + } + + // make sure that feed is closed before we signal transmitCl and exit runFeed + feed.Close() + + close(feed.transmitCl) + logging.Infof("runFeed exiting") +} + +// Client, after completing processing of an UprEvent, need to call this API to notify UprFeed, +// so that UprFeed can update its ack bytes stats and send ack to DCP if needed +// Client needs to set ackByClient flag to true in NewUprFeedWithConfig() call as a prerequisite for this call to work +// This API is not thread safe. Caller should NOT have more than one go rountine calling this API +func (feed *UprFeed) ClientAck(event *UprEvent) error { + if !feed.ackByClient { + return errors.New("Upr feed does not have ackByclient flag set") + } + feed.sendBufferAckIfNeeded(event) + return nil +} + +// increment ack bytes if the event needs to be acked to DCP +// send buffer ack if enough ack bytes have been accumulated +func (feed *UprFeed) sendBufferAckIfNeeded(event *UprEvent) { + if event == nil || event.AckSize == 0 { + // this indicates that there is no need to ack to DCP + return + } + + totalBytes := feed.toAckBytes + event.AckSize + if totalBytes > feed.maxAckBytes { + feed.toAckBytes = 0 + feed.sendBufferAck(totalBytes) + } else { + feed.toAckBytes = totalBytes + } +} + +// send buffer ack to dcp +func (feed *UprFeed) sendBufferAck(sendSize uint32) { + bufferAck := &gomemcached.MCRequest{ + Opcode: gomemcached.UPR_BUFFERACK, + } + bufferAck.Extras = make([]byte, 4) + binary.BigEndian.PutUint32(bufferAck.Extras[:4], uint32(sendSize)) + feed.writeToTransmitCh(bufferAck) + feed.stats.TotalBufferAckSent++ +} + +func (feed *UprFeed) GetUprStats() *UprStats { + return &feed.stats +} + +func composeOpaque(vbno, opaqueMSB uint16) uint32 { + return (uint32(opaqueMSB) << 16) | uint32(vbno) +} + +func getUprOpenCtrlOpaque() uint32 { + return atomic.AddUint32(&opaqueOpenCtrlWell, 1) +} + +func appOpaque(opq32 uint32) uint16 { + return uint16((opq32 & 0xFFFF0000) >> 16) +} + +func vbOpaque(opq32 uint32) uint16 { + return uint16(opq32 & 0xFFFF) +} + +// Close this UprFeed. +func (feed *UprFeed) Close() { + feed.muClosed.Lock() + defer feed.muClosed.Unlock() + if !feed.closed { + close(feed.closer) + feed.closed = true + feed.negotiator.initialize() + } +} + +// check if the UprFeed has been closed +func (feed *UprFeed) Closed() bool { + feed.muClosed.RLock() + defer feed.muClosed.RUnlock() + return feed.closed +} diff --git a/vendor/github.com/couchbase/gomemcached/mc_constants.go b/vendor/github.com/couchbase/gomemcached/mc_constants.go new file mode 100644 index 00000000..1d5027d1 --- /dev/null +++ b/vendor/github.com/couchbase/gomemcached/mc_constants.go @@ -0,0 +1,335 @@ +// Package gomemcached is binary protocol packet formats and constants. +package gomemcached + +import ( + "fmt" +) + +const ( + REQ_MAGIC = 0x80 + RES_MAGIC = 0x81 +) + +// CommandCode for memcached packets. +type CommandCode uint8 + +const ( + GET = CommandCode(0x00) + SET = CommandCode(0x01) + ADD = CommandCode(0x02) + REPLACE = CommandCode(0x03) + DELETE = CommandCode(0x04) + INCREMENT = CommandCode(0x05) + DECREMENT = CommandCode(0x06) + QUIT = CommandCode(0x07) + FLUSH = CommandCode(0x08) + GETQ = CommandCode(0x09) + NOOP = CommandCode(0x0a) + VERSION = CommandCode(0x0b) + GETK = CommandCode(0x0c) + GETKQ = CommandCode(0x0d) + APPEND = CommandCode(0x0e) + PREPEND = CommandCode(0x0f) + STAT = CommandCode(0x10) + SETQ = CommandCode(0x11) + ADDQ = CommandCode(0x12) + REPLACEQ = CommandCode(0x13) + DELETEQ = CommandCode(0x14) + INCREMENTQ = CommandCode(0x15) + DECREMENTQ = CommandCode(0x16) + QUITQ = CommandCode(0x17) + FLUSHQ = CommandCode(0x18) + APPENDQ = CommandCode(0x19) + AUDIT = CommandCode(0x27) + PREPENDQ = CommandCode(0x1a) + GAT = CommandCode(0x1d) + HELLO = CommandCode(0x1f) + RGET = CommandCode(0x30) + RSET = CommandCode(0x31) + RSETQ = CommandCode(0x32) + RAPPEND = CommandCode(0x33) + RAPPENDQ = CommandCode(0x34) + RPREPEND = CommandCode(0x35) + RPREPENDQ = CommandCode(0x36) + RDELETE = CommandCode(0x37) + RDELETEQ = CommandCode(0x38) + RINCR = CommandCode(0x39) + RINCRQ = CommandCode(0x3a) + RDECR = CommandCode(0x3b) + RDECRQ = CommandCode(0x3c) + + SASL_LIST_MECHS = CommandCode(0x20) + SASL_AUTH = CommandCode(0x21) + SASL_STEP = CommandCode(0x22) + + SET_VBUCKET = CommandCode(0x3d) + + TAP_CONNECT = CommandCode(0x40) // Client-sent request to initiate Tap feed + TAP_MUTATION = CommandCode(0x41) // Notification of a SET/ADD/REPLACE/etc. on the server + TAP_DELETE = CommandCode(0x42) // Notification of a DELETE on the server + TAP_FLUSH = CommandCode(0x43) // Replicates a flush_all command + TAP_OPAQUE = CommandCode(0x44) // Opaque control data from the engine + TAP_VBUCKET_SET = CommandCode(0x45) // Sets state of vbucket in receiver (used in takeover) + TAP_CHECKPOINT_START = CommandCode(0x46) // Notifies start of new checkpoint + TAP_CHECKPOINT_END = CommandCode(0x47) // Notifies end of checkpoint + + UPR_OPEN = CommandCode(0x50) // Open a UPR connection with a name + UPR_ADDSTREAM = CommandCode(0x51) // Sent by ebucketMigrator to UPR Consumer + UPR_CLOSESTREAM = CommandCode(0x52) // Sent by eBucketMigrator to UPR Consumer + UPR_FAILOVERLOG = CommandCode(0x54) // Request failover logs + UPR_STREAMREQ = CommandCode(0x53) // Stream request from consumer to producer + UPR_STREAMEND = CommandCode(0x55) // Sent by producer when it has no more messages to stream + UPR_SNAPSHOT = CommandCode(0x56) // Start of a new snapshot + UPR_MUTATION = CommandCode(0x57) // Key mutation + UPR_DELETION = CommandCode(0x58) // Key deletion + UPR_EXPIRATION = CommandCode(0x59) // Key expiration + UPR_FLUSH = CommandCode(0x5a) // Delete all the data for a vbucket + UPR_NOOP = CommandCode(0x5c) // UPR NOOP + UPR_BUFFERACK = CommandCode(0x5d) // UPR Buffer Acknowledgement + UPR_CONTROL = CommandCode(0x5e) // Set flow control params + + SELECT_BUCKET = CommandCode(0x89) // Select bucket + + OBSERVE_SEQNO = CommandCode(0x91) // Sequence Number based Observe + OBSERVE = CommandCode(0x92) + + GET_META = CommandCode(0xA0) // Get meta. returns with expiry, flags, cas etc + SUBDOC_GET = CommandCode(0xc5) // Get subdoc. Returns with xattrs + SUBDOC_MULTI_LOOKUP = CommandCode(0xd0) // Multi lookup. Doc xattrs and meta. +) + +// command codes that are counted toward DCP control buffer +// when DCP clients receive DCP messages with these command codes, they need to provide acknowledgement +var BufferedCommandCodeMap = map[CommandCode]bool{ + SET_VBUCKET: true, + UPR_STREAMEND: true, + UPR_SNAPSHOT: true, + UPR_MUTATION: true, + UPR_DELETION: true, + UPR_EXPIRATION: true} + +// Status field for memcached response. +type Status uint16 + +// Matches with protocol_binary.h as source of truth +const ( + SUCCESS = Status(0x00) + KEY_ENOENT = Status(0x01) + KEY_EEXISTS = Status(0x02) + E2BIG = Status(0x03) + EINVAL = Status(0x04) + NOT_STORED = Status(0x05) + DELTA_BADVAL = Status(0x06) + NOT_MY_VBUCKET = Status(0x07) + NO_BUCKET = Status(0x08) + LOCKED = Status(0x09) + AUTH_STALE = Status(0x1f) + AUTH_ERROR = Status(0x20) + AUTH_CONTINUE = Status(0x21) + ERANGE = Status(0x22) + ROLLBACK = Status(0x23) + EACCESS = Status(0x24) + NOT_INITIALIZED = Status(0x25) + UNKNOWN_COMMAND = Status(0x81) + ENOMEM = Status(0x82) + NOT_SUPPORTED = Status(0x83) + EINTERNAL = Status(0x84) + EBUSY = Status(0x85) + TMPFAIL = Status(0x86) + + // SUBDOC + SUBDOC_PATH_NOT_FOUND = Status(0xc0) + SUBDOC_BAD_MULTI = Status(0xcc) + SUBDOC_MULTI_PATH_FAILURE_DELETED = Status(0xd3) +) + +// for log redaction +const ( + UdTagBegin = "" + UdTagEnd = "" +) + +var isFatal = map[Status]bool{ + DELTA_BADVAL: true, + NO_BUCKET: true, + AUTH_STALE: true, + AUTH_ERROR: true, + ERANGE: true, + ROLLBACK: true, + EACCESS: true, + ENOMEM: true, + NOT_SUPPORTED: true, +} + +// the producer/consumer bit in dcp flags +var DCP_PRODUCER uint32 = 0x01 + +// the include XATTRS bit in dcp flags +var DCP_OPEN_INCLUDE_XATTRS uint32 = 0x04 + +// the include deletion time bit in dcp flags +var DCP_OPEN_INCLUDE_DELETE_TIMES uint32 = 0x20 + +// Datatype to Include XATTRS in SUBDOC GET +var SUBDOC_FLAG_XATTR uint8 = 0x04 + +// MCItem is an internal representation of an item. +type MCItem struct { + Cas uint64 + Flags, Expiration uint32 + Data []byte +} + +// Number of bytes in a binary protocol header. +const HDR_LEN = 24 + +// Mapping of CommandCode -> name of command (not exhaustive) +var CommandNames map[CommandCode]string + +// StatusNames human readable names for memcached response. +var StatusNames map[Status]string + +func init() { + CommandNames = make(map[CommandCode]string) + CommandNames[GET] = "GET" + CommandNames[SET] = "SET" + CommandNames[ADD] = "ADD" + CommandNames[REPLACE] = "REPLACE" + CommandNames[DELETE] = "DELETE" + CommandNames[INCREMENT] = "INCREMENT" + CommandNames[DECREMENT] = "DECREMENT" + CommandNames[QUIT] = "QUIT" + CommandNames[FLUSH] = "FLUSH" + CommandNames[GETQ] = "GETQ" + CommandNames[NOOP] = "NOOP" + CommandNames[VERSION] = "VERSION" + CommandNames[GETK] = "GETK" + CommandNames[GETKQ] = "GETKQ" + CommandNames[APPEND] = "APPEND" + CommandNames[PREPEND] = "PREPEND" + CommandNames[STAT] = "STAT" + CommandNames[SETQ] = "SETQ" + CommandNames[ADDQ] = "ADDQ" + CommandNames[REPLACEQ] = "REPLACEQ" + CommandNames[DELETEQ] = "DELETEQ" + CommandNames[INCREMENTQ] = "INCREMENTQ" + CommandNames[DECREMENTQ] = "DECREMENTQ" + CommandNames[QUITQ] = "QUITQ" + CommandNames[FLUSHQ] = "FLUSHQ" + CommandNames[APPENDQ] = "APPENDQ" + CommandNames[PREPENDQ] = "PREPENDQ" + CommandNames[RGET] = "RGET" + CommandNames[RSET] = "RSET" + CommandNames[RSETQ] = "RSETQ" + CommandNames[RAPPEND] = "RAPPEND" + CommandNames[RAPPENDQ] = "RAPPENDQ" + CommandNames[RPREPEND] = "RPREPEND" + CommandNames[RPREPENDQ] = "RPREPENDQ" + CommandNames[RDELETE] = "RDELETE" + CommandNames[RDELETEQ] = "RDELETEQ" + CommandNames[RINCR] = "RINCR" + CommandNames[RINCRQ] = "RINCRQ" + CommandNames[RDECR] = "RDECR" + CommandNames[RDECRQ] = "RDECRQ" + + CommandNames[SASL_LIST_MECHS] = "SASL_LIST_MECHS" + CommandNames[SASL_AUTH] = "SASL_AUTH" + CommandNames[SASL_STEP] = "SASL_STEP" + + CommandNames[TAP_CONNECT] = "TAP_CONNECT" + CommandNames[TAP_MUTATION] = "TAP_MUTATION" + CommandNames[TAP_DELETE] = "TAP_DELETE" + CommandNames[TAP_FLUSH] = "TAP_FLUSH" + CommandNames[TAP_OPAQUE] = "TAP_OPAQUE" + CommandNames[TAP_VBUCKET_SET] = "TAP_VBUCKET_SET" + CommandNames[TAP_CHECKPOINT_START] = "TAP_CHECKPOINT_START" + CommandNames[TAP_CHECKPOINT_END] = "TAP_CHECKPOINT_END" + + CommandNames[UPR_OPEN] = "UPR_OPEN" + CommandNames[UPR_ADDSTREAM] = "UPR_ADDSTREAM" + CommandNames[UPR_CLOSESTREAM] = "UPR_CLOSESTREAM" + CommandNames[UPR_FAILOVERLOG] = "UPR_FAILOVERLOG" + CommandNames[UPR_STREAMREQ] = "UPR_STREAMREQ" + CommandNames[UPR_STREAMEND] = "UPR_STREAMEND" + CommandNames[UPR_SNAPSHOT] = "UPR_SNAPSHOT" + CommandNames[UPR_MUTATION] = "UPR_MUTATION" + CommandNames[UPR_DELETION] = "UPR_DELETION" + CommandNames[UPR_EXPIRATION] = "UPR_EXPIRATION" + CommandNames[UPR_FLUSH] = "UPR_FLUSH" + CommandNames[UPR_NOOP] = "UPR_NOOP" + CommandNames[UPR_BUFFERACK] = "UPR_BUFFERACK" + CommandNames[UPR_CONTROL] = "UPR_CONTROL" + CommandNames[SUBDOC_GET] = "SUBDOC_GET" + CommandNames[SUBDOC_MULTI_LOOKUP] = "SUBDOC_MULTI_LOOKUP" + + StatusNames = make(map[Status]string) + StatusNames[SUCCESS] = "SUCCESS" + StatusNames[KEY_ENOENT] = "KEY_ENOENT" + StatusNames[KEY_EEXISTS] = "KEY_EEXISTS" + StatusNames[E2BIG] = "E2BIG" + StatusNames[EINVAL] = "EINVAL" + StatusNames[NOT_STORED] = "NOT_STORED" + StatusNames[DELTA_BADVAL] = "DELTA_BADVAL" + StatusNames[NOT_MY_VBUCKET] = "NOT_MY_VBUCKET" + StatusNames[NO_BUCKET] = "NO_BUCKET" + StatusNames[AUTH_STALE] = "AUTH_STALE" + StatusNames[AUTH_ERROR] = "AUTH_ERROR" + StatusNames[AUTH_CONTINUE] = "AUTH_CONTINUE" + StatusNames[ERANGE] = "ERANGE" + StatusNames[ROLLBACK] = "ROLLBACK" + StatusNames[EACCESS] = "EACCESS" + StatusNames[NOT_INITIALIZED] = "NOT_INITIALIZED" + StatusNames[UNKNOWN_COMMAND] = "UNKNOWN_COMMAND" + StatusNames[ENOMEM] = "ENOMEM" + StatusNames[NOT_SUPPORTED] = "NOT_SUPPORTED" + StatusNames[EINTERNAL] = "EINTERNAL" + StatusNames[EBUSY] = "EBUSY" + StatusNames[TMPFAIL] = "TMPFAIL" + StatusNames[SUBDOC_PATH_NOT_FOUND] = "SUBDOC_PATH_NOT_FOUND" + StatusNames[SUBDOC_BAD_MULTI] = "SUBDOC_BAD_MULTI" + +} + +// String an op code. +func (o CommandCode) String() (rv string) { + rv = CommandNames[o] + if rv == "" { + rv = fmt.Sprintf("0x%02x", int(o)) + } + return rv +} + +// String an op code. +func (s Status) String() (rv string) { + rv = StatusNames[s] + if rv == "" { + rv = fmt.Sprintf("0x%02x", int(s)) + } + return rv +} + +// IsQuiet will return true if a command is a "quiet" command. +func (o CommandCode) IsQuiet() bool { + switch o { + case GETQ, + GETKQ, + SETQ, + ADDQ, + REPLACEQ, + DELETEQ, + INCREMENTQ, + DECREMENTQ, + QUITQ, + FLUSHQ, + APPENDQ, + PREPENDQ, + RSETQ, + RAPPENDQ, + RPREPENDQ, + RDELETEQ, + RINCRQ, + RDECRQ: + return true + } + return false +} diff --git a/vendor/github.com/couchbase/gomemcached/mc_req.go b/vendor/github.com/couchbase/gomemcached/mc_req.go new file mode 100644 index 00000000..3ff67ab9 --- /dev/null +++ b/vendor/github.com/couchbase/gomemcached/mc_req.go @@ -0,0 +1,197 @@ +package gomemcached + +import ( + "encoding/binary" + "fmt" + "io" +) + +// The maximum reasonable body length to expect. +// Anything larger than this will result in an error. +// The current limit, 20MB, is the size limit supported by ep-engine. +var MaxBodyLen = int(20 * 1024 * 1024) + +// MCRequest is memcached Request +type MCRequest struct { + // The command being issued + Opcode CommandCode + // The CAS (if applicable, or 0) + Cas uint64 + // An opaque value to be returned with this request + Opaque uint32 + // The vbucket to which this command belongs + VBucket uint16 + // Command extras, key, and body + Extras, Key, Body, ExtMeta []byte + // Datatype identifier + DataType uint8 +} + +// Size gives the number of bytes this request requires. +func (req *MCRequest) Size() int { + return HDR_LEN + len(req.Extras) + len(req.Key) + len(req.Body) + len(req.ExtMeta) +} + +// A debugging string representation of this request +func (req MCRequest) String() string { + return fmt.Sprintf("{MCRequest opcode=%s, bodylen=%d, key='%s'}", + req.Opcode, len(req.Body), req.Key) +} + +func (req *MCRequest) fillHeaderBytes(data []byte) int { + + pos := 0 + data[pos] = REQ_MAGIC + pos++ + data[pos] = byte(req.Opcode) + pos++ + binary.BigEndian.PutUint16(data[pos:pos+2], + uint16(len(req.Key))) + pos += 2 + + // 4 + data[pos] = byte(len(req.Extras)) + pos++ + // Data type + if req.DataType != 0 { + data[pos] = byte(req.DataType) + } + pos++ + binary.BigEndian.PutUint16(data[pos:pos+2], req.VBucket) + pos += 2 + + // 8 + binary.BigEndian.PutUint32(data[pos:pos+4], + uint32(len(req.Body)+len(req.Key)+len(req.Extras)+len(req.ExtMeta))) + pos += 4 + + // 12 + binary.BigEndian.PutUint32(data[pos:pos+4], req.Opaque) + pos += 4 + + // 16 + if req.Cas != 0 { + binary.BigEndian.PutUint64(data[pos:pos+8], req.Cas) + } + pos += 8 + + if len(req.Extras) > 0 { + copy(data[pos:pos+len(req.Extras)], req.Extras) + pos += len(req.Extras) + } + + if len(req.Key) > 0 { + copy(data[pos:pos+len(req.Key)], req.Key) + pos += len(req.Key) + } + + return pos +} + +// HeaderBytes will return the wire representation of the request header +// (with the extras and key). +func (req *MCRequest) HeaderBytes() []byte { + data := make([]byte, HDR_LEN+len(req.Extras)+len(req.Key)) + + req.fillHeaderBytes(data) + + return data +} + +// Bytes will return the wire representation of this request. +func (req *MCRequest) Bytes() []byte { + data := make([]byte, req.Size()) + + pos := req.fillHeaderBytes(data) + + if len(req.Body) > 0 { + copy(data[pos:pos+len(req.Body)], req.Body) + } + + if len(req.ExtMeta) > 0 { + copy(data[pos+len(req.Body):pos+len(req.Body)+len(req.ExtMeta)], req.ExtMeta) + } + + return data +} + +// Transmit will send this request message across a writer. +func (req *MCRequest) Transmit(w io.Writer) (n int, err error) { + if len(req.Body) < 128 { + n, err = w.Write(req.Bytes()) + } else { + n, err = w.Write(req.HeaderBytes()) + if err == nil { + m := 0 + m, err = w.Write(req.Body) + n += m + } + } + return +} + +// Receive will fill this MCRequest with the data from a reader. +func (req *MCRequest) Receive(r io.Reader, hdrBytes []byte) (int, error) { + if len(hdrBytes) < HDR_LEN { + hdrBytes = []byte{ + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0} + } + n, err := io.ReadFull(r, hdrBytes) + if err != nil { + return n, err + } + + if hdrBytes[0] != RES_MAGIC && hdrBytes[0] != REQ_MAGIC { + return n, fmt.Errorf("bad magic: 0x%02x", hdrBytes[0]) + } + + klen := int(binary.BigEndian.Uint16(hdrBytes[2:])) + elen := int(hdrBytes[4]) + // Data type at 5 + req.DataType = uint8(hdrBytes[5]) + + req.Opcode = CommandCode(hdrBytes[1]) + // Vbucket at 6:7 + req.VBucket = binary.BigEndian.Uint16(hdrBytes[6:]) + totalBodyLen := int(binary.BigEndian.Uint32(hdrBytes[8:])) + + req.Opaque = binary.BigEndian.Uint32(hdrBytes[12:]) + req.Cas = binary.BigEndian.Uint64(hdrBytes[16:]) + + if totalBodyLen > 0 { + buf := make([]byte, totalBodyLen) + m, err := io.ReadFull(r, buf) + n += m + if err == nil { + if req.Opcode >= TAP_MUTATION && + req.Opcode <= TAP_CHECKPOINT_END && + len(buf) > 1 { + // In these commands there is "engine private" + // data at the end of the extras. The first 2 + // bytes of extra data give its length. + elen += int(binary.BigEndian.Uint16(buf)) + } + + req.Extras = buf[0:elen] + req.Key = buf[elen : klen+elen] + + // get the length of extended metadata + extMetaLen := 0 + if elen > 29 { + extMetaLen = int(binary.BigEndian.Uint16(req.Extras[28:30])) + } + + bodyLen := totalBodyLen - klen - elen - extMetaLen + if bodyLen > MaxBodyLen { + return n, fmt.Errorf("%d is too big (max %d)", + bodyLen, MaxBodyLen) + } + + req.Body = buf[klen+elen : klen+elen+bodyLen] + req.ExtMeta = buf[klen+elen+bodyLen:] + } + } + return n, err +} diff --git a/vendor/github.com/couchbase/gomemcached/mc_res.go b/vendor/github.com/couchbase/gomemcached/mc_res.go new file mode 100644 index 00000000..2b4cfe13 --- /dev/null +++ b/vendor/github.com/couchbase/gomemcached/mc_res.go @@ -0,0 +1,267 @@ +package gomemcached + +import ( + "encoding/binary" + "fmt" + "io" + "sync" +) + +// MCResponse is memcached response +type MCResponse struct { + // The command opcode of the command that sent the request + Opcode CommandCode + // The status of the response + Status Status + // The opaque sent in the request + Opaque uint32 + // The CAS identifier (if applicable) + Cas uint64 + // Extras, key, and body for this response + Extras, Key, Body []byte + // If true, this represents a fatal condition and we should hang up + Fatal bool + // Datatype identifier + DataType uint8 +} + +// A debugging string representation of this response +func (res MCResponse) String() string { + return fmt.Sprintf("{MCResponse status=%v keylen=%d, extralen=%d, bodylen=%d}", + res.Status, len(res.Key), len(res.Extras), len(res.Body)) +} + +// Response as an error. +func (res *MCResponse) Error() string { + return fmt.Sprintf("MCResponse status=%v, opcode=%v, opaque=%v, msg: %s", + res.Status, res.Opcode, res.Opaque, string(res.Body)) +} + +func errStatus(e error) Status { + status := Status(0xffff) + if res, ok := e.(*MCResponse); ok { + status = res.Status + } + return status +} + +// IsNotFound is true if this error represents a "not found" response. +func IsNotFound(e error) bool { + return errStatus(e) == KEY_ENOENT +} + +// IsFatal is false if this error isn't believed to be fatal to a connection. +func IsFatal(e error) bool { + if e == nil { + return false + } + _, ok := isFatal[errStatus(e)] + if ok { + return true + } + return false +} + +// Size is number of bytes this response consumes on the wire. +func (res *MCResponse) Size() int { + return HDR_LEN + len(res.Extras) + len(res.Key) + len(res.Body) +} + +func (res *MCResponse) fillHeaderBytes(data []byte) int { + pos := 0 + data[pos] = RES_MAGIC + pos++ + data[pos] = byte(res.Opcode) + pos++ + binary.BigEndian.PutUint16(data[pos:pos+2], + uint16(len(res.Key))) + pos += 2 + + // 4 + data[pos] = byte(len(res.Extras)) + pos++ + // Data type + if res.DataType != 0 { + data[pos] = byte(res.DataType) + } else { + data[pos] = 0 + } + pos++ + binary.BigEndian.PutUint16(data[pos:pos+2], uint16(res.Status)) + pos += 2 + + // 8 + binary.BigEndian.PutUint32(data[pos:pos+4], + uint32(len(res.Body)+len(res.Key)+len(res.Extras))) + pos += 4 + + // 12 + binary.BigEndian.PutUint32(data[pos:pos+4], res.Opaque) + pos += 4 + + // 16 + binary.BigEndian.PutUint64(data[pos:pos+8], res.Cas) + pos += 8 + + if len(res.Extras) > 0 { + copy(data[pos:pos+len(res.Extras)], res.Extras) + pos += len(res.Extras) + } + + if len(res.Key) > 0 { + copy(data[pos:pos+len(res.Key)], res.Key) + pos += len(res.Key) + } + + return pos +} + +// HeaderBytes will get just the header bytes for this response. +func (res *MCResponse) HeaderBytes() []byte { + data := make([]byte, HDR_LEN+len(res.Extras)+len(res.Key)) + + res.fillHeaderBytes(data) + + return data +} + +// Bytes will return the actual bytes transmitted for this response. +func (res *MCResponse) Bytes() []byte { + data := make([]byte, res.Size()) + + pos := res.fillHeaderBytes(data) + + copy(data[pos:pos+len(res.Body)], res.Body) + + return data +} + +// Transmit will send this response message across a writer. +func (res *MCResponse) Transmit(w io.Writer) (n int, err error) { + if len(res.Body) < 128 { + n, err = w.Write(res.Bytes()) + } else { + n, err = w.Write(res.HeaderBytes()) + if err == nil { + m := 0 + m, err = w.Write(res.Body) + m += n + } + } + return +} + +// Receive will fill this MCResponse with the data from this reader. +func (res *MCResponse) Receive(r io.Reader, hdrBytes []byte) (n int, err error) { + if len(hdrBytes) < HDR_LEN { + hdrBytes = []byte{ + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0} + } + n, err = io.ReadFull(r, hdrBytes) + if err != nil { + return n, err + } + + if hdrBytes[0] != RES_MAGIC && hdrBytes[0] != REQ_MAGIC { + return n, fmt.Errorf("bad magic: 0x%02x", hdrBytes[0]) + } + + klen := int(binary.BigEndian.Uint16(hdrBytes[2:4])) + elen := int(hdrBytes[4]) + + res.Opcode = CommandCode(hdrBytes[1]) + res.DataType = uint8(hdrBytes[5]) + res.Status = Status(binary.BigEndian.Uint16(hdrBytes[6:8])) + res.Opaque = binary.BigEndian.Uint32(hdrBytes[12:16]) + res.Cas = binary.BigEndian.Uint64(hdrBytes[16:24]) + + bodyLen := int(binary.BigEndian.Uint32(hdrBytes[8:12])) - (klen + elen) + + //defer function to debug the panic seen with MB-15557 + defer func() { + if e := recover(); e != nil { + err = fmt.Errorf(`Panic in Receive. Response %v \n + key len %v extra len %v bodylen %v`, res, klen, elen, bodyLen) + } + }() + + buf := make([]byte, klen+elen+bodyLen) + m, err := io.ReadFull(r, buf) + if err == nil { + res.Extras = buf[0:elen] + res.Key = buf[elen : klen+elen] + res.Body = buf[klen+elen:] + } + + return n + m, err +} + +type MCResponsePool struct { + pool *sync.Pool +} + +func NewMCResponsePool() *MCResponsePool { + rv := &MCResponsePool{ + pool: &sync.Pool{ + New: func() interface{} { + return &MCResponse{} + }, + }, + } + + return rv +} + +func (this *MCResponsePool) Get() *MCResponse { + return this.pool.Get().(*MCResponse) +} + +func (this *MCResponsePool) Put(r *MCResponse) { + if r == nil { + return + } + + r.Extras = nil + r.Key = nil + r.Body = nil + r.Fatal = false + + this.pool.Put(r) +} + +type StringMCResponsePool struct { + pool *sync.Pool + size int +} + +func NewStringMCResponsePool(size int) *StringMCResponsePool { + rv := &StringMCResponsePool{ + pool: &sync.Pool{ + New: func() interface{} { + return make(map[string]*MCResponse, size) + }, + }, + size: size, + } + + return rv +} + +func (this *StringMCResponsePool) Get() map[string]*MCResponse { + return this.pool.Get().(map[string]*MCResponse) +} + +func (this *StringMCResponsePool) Put(m map[string]*MCResponse) { + if m == nil || len(m) > 2*this.size { + return + } + + for k := range m { + m[k] = nil + delete(m, k) + } + + this.pool.Put(m) +} diff --git a/vendor/github.com/couchbase/gomemcached/tap.go b/vendor/github.com/couchbase/gomemcached/tap.go new file mode 100644 index 00000000..e4862328 --- /dev/null +++ b/vendor/github.com/couchbase/gomemcached/tap.go @@ -0,0 +1,168 @@ +package gomemcached + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "io/ioutil" + "strings" +) + +type TapConnectFlag uint32 + +// Tap connect option flags +const ( + BACKFILL = TapConnectFlag(0x01) + DUMP = TapConnectFlag(0x02) + LIST_VBUCKETS = TapConnectFlag(0x04) + TAKEOVER_VBUCKETS = TapConnectFlag(0x08) + SUPPORT_ACK = TapConnectFlag(0x10) + REQUEST_KEYS_ONLY = TapConnectFlag(0x20) + CHECKPOINT = TapConnectFlag(0x40) + REGISTERED_CLIENT = TapConnectFlag(0x80) + FIX_FLAG_BYTEORDER = TapConnectFlag(0x100) +) + +// Tap opaque event subtypes +const ( + TAP_OPAQUE_ENABLE_AUTO_NACK = 0 + TAP_OPAQUE_INITIAL_VBUCKET_STREAM = 1 + TAP_OPAQUE_ENABLE_CHECKPOINT_SYNC = 2 + TAP_OPAQUE_CLOSE_TAP_STREAM = 7 + TAP_OPAQUE_CLOSE_BACKFILL = 8 +) + +// Tap item flags +const ( + TAP_ACK = 1 + TAP_NO_VALUE = 2 + TAP_FLAG_NETWORK_BYTE_ORDER = 4 +) + +// TapConnectFlagNames for TapConnectFlag +var TapConnectFlagNames = map[TapConnectFlag]string{ + BACKFILL: "BACKFILL", + DUMP: "DUMP", + LIST_VBUCKETS: "LIST_VBUCKETS", + TAKEOVER_VBUCKETS: "TAKEOVER_VBUCKETS", + SUPPORT_ACK: "SUPPORT_ACK", + REQUEST_KEYS_ONLY: "REQUEST_KEYS_ONLY", + CHECKPOINT: "CHECKPOINT", + REGISTERED_CLIENT: "REGISTERED_CLIENT", + FIX_FLAG_BYTEORDER: "FIX_FLAG_BYTEORDER", +} + +// TapItemParser is a function to parse a single tap extra. +type TapItemParser func(io.Reader) (interface{}, error) + +// TapParseUint64 is a function to parse a single tap uint64. +func TapParseUint64(r io.Reader) (interface{}, error) { + var rv uint64 + err := binary.Read(r, binary.BigEndian, &rv) + return rv, err +} + +// TapParseUint16 is a function to parse a single tap uint16. +func TapParseUint16(r io.Reader) (interface{}, error) { + var rv uint16 + err := binary.Read(r, binary.BigEndian, &rv) + return rv, err +} + +// TapParseBool is a function to parse a single tap boolean. +func TapParseBool(r io.Reader) (interface{}, error) { + return true, nil +} + +// TapParseVBList parses a list of vBucket numbers as []uint16. +func TapParseVBList(r io.Reader) (interface{}, error) { + num, err := TapParseUint16(r) + if err != nil { + return nil, err + } + n := int(num.(uint16)) + + rv := make([]uint16, n) + for i := 0; i < n; i++ { + x, err := TapParseUint16(r) + if err != nil { + return nil, err + } + rv[i] = x.(uint16) + } + + return rv, err +} + +// TapFlagParsers parser functions for TAP fields. +var TapFlagParsers = map[TapConnectFlag]TapItemParser{ + BACKFILL: TapParseUint64, + LIST_VBUCKETS: TapParseVBList, +} + +// SplitFlags will split the ORed flags into the individual bit flags. +func (f TapConnectFlag) SplitFlags() []TapConnectFlag { + rv := []TapConnectFlag{} + for i := uint32(1); f != 0; i = i << 1 { + if uint32(f)&i == i { + rv = append(rv, TapConnectFlag(i)) + } + f = TapConnectFlag(uint32(f) & (^i)) + } + return rv +} + +func (f TapConnectFlag) String() string { + parts := []string{} + for _, x := range f.SplitFlags() { + p := TapConnectFlagNames[x] + if p == "" { + p = fmt.Sprintf("0x%x", int(x)) + } + parts = append(parts, p) + } + return strings.Join(parts, "|") +} + +type TapConnect struct { + Flags map[TapConnectFlag]interface{} + RemainingBody []byte + Name string +} + +// ParseTapCommands parse the tap request into the interesting bits we may +// need to do something with. +func (req *MCRequest) ParseTapCommands() (TapConnect, error) { + rv := TapConnect{ + Flags: map[TapConnectFlag]interface{}{}, + Name: string(req.Key), + } + + if len(req.Extras) < 4 { + return rv, fmt.Errorf("not enough extra bytes: %x", req.Extras) + } + + flags := TapConnectFlag(binary.BigEndian.Uint32(req.Extras)) + + r := bytes.NewReader(req.Body) + + for _, f := range flags.SplitFlags() { + fun := TapFlagParsers[f] + if fun == nil { + fun = TapParseBool + } + + val, err := fun(r) + if err != nil { + return rv, err + } + + rv.Flags[f] = val + } + + var err error + rv.RemainingBody, err = ioutil.ReadAll(r) + + return rv, err +} diff --git a/vendor/github.com/couchbase/goutils/LICENSE.md b/vendor/github.com/couchbase/goutils/LICENSE.md new file mode 100644 index 00000000..a572e246 --- /dev/null +++ b/vendor/github.com/couchbase/goutils/LICENSE.md @@ -0,0 +1,47 @@ +COUCHBASE INC. COMMUNITY EDITION LICENSE AGREEMENT + +IMPORTANT-READ CAREFULLY: BY CLICKING THE "I ACCEPT" BOX OR INSTALLING, +DOWNLOADING OR OTHERWISE USING THIS SOFTWARE AND ANY ASSOCIATED +DOCUMENTATION, YOU, ON BEHALF OF YOURSELF OR AS AN AUTHORIZED +REPRESENTATIVE ON BEHALF OF AN ENTITY ("LICENSEE") AGREE TO ALL THE +TERMS OF THIS COMMUNITY EDITION LICENSE AGREEMENT (THE "AGREEMENT") +REGARDING YOUR USE OF THE SOFTWARE. YOU REPRESENT AND WARRANT THAT YOU +HAVE FULL LEGAL AUTHORITY TO BIND THE LICENSEE TO THIS AGREEMENT. IF YOU +DO NOT AGREE WITH ALL OF THESE TERMS, DO NOT SELECT THE "I ACCEPT" BOX +AND DO NOT INSTALL, DOWNLOAD OR OTHERWISE USE THE SOFTWARE. THE +EFFECTIVE DATE OF THIS AGREEMENT IS THE DATE ON WHICH YOU CLICK "I +ACCEPT" OR OTHERWISE INSTALL, DOWNLOAD OR USE THE SOFTWARE. + +1. License Grant. Couchbase Inc. hereby grants Licensee, free of charge, +the non-exclusive right to use, copy, merge, publish, distribute, +sublicense, and/or sell copies of the Software, and to permit persons to +whom the Software is furnished to do so, subject to Licensee including +the following copyright notice in all copies or substantial portions of +the Software: + +Couchbase (r) http://www.Couchbase.com Copyright 2016 Couchbase, Inc. + +As used in this Agreement, "Software" means the object code version of +the applicable elastic data management server software provided by +Couchbase Inc. + +2. Restrictions. Licensee will not reverse engineer, disassemble, or +decompile the Software (except to the extent such restrictions are +prohibited by law). + +3. Support. Couchbase, Inc. will provide Licensee with access to, and +use of, the Couchbase, Inc. support forum available at the following +URL: http://www.couchbase.org/forums/. Couchbase, Inc. may, at its +discretion, modify, suspend or terminate support at any time upon notice +to Licensee. + +4. Warranty Disclaimer and Limitation of Liability. THE SOFTWARE IS +PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +COUCHBASE INC. OR THE AUTHORS OR COPYRIGHT HOLDERS IN THE SOFTWARE BE +LIABLE FOR ANY CLAIM, DAMAGES (IINCLUDING, WITHOUT LIMITATION, DIRECT, +INDIRECT OR CONSEQUENTIAL DAMAGES) OR OTHER LIABILITY, WHETHER IN AN +ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/couchbase/goutils/logging/logger.go b/vendor/github.com/couchbase/goutils/logging/logger.go new file mode 100644 index 00000000..b9948f9b --- /dev/null +++ b/vendor/github.com/couchbase/goutils/logging/logger.go @@ -0,0 +1,481 @@ +// Copyright (c) 2016 Couchbase, Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file +// except in compliance with the License. You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software distributed under the +// License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, +// either express or implied. See the License for the specific language governing permissions +// and limitations under the License. + +package logging + +import ( + "os" + "runtime" + "strings" + "sync" +) + +type Level int + +const ( + NONE = Level(iota) // Disable all logging + FATAL // System is in severe error state and has to abort + SEVERE // System is in severe error state and cannot recover reliably + ERROR // System is in error state but can recover and continue reliably + WARN // System approaching error state, or is in a correct but undesirable state + INFO // System-level events and status, in correct states + REQUEST // Request-level events, with request-specific rlevel + TRACE // Trace detailed system execution, e.g. function entry / exit + DEBUG // Debug +) + +type LogEntryFormatter int + +const ( + TEXTFORMATTER = LogEntryFormatter(iota) + JSONFORMATTER + KVFORMATTER +) + +func (level Level) String() string { + return _LEVEL_NAMES[level] +} + +var _LEVEL_NAMES = []string{ + DEBUG: "DEBUG", + TRACE: "TRACE", + REQUEST: "REQUEST", + INFO: "INFO", + WARN: "WARN", + ERROR: "ERROR", + SEVERE: "SEVERE", + FATAL: "FATAL", + NONE: "NONE", +} + +var _LEVEL_MAP = map[string]Level{ + "debug": DEBUG, + "trace": TRACE, + "request": REQUEST, + "info": INFO, + "warn": WARN, + "error": ERROR, + "severe": SEVERE, + "fatal": FATAL, + "none": NONE, +} + +func ParseLevel(name string) (level Level, ok bool) { + level, ok = _LEVEL_MAP[strings.ToLower(name)] + return +} + +/* + +Pair supports logging of key-value pairs. Keys beginning with _ are +reserved for the logger, e.g. _time, _level, _msg, and _rlevel. The +Pair APIs are designed to avoid heap allocation and garbage +collection. + +*/ +type Pairs []Pair +type Pair struct { + Name string + Value interface{} +} + +/* + +Map allows key-value pairs to be specified using map literals or data +structures. For example: + +Errorm(msg, Map{...}) + +Map incurs heap allocation and garbage collection, so the Pair APIs +should be preferred. + +*/ +type Map map[string]interface{} + +// Logger provides a common interface for logging libraries +type Logger interface { + /* + These APIs write all the given pairs in addition to standard logger keys. + */ + Logp(level Level, msg string, kv ...Pair) + + Debugp(msg string, kv ...Pair) + + Tracep(msg string, kv ...Pair) + + Requestp(rlevel Level, msg string, kv ...Pair) + + Infop(msg string, kv ...Pair) + + Warnp(msg string, kv ...Pair) + + Errorp(msg string, kv ...Pair) + + Severep(msg string, kv ...Pair) + + Fatalp(msg string, kv ...Pair) + + /* + These APIs write the fields in the given kv Map in addition to standard logger keys. + */ + Logm(level Level, msg string, kv Map) + + Debugm(msg string, kv Map) + + Tracem(msg string, kv Map) + + Requestm(rlevel Level, msg string, kv Map) + + Infom(msg string, kv Map) + + Warnm(msg string, kv Map) + + Errorm(msg string, kv Map) + + Severem(msg string, kv Map) + + Fatalm(msg string, kv Map) + + /* + + These APIs only write _msg, _time, _level, and other logger keys. If + the msg contains other fields, use the Pair or Map APIs instead. + + */ + Logf(level Level, fmt string, args ...interface{}) + + Debugf(fmt string, args ...interface{}) + + Tracef(fmt string, args ...interface{}) + + Requestf(rlevel Level, fmt string, args ...interface{}) + + Infof(fmt string, args ...interface{}) + + Warnf(fmt string, args ...interface{}) + + Errorf(fmt string, args ...interface{}) + + Severef(fmt string, args ...interface{}) + + Fatalf(fmt string, args ...interface{}) + + /* + These APIs control the logging level + */ + + SetLevel(Level) // Set the logging level + + Level() Level // Get the current logging level +} + +var logger Logger = nil +var curLevel Level = DEBUG // initially set to never skip + +var loggerMutex sync.RWMutex + +// All the methods below first acquire the mutex (mostly in exclusive mode) +// and only then check if logging at the current level is enabled. +// This introduces a fair bottleneck for those log entries that should be +// skipped (the majority, at INFO or below levels) +// We try to predict here if we should lock the mutex at all by caching +// the current log level: while dynamically changing logger, there might +// be the odd entry skipped as the new level is cached. +// Since we seem to never change the logger, this is not an issue. +func skipLogging(level Level) bool { + if logger == nil { + return true + } + return level > curLevel +} + +func SetLogger(newLogger Logger) { + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger = newLogger + if logger == nil { + curLevel = NONE + } else { + curLevel = newLogger.Level() + } +} + +func Logp(level Level, msg string, kv ...Pair) { + if skipLogging(level) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Logp(level, msg, kv...) +} + +func Debugp(msg string, kv ...Pair) { + if skipLogging(DEBUG) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Debugp(msg, kv...) +} + +func Tracep(msg string, kv ...Pair) { + if skipLogging(TRACE) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Tracep(msg, kv...) +} + +func Requestp(rlevel Level, msg string, kv ...Pair) { + if skipLogging(REQUEST) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Requestp(rlevel, msg, kv...) +} + +func Infop(msg string, kv ...Pair) { + if skipLogging(INFO) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Infop(msg, kv...) +} + +func Warnp(msg string, kv ...Pair) { + if skipLogging(WARN) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Warnp(msg, kv...) +} + +func Errorp(msg string, kv ...Pair) { + if skipLogging(ERROR) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Errorp(msg, kv...) +} + +func Severep(msg string, kv ...Pair) { + if skipLogging(SEVERE) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Severep(msg, kv...) +} + +func Fatalp(msg string, kv ...Pair) { + if skipLogging(FATAL) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Fatalp(msg, kv...) +} + +func Logm(level Level, msg string, kv Map) { + if skipLogging(level) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Logm(level, msg, kv) +} + +func Debugm(msg string, kv Map) { + if skipLogging(DEBUG) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Debugm(msg, kv) +} + +func Tracem(msg string, kv Map) { + if skipLogging(TRACE) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Tracem(msg, kv) +} + +func Requestm(rlevel Level, msg string, kv Map) { + if skipLogging(REQUEST) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Requestm(rlevel, msg, kv) +} + +func Infom(msg string, kv Map) { + if skipLogging(INFO) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Infom(msg, kv) +} + +func Warnm(msg string, kv Map) { + if skipLogging(WARN) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Warnm(msg, kv) +} + +func Errorm(msg string, kv Map) { + if skipLogging(ERROR) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Errorm(msg, kv) +} + +func Severem(msg string, kv Map) { + if skipLogging(SEVERE) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Severem(msg, kv) +} + +func Fatalm(msg string, kv Map) { + if skipLogging(FATAL) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Fatalm(msg, kv) +} + +func Logf(level Level, fmt string, args ...interface{}) { + if skipLogging(level) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Logf(level, fmt, args...) +} + +func Debugf(fmt string, args ...interface{}) { + if skipLogging(DEBUG) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Debugf(fmt, args...) +} + +func Tracef(fmt string, args ...interface{}) { + if skipLogging(TRACE) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Tracef(fmt, args...) +} + +func Requestf(rlevel Level, fmt string, args ...interface{}) { + if skipLogging(REQUEST) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Requestf(rlevel, fmt, args...) +} + +func Infof(fmt string, args ...interface{}) { + if skipLogging(INFO) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Infof(fmt, args...) +} + +func Warnf(fmt string, args ...interface{}) { + if skipLogging(WARN) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Warnf(fmt, args...) +} + +func Errorf(fmt string, args ...interface{}) { + if skipLogging(ERROR) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Errorf(fmt, args...) +} + +func Severef(fmt string, args ...interface{}) { + if skipLogging(SEVERE) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Severef(fmt, args...) +} + +func Fatalf(fmt string, args ...interface{}) { + if skipLogging(FATAL) { + return + } + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Fatalf(fmt, args...) +} + +func SetLevel(level Level) { + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.SetLevel(level) + curLevel = level +} + +func LogLevel() Level { + loggerMutex.RLock() + defer loggerMutex.RUnlock() + return logger.Level() +} + +func Stackf(level Level, fmt string, args ...interface{}) { + if skipLogging(level) { + return + } + buf := make([]byte, 1<<16) + n := runtime.Stack(buf, false) + s := string(buf[0:n]) + loggerMutex.Lock() + defer loggerMutex.Unlock() + logger.Logf(level, fmt, args...) + logger.Logf(level, s) +} + +func init() { + logger = NewLogger(os.Stderr, INFO, TEXTFORMATTER) + SetLogger(logger) +} diff --git a/vendor/github.com/couchbase/goutils/logging/logger_golog.go b/vendor/github.com/couchbase/goutils/logging/logger_golog.go new file mode 100644 index 00000000..eec432a5 --- /dev/null +++ b/vendor/github.com/couchbase/goutils/logging/logger_golog.go @@ -0,0 +1,318 @@ +// Copyright (c) 2016 Couchbase, Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file +// except in compliance with the License. You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software distributed under the +// License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, +// either express or implied. See the License for the specific language governing permissions +// and limitations under the License. + +package logging + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "log" + "time" +) + +type goLogger struct { + logger *log.Logger + level Level + entryFormatter formatter +} + +const ( + _LEVEL = "_level" + _MSG = "_msg" + _TIME = "_time" + _RLEVEL = "_rlevel" +) + +func NewLogger(out io.Writer, lvl Level, fmtLogging LogEntryFormatter) *goLogger { + logger := &goLogger{ + logger: log.New(out, "", 0), + level: lvl, + } + if fmtLogging == JSONFORMATTER { + logger.entryFormatter = &jsonFormatter{} + } else if fmtLogging == KVFORMATTER { + logger.entryFormatter = &keyvalueFormatter{} + } else { + logger.entryFormatter = &textFormatter{} + } + return logger +} + +func (gl *goLogger) Logp(level Level, msg string, kv ...Pair) { + if gl.logger == nil { + return + } + if level <= gl.level { + e := newLogEntry(msg, level) + copyPairs(e, kv) + gl.log(e) + } +} + +func (gl *goLogger) Debugp(msg string, kv ...Pair) { + gl.Logp(DEBUG, msg, kv...) +} + +func (gl *goLogger) Tracep(msg string, kv ...Pair) { + gl.Logp(TRACE, msg, kv...) +} + +func (gl *goLogger) Requestp(rlevel Level, msg string, kv ...Pair) { + if gl.logger == nil { + return + } + if REQUEST <= gl.level { + e := newLogEntry(msg, REQUEST) + e.Rlevel = rlevel + copyPairs(e, kv) + gl.log(e) + } +} + +func (gl *goLogger) Infop(msg string, kv ...Pair) { + gl.Logp(INFO, msg, kv...) +} + +func (gl *goLogger) Warnp(msg string, kv ...Pair) { + gl.Logp(WARN, msg, kv...) +} + +func (gl *goLogger) Errorp(msg string, kv ...Pair) { + gl.Logp(ERROR, msg, kv...) +} + +func (gl *goLogger) Severep(msg string, kv ...Pair) { + gl.Logp(SEVERE, msg, kv...) +} + +func (gl *goLogger) Fatalp(msg string, kv ...Pair) { + gl.Logp(FATAL, msg, kv...) +} + +func (gl *goLogger) Logm(level Level, msg string, kv Map) { + if gl.logger == nil { + return + } + if level <= gl.level { + e := newLogEntry(msg, level) + e.Data = kv + gl.log(e) + } +} + +func (gl *goLogger) Debugm(msg string, kv Map) { + gl.Logm(DEBUG, msg, kv) +} + +func (gl *goLogger) Tracem(msg string, kv Map) { + gl.Logm(TRACE, msg, kv) +} + +func (gl *goLogger) Requestm(rlevel Level, msg string, kv Map) { + if gl.logger == nil { + return + } + if REQUEST <= gl.level { + e := newLogEntry(msg, REQUEST) + e.Rlevel = rlevel + e.Data = kv + gl.log(e) + } +} + +func (gl *goLogger) Infom(msg string, kv Map) { + gl.Logm(INFO, msg, kv) +} + +func (gl *goLogger) Warnm(msg string, kv Map) { + gl.Logm(WARN, msg, kv) +} + +func (gl *goLogger) Errorm(msg string, kv Map) { + gl.Logm(ERROR, msg, kv) +} + +func (gl *goLogger) Severem(msg string, kv Map) { + gl.Logm(SEVERE, msg, kv) +} + +func (gl *goLogger) Fatalm(msg string, kv Map) { + gl.Logm(FATAL, msg, kv) +} + +func (gl *goLogger) Logf(level Level, format string, args ...interface{}) { + if gl.logger == nil { + return + } + if level <= gl.level { + e := newLogEntry(fmt.Sprintf(format, args...), level) + gl.log(e) + } +} + +func (gl *goLogger) Debugf(format string, args ...interface{}) { + gl.Logf(DEBUG, format, args...) +} + +func (gl *goLogger) Tracef(format string, args ...interface{}) { + gl.Logf(TRACE, format, args...) +} + +func (gl *goLogger) Requestf(rlevel Level, format string, args ...interface{}) { + if gl.logger == nil { + return + } + if REQUEST <= gl.level { + e := newLogEntry(fmt.Sprintf(format, args...), REQUEST) + e.Rlevel = rlevel + gl.log(e) + } +} + +func (gl *goLogger) Infof(format string, args ...interface{}) { + gl.Logf(INFO, format, args...) +} + +func (gl *goLogger) Warnf(format string, args ...interface{}) { + gl.Logf(WARN, format, args...) +} + +func (gl *goLogger) Errorf(format string, args ...interface{}) { + gl.Logf(ERROR, format, args...) +} + +func (gl *goLogger) Severef(format string, args ...interface{}) { + gl.Logf(SEVERE, format, args...) +} + +func (gl *goLogger) Fatalf(format string, args ...interface{}) { + gl.Logf(FATAL, format, args...) +} + +func (gl *goLogger) Level() Level { + return gl.level +} + +func (gl *goLogger) SetLevel(level Level) { + gl.level = level +} + +func (gl *goLogger) log(newEntry *logEntry) { + s := gl.entryFormatter.format(newEntry) + gl.logger.Print(s) +} + +type logEntry struct { + Time string + Level Level + Rlevel Level + Message string + Data Map +} + +func newLogEntry(msg string, level Level) *logEntry { + return &logEntry{ + Time: time.Now().Format("2006-01-02T15:04:05.000-07:00"), // time.RFC3339 with milliseconds + Level: level, + Rlevel: NONE, + Message: msg, + } +} + +func copyPairs(newEntry *logEntry, pairs []Pair) { + newEntry.Data = make(Map, len(pairs)) + for _, p := range pairs { + newEntry.Data[p.Name] = p.Value + } +} + +type formatter interface { + format(*logEntry) string +} + +type textFormatter struct { +} + +// ex. 2016-02-10T09:15:25.498-08:00 [INFO] This is a message from test in text format + +func (*textFormatter) format(newEntry *logEntry) string { + b := &bytes.Buffer{} + appendValue(b, newEntry.Time) + if newEntry.Rlevel != NONE { + fmt.Fprintf(b, "[%s,%s] ", newEntry.Level.String(), newEntry.Rlevel.String()) + } else { + fmt.Fprintf(b, "[%s] ", newEntry.Level.String()) + } + appendValue(b, newEntry.Message) + for key, value := range newEntry.Data { + appendKeyValue(b, key, value) + } + b.WriteByte('\n') + s := bytes.NewBuffer(b.Bytes()) + return s.String() +} + +func appendValue(b *bytes.Buffer, value interface{}) { + if _, ok := value.(string); ok { + fmt.Fprintf(b, "%s ", value) + } else { + fmt.Fprintf(b, "%v ", value) + } +} + +type keyvalueFormatter struct { +} + +// ex. _time=2016-02-10T09:15:25.498-08:00 _level=INFO _msg=This is a message from test in key-value format + +func (*keyvalueFormatter) format(newEntry *logEntry) string { + b := &bytes.Buffer{} + appendKeyValue(b, _TIME, newEntry.Time) + appendKeyValue(b, _LEVEL, newEntry.Level.String()) + if newEntry.Rlevel != NONE { + appendKeyValue(b, _RLEVEL, newEntry.Rlevel.String()) + } + appendKeyValue(b, _MSG, newEntry.Message) + for key, value := range newEntry.Data { + appendKeyValue(b, key, value) + } + b.WriteByte('\n') + s := bytes.NewBuffer(b.Bytes()) + return s.String() +} + +func appendKeyValue(b *bytes.Buffer, key, value interface{}) { + if _, ok := value.(string); ok { + fmt.Fprintf(b, "%v=%s ", key, value) + } else { + fmt.Fprintf(b, "%v=%v ", key, value) + } +} + +type jsonFormatter struct { +} + +// ex. {"_level":"INFO","_msg":"This is a message from test in json format","_time":"2016-02-10T09:12:59.518-08:00"} + +func (*jsonFormatter) format(newEntry *logEntry) string { + if newEntry.Data == nil { + newEntry.Data = make(Map, 5) + } + newEntry.Data[_TIME] = newEntry.Time + newEntry.Data[_LEVEL] = newEntry.Level.String() + if newEntry.Rlevel != NONE { + newEntry.Data[_RLEVEL] = newEntry.Rlevel.String() + } + newEntry.Data[_MSG] = newEntry.Message + serialized, _ := json.Marshal(newEntry.Data) + s := bytes.NewBuffer(append(serialized, '\n')) + return s.String() +} diff --git a/vendor/github.com/couchbase/goutils/scramsha/scramsha.go b/vendor/github.com/couchbase/goutils/scramsha/scramsha.go new file mode 100644 index 00000000..b234bfc8 --- /dev/null +++ b/vendor/github.com/couchbase/goutils/scramsha/scramsha.go @@ -0,0 +1,207 @@ +// @author Couchbase +// @copyright 2018 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package scramsha provides implementation of client side SCRAM-SHA +// according to https://tools.ietf.org/html/rfc5802 +package scramsha + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "encoding/base64" + "fmt" + "github.com/pkg/errors" + "golang.org/x/crypto/pbkdf2" + "hash" + "strconv" + "strings" +) + +func hmacHash(message []byte, secret []byte, hashFunc func() hash.Hash) []byte { + h := hmac.New(hashFunc, secret) + h.Write(message) + return h.Sum(nil) +} + +func shaHash(message []byte, hashFunc func() hash.Hash) []byte { + h := hashFunc() + h.Write(message) + return h.Sum(nil) +} + +func generateClientNonce(size int) (string, error) { + randomBytes := make([]byte, size) + _, err := rand.Read(randomBytes) + if err != nil { + return "", errors.Wrap(err, "Unable to generate nonce") + } + return base64.StdEncoding.EncodeToString(randomBytes), nil +} + +// ScramSha provides context for SCRAM-SHA handling +type ScramSha struct { + hashSize int + hashFunc func() hash.Hash + clientNonce string + serverNonce string + salt []byte + i int + saltedPassword []byte + authMessage string +} + +var knownMethods = []string{"SCRAM-SHA512", "SCRAM-SHA256", "SCRAM-SHA1"} + +// BestMethod returns SCRAM-SHA method we consider the best out of suggested +// by server +func BestMethod(methods string) (string, error) { + for _, m := range knownMethods { + if strings.Index(methods, m) != -1 { + return m, nil + } + } + return "", errors.Errorf( + "None of the server suggested methods [%s] are supported", + methods) +} + +// NewScramSha creates context for SCRAM-SHA handling +func NewScramSha(method string) (*ScramSha, error) { + s := &ScramSha{} + + if method == knownMethods[0] { + s.hashFunc = sha512.New + s.hashSize = 64 + } else if method == knownMethods[1] { + s.hashFunc = sha256.New + s.hashSize = 32 + } else if method == knownMethods[2] { + s.hashFunc = sha1.New + s.hashSize = 20 + } else { + return nil, errors.Errorf("Unsupported method %s", method) + } + return s, nil +} + +// GetStartRequest builds start SCRAM-SHA request to be sent to server +func (s *ScramSha) GetStartRequest(user string) (string, error) { + var err error + s.clientNonce, err = generateClientNonce(24) + if err != nil { + return "", errors.Wrapf(err, "Unable to generate SCRAM-SHA "+ + "start request for user %s", user) + } + + message := fmt.Sprintf("n,,n=%s,r=%s", user, s.clientNonce) + s.authMessage = message[3:] + return message, nil +} + +// HandleStartResponse handles server response on start SCRAM-SHA request +func (s *ScramSha) HandleStartResponse(response string) error { + parts := strings.Split(response, ",") + if len(parts) != 3 { + return errors.Errorf("expected 3 fields in first SCRAM-SHA-1 "+ + "server message %s", response) + } + if !strings.HasPrefix(parts[0], "r=") || len(parts[0]) < 3 { + return errors.Errorf("Server sent an invalid nonce %s", + parts[0]) + } + if !strings.HasPrefix(parts[1], "s=") || len(parts[1]) < 3 { + return errors.Errorf("Server sent an invalid salt %s", parts[1]) + } + if !strings.HasPrefix(parts[2], "i=") || len(parts[2]) < 3 { + return errors.Errorf("Server sent an invalid iteration count %s", + parts[2]) + } + + s.serverNonce = parts[0][2:] + encodedSalt := parts[1][2:] + var err error + s.i, err = strconv.Atoi(parts[2][2:]) + if err != nil { + return errors.Errorf("Iteration count %s must be integer.", + parts[2][2:]) + } + + if s.i < 1 { + return errors.New("Iteration count should be positive") + } + + if !strings.HasPrefix(s.serverNonce, s.clientNonce) { + return errors.Errorf("Server nonce %s doesn't contain client"+ + " nonce %s", s.serverNonce, s.clientNonce) + } + + s.salt, err = base64.StdEncoding.DecodeString(encodedSalt) + if err != nil { + return errors.Wrapf(err, "Unable to decode salt %s", + encodedSalt) + } + + s.authMessage = s.authMessage + "," + response + return nil +} + +// GetFinalRequest builds final SCRAM-SHA request to be sent to server +func (s *ScramSha) GetFinalRequest(pass string) string { + clientFinalMessageBare := "c=biws,r=" + s.serverNonce + s.authMessage = s.authMessage + "," + clientFinalMessageBare + + s.saltedPassword = pbkdf2.Key([]byte(pass), s.salt, s.i, + s.hashSize, s.hashFunc) + + clientKey := hmacHash([]byte("Client Key"), s.saltedPassword, s.hashFunc) + storedKey := shaHash(clientKey, s.hashFunc) + clientSignature := hmacHash([]byte(s.authMessage), storedKey, s.hashFunc) + + clientProof := make([]byte, len(clientSignature)) + for i := 0; i < len(clientSignature); i++ { + clientProof[i] = clientKey[i] ^ clientSignature[i] + } + + return clientFinalMessageBare + ",p=" + + base64.StdEncoding.EncodeToString(clientProof) +} + +// HandleFinalResponse handles server's response on final SCRAM-SHA request +func (s *ScramSha) HandleFinalResponse(response string) error { + if strings.Contains(response, ",") || + !strings.HasPrefix(response, "v=") { + return errors.Errorf("Server sent an invalid final message %s", + response) + } + + decodedMessage, err := base64.StdEncoding.DecodeString(response[2:]) + if err != nil { + return errors.Wrapf(err, "Unable to decode server message %s", + response[2:]) + } + serverKey := hmacHash([]byte("Server Key"), s.saltedPassword, + s.hashFunc) + serverSignature := hmacHash([]byte(s.authMessage), serverKey, + s.hashFunc) + if string(decodedMessage) != string(serverSignature) { + return errors.Errorf("Server proof %s doesn't match "+ + "the expected: %s", + string(decodedMessage), string(serverSignature)) + } + return nil +} diff --git a/vendor/github.com/couchbase/goutils/scramsha/scramsha_http.go b/vendor/github.com/couchbase/goutils/scramsha/scramsha_http.go new file mode 100644 index 00000000..19f32b31 --- /dev/null +++ b/vendor/github.com/couchbase/goutils/scramsha/scramsha_http.go @@ -0,0 +1,252 @@ +// @author Couchbase +// @copyright 2018 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package scramsha provides implementation of client side SCRAM-SHA +// via Http according to https://tools.ietf.org/html/rfc7804 +package scramsha + +import ( + "encoding/base64" + "github.com/pkg/errors" + "io" + "io/ioutil" + "net/http" + "strings" +) + +// consts used to parse scramsha response from target +const ( + WWWAuthenticate = "WWW-Authenticate" + AuthenticationInfo = "Authentication-Info" + Authorization = "Authorization" + DataPrefix = "data=" + SidPrefix = "sid=" +) + +// Request provides implementation of http request that can be retried +type Request struct { + body io.ReadSeeker + + // Embed an HTTP request directly. This makes a *Request act exactly + // like an *http.Request so that all meta methods are supported. + *http.Request +} + +type lenReader interface { + Len() int +} + +// NewRequest creates http request that can be retried +func NewRequest(method, url string, body io.ReadSeeker) (*Request, error) { + // Wrap the body in a noop ReadCloser if non-nil. This prevents the + // reader from being closed by the HTTP client. + var rcBody io.ReadCloser + if body != nil { + rcBody = ioutil.NopCloser(body) + } + + // Make the request with the noop-closer for the body. + httpReq, err := http.NewRequest(method, url, rcBody) + if err != nil { + return nil, err + } + + // Check if we can set the Content-Length automatically. + if lr, ok := body.(lenReader); ok { + httpReq.ContentLength = int64(lr.Len()) + } + + return &Request{body, httpReq}, nil +} + +func encode(str string) string { + return base64.StdEncoding.EncodeToString([]byte(str)) +} + +func decode(str string) (string, error) { + bytes, err := base64.StdEncoding.DecodeString(str) + if err != nil { + return "", errors.Errorf("Cannot base64 decode %s", + str) + } + return string(bytes), err +} + +func trimPrefix(s, prefix string) (string, error) { + l := len(s) + trimmed := strings.TrimPrefix(s, prefix) + if l == len(trimmed) { + return trimmed, errors.Errorf("Prefix %s not found in %s", + prefix, s) + } + return trimmed, nil +} + +func drainBody(resp *http.Response) { + defer resp.Body.Close() + io.Copy(ioutil.Discard, resp.Body) +} + +// DoScramSha performs SCRAM-SHA handshake via Http +func DoScramSha(req *Request, + username string, + password string, + client *http.Client) (*http.Response, error) { + + method := "SCRAM-SHA-512" + s, err := NewScramSha("SCRAM-SHA512") + if err != nil { + return nil, errors.Wrap(err, + "Unable to initialize SCRAM-SHA handler") + } + + message, err := s.GetStartRequest(username) + if err != nil { + return nil, err + } + + encodedMessage := method + " " + DataPrefix + encode(message) + + req.Header.Set(Authorization, encodedMessage) + + res, err := client.Do(req.Request) + if err != nil { + return nil, errors.Wrap(err, "Problem sending SCRAM-SHA start"+ + "request") + } + + if res.StatusCode != http.StatusUnauthorized { + return res, nil + } + + authHeader := res.Header.Get(WWWAuthenticate) + if authHeader == "" { + drainBody(res) + return nil, errors.Errorf("Header %s is not populated in "+ + "SCRAM-SHA start response", WWWAuthenticate) + } + + authHeader, err = trimPrefix(authHeader, method+" ") + if err != nil { + if strings.HasPrefix(authHeader, "Basic ") { + // user not found + return res, nil + } + drainBody(res) + return nil, errors.Wrapf(err, "Error while parsing SCRAM-SHA "+ + "start response %s", authHeader) + } + + drainBody(res) + + sid, response, err := parseSidAndData(authHeader) + if err != nil { + return nil, errors.Wrapf(err, "Error while parsing SCRAM-SHA "+ + "start response %s", authHeader) + } + + err = s.HandleStartResponse(response) + if err != nil { + return nil, errors.Wrapf(err, "Error parsing SCRAM-SHA start "+ + "response %s", response) + } + + message = s.GetFinalRequest(password) + encodedMessage = method + " " + SidPrefix + sid + "," + DataPrefix + + encode(message) + + req.Header.Set(Authorization, encodedMessage) + + // rewind request body so it can be resent again + if req.body != nil { + if _, err = req.body.Seek(0, 0); err != nil { + return nil, errors.Errorf("Failed to seek body: %v", + err) + } + } + + res, err = client.Do(req.Request) + if err != nil { + return nil, errors.Wrap(err, "Problem sending SCRAM-SHA final"+ + "request") + } + + if res.StatusCode == http.StatusUnauthorized { + // TODO retrieve and return error + return res, nil + } + + if res.StatusCode >= http.StatusInternalServerError { + // in this case we cannot expect server to set headers properly + return res, nil + } + + authHeader = res.Header.Get(AuthenticationInfo) + if authHeader == "" { + drainBody(res) + return nil, errors.Errorf("Header %s is not populated in "+ + "SCRAM-SHA final response", AuthenticationInfo) + } + + finalSid, response, err := parseSidAndData(authHeader) + if err != nil { + drainBody(res) + return nil, errors.Wrapf(err, "Error while parsing SCRAM-SHA "+ + "final response %s", authHeader) + } + + if finalSid != sid { + drainBody(res) + return nil, errors.Errorf("Sid %s returned by server "+ + "doesn't match the original sid %s", finalSid, sid) + } + + err = s.HandleFinalResponse(response) + if err != nil { + drainBody(res) + return nil, errors.Wrapf(err, + "Error handling SCRAM-SHA final server response %s", + response) + } + return res, nil +} + +func parseSidAndData(authHeader string) (string, string, error) { + sidIndex := strings.Index(authHeader, SidPrefix) + if sidIndex < 0 { + return "", "", errors.Errorf("Cannot find %s in %s", + SidPrefix, authHeader) + } + + sidEndIndex := strings.Index(authHeader, ",") + if sidEndIndex < 0 { + return "", "", errors.Errorf("Cannot find ',' in %s", + authHeader) + } + + sid := authHeader[sidIndex+len(SidPrefix) : sidEndIndex] + + dataIndex := strings.Index(authHeader, DataPrefix) + if dataIndex < 0 { + return "", "", errors.Errorf("Cannot find %s in %s", + DataPrefix, authHeader) + } + + data, err := decode(authHeader[dataIndex+len(DataPrefix):]) + if err != nil { + return "", "", err + } + return sid, data, nil +} diff --git a/vendor/github.com/cupcake/rdb/.gitignore b/vendor/github.com/cupcake/rdb/.gitignore new file mode 100644 index 00000000..fcc1e666 --- /dev/null +++ b/vendor/github.com/cupcake/rdb/.gitignore @@ -0,0 +1,25 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe + +# Project-specific files +diff diff --git a/vendor/github.com/cupcake/rdb/.travis.yml b/vendor/github.com/cupcake/rdb/.travis.yml new file mode 100644 index 00000000..49c6fb8f --- /dev/null +++ b/vendor/github.com/cupcake/rdb/.travis.yml @@ -0,0 +1,6 @@ +language: go +go: + - 1.1 + - tip +before_install: + - go get gopkg.in/check.v1 diff --git a/vendor/github.com/cupcake/rdb/LICENCE b/vendor/github.com/cupcake/rdb/LICENCE new file mode 100644 index 00000000..50257901 --- /dev/null +++ b/vendor/github.com/cupcake/rdb/LICENCE @@ -0,0 +1,21 @@ +Copyright (c) 2012 Jonathan Rudenberg +Copyright (c) 2012 Sripathi Krishnan + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/cupcake/rdb/README.md b/vendor/github.com/cupcake/rdb/README.md new file mode 100644 index 00000000..5c19212c --- /dev/null +++ b/vendor/github.com/cupcake/rdb/README.md @@ -0,0 +1,17 @@ +# rdb [![Build Status](https://travis-ci.org/cupcake/rdb.png?branch=master)](https://travis-ci.org/cupcake/rdb) + +rdb is a Go package that implements parsing and encoding of the +[Redis](http://redis.io) [RDB file +format](https://github.com/sripathikrishnan/redis-rdb-tools/blob/master/docs/RDB_File_Format.textile). + +This package was heavily inspired by +[redis-rdb-tools](https://github.com/sripathikrishnan/redis-rdb-tools) by +[Sripathi Krishnan](https://github.com/sripathikrishnan). + +[**Documentation**](http://godoc.org/github.com/cupcake/rdb) + +## Installation + +``` +go get github.com/cupcake/rdb +``` diff --git a/vendor/github.com/cupcake/rdb/crc64/crc64.go b/vendor/github.com/cupcake/rdb/crc64/crc64.go new file mode 100644 index 00000000..54fed9c5 --- /dev/null +++ b/vendor/github.com/cupcake/rdb/crc64/crc64.go @@ -0,0 +1,64 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package crc64 implements the Jones coefficients with an init value of 0. +package crc64 + +import "hash" + +// Redis uses the CRC64 variant with "Jones" coefficients and init value of 0. +// +// Specification of this CRC64 variant follows: +// Name: crc-64-jones +// Width: 64 bits +// Poly: 0xad93d23594c935a9 +// Reflected In: True +// Xor_In: 0xffffffffffffffff +// Reflected_Out: True +// Xor_Out: 0x0 + +var table = [256]uint64{0x0000000000000000, 0x7ad870c830358979, 0xf5b0e190606b12f2, 0x8f689158505e9b8b, 0xc038e5739841b68f, 0xbae095bba8743ff6, 0x358804e3f82aa47d, 0x4f50742bc81f2d04, 0xab28ecb46814fe75, 0xd1f09c7c5821770c, 0x5e980d24087fec87, 0x24407dec384a65fe, 0x6b1009c7f05548fa, 0x11c8790fc060c183, 0x9ea0e857903e5a08, 0xe478989fa00bd371, 0x7d08ff3b88be6f81, 0x07d08ff3b88be6f8, 0x88b81eabe8d57d73, 0xf2606e63d8e0f40a, 0xbd301a4810ffd90e, 0xc7e86a8020ca5077, 0x4880fbd87094cbfc, 0x32588b1040a14285, 0xd620138fe0aa91f4, 0xacf86347d09f188d, 0x2390f21f80c18306, 0x594882d7b0f40a7f, 0x1618f6fc78eb277b, 0x6cc0863448deae02, 0xe3a8176c18803589, 0x997067a428b5bcf0, 0xfa11fe77117cdf02, 0x80c98ebf2149567b, 0x0fa11fe77117cdf0, 0x75796f2f41224489, 0x3a291b04893d698d, 0x40f16bccb908e0f4, 0xcf99fa94e9567b7f, 0xb5418a5cd963f206, 0x513912c379682177, 0x2be1620b495da80e, 0xa489f35319033385, 0xde51839b2936bafc, 0x9101f7b0e12997f8, 0xebd98778d11c1e81, 0x64b116208142850a, 0x1e6966e8b1770c73, 0x8719014c99c2b083, 0xfdc17184a9f739fa, 0x72a9e0dcf9a9a271, 0x08719014c99c2b08, 0x4721e43f0183060c, 0x3df994f731b68f75, 0xb29105af61e814fe, 0xc849756751dd9d87, 0x2c31edf8f1d64ef6, 0x56e99d30c1e3c78f, 0xd9810c6891bd5c04, 0xa3597ca0a188d57d, 0xec09088b6997f879, 0x96d1784359a27100, 0x19b9e91b09fcea8b, 0x636199d339c963f2, 0xdf7adabd7a6e2d6f, 0xa5a2aa754a5ba416, 0x2aca3b2d1a053f9d, 0x50124be52a30b6e4, 0x1f423fcee22f9be0, 0x659a4f06d21a1299, 0xeaf2de5e82448912, 0x902aae96b271006b, 0x74523609127ad31a, 0x0e8a46c1224f5a63, 0x81e2d7997211c1e8, 0xfb3aa75142244891, 0xb46ad37a8a3b6595, 0xceb2a3b2ba0eecec, 0x41da32eaea507767, 0x3b024222da65fe1e, 0xa2722586f2d042ee, 0xd8aa554ec2e5cb97, 0x57c2c41692bb501c, 0x2d1ab4dea28ed965, 0x624ac0f56a91f461, 0x1892b03d5aa47d18, 0x97fa21650afae693, 0xed2251ad3acf6fea, 0x095ac9329ac4bc9b, 0x7382b9faaaf135e2, 0xfcea28a2faafae69, 0x8632586aca9a2710, 0xc9622c4102850a14, 0xb3ba5c8932b0836d, 0x3cd2cdd162ee18e6, 0x460abd1952db919f, 0x256b24ca6b12f26d, 0x5fb354025b277b14, 0xd0dbc55a0b79e09f, 0xaa03b5923b4c69e6, 0xe553c1b9f35344e2, 0x9f8bb171c366cd9b, 0x10e3202993385610, 0x6a3b50e1a30ddf69, 0x8e43c87e03060c18, 0xf49bb8b633338561, 0x7bf329ee636d1eea, 0x012b592653589793, 0x4e7b2d0d9b47ba97, 0x34a35dc5ab7233ee, 0xbbcbcc9dfb2ca865, 0xc113bc55cb19211c, 0x5863dbf1e3ac9dec, 0x22bbab39d3991495, 0xadd33a6183c78f1e, 0xd70b4aa9b3f20667, 0x985b3e827bed2b63, 0xe2834e4a4bd8a21a, 0x6debdf121b863991, 0x1733afda2bb3b0e8, 0xf34b37458bb86399, 0x8993478dbb8deae0, 0x06fbd6d5ebd3716b, 0x7c23a61ddbe6f812, 0x3373d23613f9d516, 0x49aba2fe23cc5c6f, 0xc6c333a67392c7e4, 0xbc1b436e43a74e9d, 0x95ac9329ac4bc9b5, 0xef74e3e19c7e40cc, 0x601c72b9cc20db47, 0x1ac40271fc15523e, 0x5594765a340a7f3a, 0x2f4c0692043ff643, 0xa02497ca54616dc8, 0xdafce7026454e4b1, 0x3e847f9dc45f37c0, 0x445c0f55f46abeb9, 0xcb349e0da4342532, 0xb1eceec59401ac4b, 0xfebc9aee5c1e814f, 0x8464ea266c2b0836, 0x0b0c7b7e3c7593bd, 0x71d40bb60c401ac4, 0xe8a46c1224f5a634, 0x927c1cda14c02f4d, 0x1d148d82449eb4c6, 0x67ccfd4a74ab3dbf, 0x289c8961bcb410bb, 0x5244f9a98c8199c2, 0xdd2c68f1dcdf0249, 0xa7f41839ecea8b30, 0x438c80a64ce15841, 0x3954f06e7cd4d138, 0xb63c61362c8a4ab3, 0xcce411fe1cbfc3ca, 0x83b465d5d4a0eece, 0xf96c151de49567b7, 0x76048445b4cbfc3c, 0x0cdcf48d84fe7545, 0x6fbd6d5ebd3716b7, 0x15651d968d029fce, 0x9a0d8ccedd5c0445, 0xe0d5fc06ed698d3c, 0xaf85882d2576a038, 0xd55df8e515432941, 0x5a3569bd451db2ca, 0x20ed197575283bb3, 0xc49581ead523e8c2, 0xbe4df122e51661bb, 0x3125607ab548fa30, 0x4bfd10b2857d7349, 0x04ad64994d625e4d, 0x7e7514517d57d734, 0xf11d85092d094cbf, 0x8bc5f5c11d3cc5c6, 0x12b5926535897936, 0x686de2ad05bcf04f, 0xe70573f555e26bc4, 0x9ddd033d65d7e2bd, 0xd28d7716adc8cfb9, 0xa85507de9dfd46c0, 0x273d9686cda3dd4b, 0x5de5e64efd965432, 0xb99d7ed15d9d8743, 0xc3450e196da80e3a, 0x4c2d9f413df695b1, 0x36f5ef890dc31cc8, 0x79a59ba2c5dc31cc, 0x037deb6af5e9b8b5, 0x8c157a32a5b7233e, 0xf6cd0afa9582aa47, 0x4ad64994d625e4da, 0x300e395ce6106da3, 0xbf66a804b64ef628, 0xc5bed8cc867b7f51, 0x8aeeace74e645255, 0xf036dc2f7e51db2c, 0x7f5e4d772e0f40a7, 0x05863dbf1e3ac9de, 0xe1fea520be311aaf, 0x9b26d5e88e0493d6, 0x144e44b0de5a085d, 0x6e963478ee6f8124, 0x21c640532670ac20, 0x5b1e309b16452559, 0xd476a1c3461bbed2, 0xaeaed10b762e37ab, 0x37deb6af5e9b8b5b, 0x4d06c6676eae0222, 0xc26e573f3ef099a9, 0xb8b627f70ec510d0, 0xf7e653dcc6da3dd4, 0x8d3e2314f6efb4ad, 0x0256b24ca6b12f26, 0x788ec2849684a65f, 0x9cf65a1b368f752e, 0xe62e2ad306bafc57, 0x6946bb8b56e467dc, 0x139ecb4366d1eea5, 0x5ccebf68aecec3a1, 0x2616cfa09efb4ad8, 0xa97e5ef8cea5d153, 0xd3a62e30fe90582a, 0xb0c7b7e3c7593bd8, 0xca1fc72bf76cb2a1, 0x45775673a732292a, 0x3faf26bb9707a053, 0x70ff52905f188d57, 0x0a2722586f2d042e, 0x854fb3003f739fa5, 0xff97c3c80f4616dc, 0x1bef5b57af4dc5ad, 0x61372b9f9f784cd4, 0xee5fbac7cf26d75f, 0x9487ca0fff135e26, 0xdbd7be24370c7322, 0xa10fceec0739fa5b, 0x2e675fb4576761d0, 0x54bf2f7c6752e8a9, 0xcdcf48d84fe75459, 0xb71738107fd2dd20, 0x387fa9482f8c46ab, 0x42a7d9801fb9cfd2, 0x0df7adabd7a6e2d6, 0x772fdd63e7936baf, 0xf8474c3bb7cdf024, 0x829f3cf387f8795d, 0x66e7a46c27f3aa2c, 0x1c3fd4a417c62355, 0x935745fc4798b8de, 0xe98f353477ad31a7, 0xa6df411fbfb21ca3, 0xdc0731d78f8795da, 0x536fa08fdfd90e51, 0x29b7d047efec8728} + +func crc64(crc uint64, b []byte) uint64 { + for _, v := range b { + crc = table[byte(crc)^v] ^ (crc >> 8) + } + return crc +} + +func Digest(b []byte) uint64 { + return crc64(0, b) +} + +type digest struct { + crc uint64 +} + +func New() hash.Hash64 { + return &digest{} +} + +func (h *digest) Write(p []byte) (int, error) { + h.crc = crc64(h.crc, p) + return len(p), nil +} + +// Encode in little endian +func (d *digest) Sum(in []byte) []byte { + s := d.Sum64() + in = append(in, byte(s)) + in = append(in, byte(s>>8)) + in = append(in, byte(s>>16)) + in = append(in, byte(s>>24)) + in = append(in, byte(s>>32)) + in = append(in, byte(s>>40)) + in = append(in, byte(s>>48)) + in = append(in, byte(s>>56)) + return in +} + +func (d *digest) Sum64() uint64 { return d.crc } +func (d *digest) BlockSize() int { return 1 } +func (d *digest) Size() int { return 8 } +func (d *digest) Reset() { d.crc = 0 } diff --git a/vendor/github.com/cupcake/rdb/decoder.go b/vendor/github.com/cupcake/rdb/decoder.go new file mode 100644 index 00000000..dd3993b5 --- /dev/null +++ b/vendor/github.com/cupcake/rdb/decoder.go @@ -0,0 +1,824 @@ +// Package rdb implements parsing and encoding of the Redis RDB file format. +package rdb + +import ( + "bufio" + "bytes" + "encoding/binary" + "fmt" + "io" + "math" + "strconv" + + "github.com/cupcake/rdb/crc64" +) + +// A Decoder must be implemented to parse a RDB file. +type Decoder interface { + // StartRDB is called when parsing of a valid RDB file starts. + StartRDB() + // StartDatabase is called when database n starts. + // Once a database starts, another database will not start until EndDatabase is called. + StartDatabase(n int) + // AUX field + Aux(key, value []byte) + // ResizeDB hint + ResizeDatabase(dbSize, expiresSize uint32) + // Set is called once for each string key. + Set(key, value []byte, expiry int64) + // StartHash is called at the beginning of a hash. + // Hset will be called exactly length times before EndHash. + StartHash(key []byte, length, expiry int64) + // Hset is called once for each field=value pair in a hash. + Hset(key, field, value []byte) + // EndHash is called when there are no more fields in a hash. + EndHash(key []byte) + // StartSet is called at the beginning of a set. + // Sadd will be called exactly cardinality times before EndSet. + StartSet(key []byte, cardinality, expiry int64) + // Sadd is called once for each member of a set. + Sadd(key, member []byte) + // EndSet is called when there are no more fields in a set. + EndSet(key []byte) + // StartList is called at the beginning of a list. + // Rpush will be called exactly length times before EndList. + // If length of the list is not known, then length is -1 + StartList(key []byte, length, expiry int64) + // Rpush is called once for each value in a list. + Rpush(key, value []byte) + // EndList is called when there are no more values in a list. + EndList(key []byte) + // StartZSet is called at the beginning of a sorted set. + // Zadd will be called exactly cardinality times before EndZSet. + StartZSet(key []byte, cardinality, expiry int64) + // Zadd is called once for each member of a sorted set. + Zadd(key []byte, score float64, member []byte) + // EndZSet is called when there are no more members in a sorted set. + EndZSet(key []byte) + // EndDatabase is called at the end of a database. + EndDatabase(n int) + // EndRDB is called when parsing of the RDB file is complete. + EndRDB() +} + +// Decode parses a RDB file from r and calls the decode hooks on d. +func Decode(r io.Reader, d Decoder) error { + decoder := &decode{d, make([]byte, 8), bufio.NewReader(r)} + return decoder.decode() +} + +// Decode a byte slice from the Redis DUMP command. The dump does not contain the +// database, key or expiry, so they must be included in the function call (but +// can be zero values). +func DecodeDump(dump []byte, db int, key []byte, expiry int64, d Decoder) error { + err := verifyDump(dump) + if err != nil { + return err + } + + decoder := &decode{d, make([]byte, 8), bytes.NewReader(dump[1:])} + decoder.event.StartRDB() + decoder.event.StartDatabase(db) + + err = decoder.readObject(key, ValueType(dump[0]), expiry) + + decoder.event.EndDatabase(db) + decoder.event.EndRDB() + return err +} + +type byteReader interface { + io.Reader + io.ByteReader +} + +type decode struct { + event Decoder + intBuf []byte + r byteReader +} + +type ValueType byte + +const ( + TypeString ValueType = 0 + TypeList ValueType = 1 + TypeSet ValueType = 2 + TypeZSet ValueType = 3 + TypeHash ValueType = 4 + + TypeHashZipmap ValueType = 9 + TypeListZiplist ValueType = 10 + TypeSetIntset ValueType = 11 + TypeZSetZiplist ValueType = 12 + TypeHashZiplist ValueType = 13 + TypeListQuicklist ValueType = 14 +) + +const ( + rdb6bitLen = 0 + rdb14bitLen = 1 + rdb32bitLen = 2 + rdbEncVal = 3 + + rdbFlagAux = 0xfa + rdbFlagResizeDB = 0xfb + rdbFlagExpiryMS = 0xfc + rdbFlagExpiry = 0xfd + rdbFlagSelectDB = 0xfe + rdbFlagEOF = 0xff + + rdbEncInt8 = 0 + rdbEncInt16 = 1 + rdbEncInt32 = 2 + rdbEncLZF = 3 + + rdbZiplist6bitlenString = 0 + rdbZiplist14bitlenString = 1 + rdbZiplist32bitlenString = 2 + + rdbZiplistInt16 = 0xc0 + rdbZiplistInt32 = 0xd0 + rdbZiplistInt64 = 0xe0 + rdbZiplistInt24 = 0xf0 + rdbZiplistInt8 = 0xfe + rdbZiplistInt4 = 15 +) + +func (d *decode) decode() error { + err := d.checkHeader() + if err != nil { + return err + } + + d.event.StartRDB() + + var db uint32 + var expiry int64 + firstDB := true + for { + objType, err := d.r.ReadByte() + if err != nil { + return err + } + switch objType { + case rdbFlagAux: + auxKey, err := d.readString() + if err != nil { + return err + } + auxVal, err := d.readString() + if err != nil { + return err + } + d.event.Aux(auxKey, auxVal) + case rdbFlagResizeDB: + dbSize, _, err := d.readLength() + if err != nil { + return err + } + expiresSize, _, err := d.readLength() + if err != nil { + return err + } + d.event.ResizeDatabase(dbSize, expiresSize) + case rdbFlagExpiryMS: + _, err := io.ReadFull(d.r, d.intBuf) + if err != nil { + return err + } + expiry = int64(binary.LittleEndian.Uint64(d.intBuf)) + case rdbFlagExpiry: + _, err := io.ReadFull(d.r, d.intBuf[:4]) + if err != nil { + return err + } + expiry = int64(binary.LittleEndian.Uint32(d.intBuf)) * 1000 + case rdbFlagSelectDB: + if !firstDB { + d.event.EndDatabase(int(db)) + } + db, _, err = d.readLength() + if err != nil { + return err + } + d.event.StartDatabase(int(db)) + case rdbFlagEOF: + d.event.EndDatabase(int(db)) + d.event.EndRDB() + return nil + default: + key, err := d.readString() + if err != nil { + return err + } + err = d.readObject(key, ValueType(objType), expiry) + if err != nil { + return err + } + expiry = 0 + } + } + + panic("not reached") +} + +func (d *decode) readObject(key []byte, typ ValueType, expiry int64) error { + switch typ { + case TypeString: + value, err := d.readString() + if err != nil { + return err + } + d.event.Set(key, value, expiry) + case TypeList: + length, _, err := d.readLength() + if err != nil { + return err + } + d.event.StartList(key, int64(length), expiry) + for i := uint32(0); i < length; i++ { + value, err := d.readString() + if err != nil { + return err + } + d.event.Rpush(key, value) + } + d.event.EndList(key) + case TypeListQuicklist: + length, _, err := d.readLength() + if err != nil { + return err + } + d.event.StartList(key, int64(-1), expiry) + for i := uint32(0); i < length; i++ { + d.readZiplist(key, 0, false) + } + d.event.EndList(key) + case TypeSet: + cardinality, _, err := d.readLength() + if err != nil { + return err + } + d.event.StartSet(key, int64(cardinality), expiry) + for i := uint32(0); i < cardinality; i++ { + member, err := d.readString() + if err != nil { + return err + } + d.event.Sadd(key, member) + } + d.event.EndSet(key) + case TypeZSet: + cardinality, _, err := d.readLength() + if err != nil { + return err + } + d.event.StartZSet(key, int64(cardinality), expiry) + for i := uint32(0); i < cardinality; i++ { + member, err := d.readString() + if err != nil { + return err + } + score, err := d.readFloat64() + if err != nil { + return err + } + d.event.Zadd(key, score, member) + } + d.event.EndZSet(key) + case TypeHash: + length, _, err := d.readLength() + if err != nil { + return err + } + d.event.StartHash(key, int64(length), expiry) + for i := uint32(0); i < length; i++ { + field, err := d.readString() + if err != nil { + return err + } + value, err := d.readString() + if err != nil { + return err + } + d.event.Hset(key, field, value) + } + d.event.EndHash(key) + case TypeHashZipmap: + return d.readZipmap(key, expiry) + case TypeListZiplist: + return d.readZiplist(key, expiry, true) + case TypeSetIntset: + return d.readIntset(key, expiry) + case TypeZSetZiplist: + return d.readZiplistZset(key, expiry) + case TypeHashZiplist: + return d.readZiplistHash(key, expiry) + default: + return fmt.Errorf("rdb: unknown object type %d for key %s", typ, key) + } + return nil +} + +func (d *decode) readZipmap(key []byte, expiry int64) error { + var length int + zipmap, err := d.readString() + if err != nil { + return err + } + buf := newSliceBuffer(zipmap) + lenByte, err := buf.ReadByte() + if err != nil { + return err + } + if lenByte >= 254 { // we need to count the items manually + length, err = countZipmapItems(buf) + length /= 2 + if err != nil { + return err + } + } else { + length = int(lenByte) + } + d.event.StartHash(key, int64(length), expiry) + for i := 0; i < length; i++ { + field, err := readZipmapItem(buf, false) + if err != nil { + return err + } + value, err := readZipmapItem(buf, true) + if err != nil { + return err + } + d.event.Hset(key, field, value) + } + d.event.EndHash(key) + return nil +} + +func readZipmapItem(buf *sliceBuffer, readFree bool) ([]byte, error) { + length, free, err := readZipmapItemLength(buf, readFree) + if err != nil { + return nil, err + } + if length == -1 { + return nil, nil + } + value, err := buf.Slice(length) + if err != nil { + return nil, err + } + _, err = buf.Seek(int64(free), 1) + return value, err +} + +func countZipmapItems(buf *sliceBuffer) (int, error) { + n := 0 + for { + strLen, free, err := readZipmapItemLength(buf, n%2 != 0) + if err != nil { + return 0, err + } + if strLen == -1 { + break + } + _, err = buf.Seek(int64(strLen)+int64(free), 1) + if err != nil { + return 0, err + } + n++ + } + _, err := buf.Seek(0, 0) + return n, err +} + +func readZipmapItemLength(buf *sliceBuffer, readFree bool) (int, int, error) { + b, err := buf.ReadByte() + if err != nil { + return 0, 0, err + } + switch b { + case 253: + s, err := buf.Slice(5) + if err != nil { + return 0, 0, err + } + return int(binary.BigEndian.Uint32(s)), int(s[4]), nil + case 254: + return 0, 0, fmt.Errorf("rdb: invalid zipmap item length") + case 255: + return -1, 0, nil + } + var free byte + if readFree { + free, err = buf.ReadByte() + } + return int(b), int(free), err +} + +func (d *decode) readZiplist(key []byte, expiry int64, addListEvents bool) error { + ziplist, err := d.readString() + if err != nil { + return err + } + buf := newSliceBuffer(ziplist) + length, err := readZiplistLength(buf) + if err != nil { + return err + } + if addListEvents { + d.event.StartList(key, length, expiry) + } + for i := int64(0); i < length; i++ { + entry, err := readZiplistEntry(buf) + if err != nil { + return err + } + d.event.Rpush(key, entry) + } + if addListEvents { + d.event.EndList(key) + } + return nil +} + +func (d *decode) readZiplistZset(key []byte, expiry int64) error { + ziplist, err := d.readString() + if err != nil { + return err + } + buf := newSliceBuffer(ziplist) + cardinality, err := readZiplistLength(buf) + if err != nil { + return err + } + cardinality /= 2 + d.event.StartZSet(key, cardinality, expiry) + for i := int64(0); i < cardinality; i++ { + member, err := readZiplistEntry(buf) + if err != nil { + return err + } + scoreBytes, err := readZiplistEntry(buf) + if err != nil { + return err + } + score, err := strconv.ParseFloat(string(scoreBytes), 64) + if err != nil { + return err + } + d.event.Zadd(key, score, member) + } + d.event.EndZSet(key) + return nil +} + +func (d *decode) readZiplistHash(key []byte, expiry int64) error { + ziplist, err := d.readString() + if err != nil { + return err + } + buf := newSliceBuffer(ziplist) + length, err := readZiplistLength(buf) + if err != nil { + return err + } + length /= 2 + d.event.StartHash(key, length, expiry) + for i := int64(0); i < length; i++ { + field, err := readZiplistEntry(buf) + if err != nil { + return err + } + value, err := readZiplistEntry(buf) + if err != nil { + return err + } + d.event.Hset(key, field, value) + } + d.event.EndHash(key) + return nil +} + +func readZiplistLength(buf *sliceBuffer) (int64, error) { + buf.Seek(8, 0) // skip the zlbytes and zltail + lenBytes, err := buf.Slice(2) + if err != nil { + return 0, err + } + return int64(binary.LittleEndian.Uint16(lenBytes)), nil +} + +func readZiplistEntry(buf *sliceBuffer) ([]byte, error) { + prevLen, err := buf.ReadByte() + if err != nil { + return nil, err + } + if prevLen == 254 { + buf.Seek(4, 1) // skip the 4-byte prevlen + } + + header, err := buf.ReadByte() + if err != nil { + return nil, err + } + switch { + case header>>6 == rdbZiplist6bitlenString: + return buf.Slice(int(header & 0x3f)) + case header>>6 == rdbZiplist14bitlenString: + b, err := buf.ReadByte() + if err != nil { + return nil, err + } + return buf.Slice((int(header&0x3f) << 8) | int(b)) + case header>>6 == rdbZiplist32bitlenString: + lenBytes, err := buf.Slice(4) + if err != nil { + return nil, err + } + return buf.Slice(int(binary.BigEndian.Uint32(lenBytes))) + case header == rdbZiplistInt16: + intBytes, err := buf.Slice(2) + if err != nil { + return nil, err + } + return []byte(strconv.FormatInt(int64(int16(binary.LittleEndian.Uint16(intBytes))), 10)), nil + case header == rdbZiplistInt32: + intBytes, err := buf.Slice(4) + if err != nil { + return nil, err + } + return []byte(strconv.FormatInt(int64(int32(binary.LittleEndian.Uint32(intBytes))), 10)), nil + case header == rdbZiplistInt64: + intBytes, err := buf.Slice(8) + if err != nil { + return nil, err + } + return []byte(strconv.FormatInt(int64(binary.LittleEndian.Uint64(intBytes)), 10)), nil + case header == rdbZiplistInt24: + intBytes := make([]byte, 4) + _, err := buf.Read(intBytes[1:]) + if err != nil { + return nil, err + } + return []byte(strconv.FormatInt(int64(int32(binary.LittleEndian.Uint32(intBytes))>>8), 10)), nil + case header == rdbZiplistInt8: + b, err := buf.ReadByte() + return []byte(strconv.FormatInt(int64(int8(b)), 10)), err + case header>>4 == rdbZiplistInt4: + return []byte(strconv.FormatInt(int64(header&0x0f)-1, 10)), nil + } + + return nil, fmt.Errorf("rdb: unknown ziplist header byte: %d", header) +} + +func (d *decode) readIntset(key []byte, expiry int64) error { + intset, err := d.readString() + if err != nil { + return err + } + buf := newSliceBuffer(intset) + intSizeBytes, err := buf.Slice(4) + if err != nil { + return err + } + intSize := binary.LittleEndian.Uint32(intSizeBytes) + + if intSize != 2 && intSize != 4 && intSize != 8 { + return fmt.Errorf("rdb: unknown intset encoding: %d", intSize) + } + + lenBytes, err := buf.Slice(4) + if err != nil { + return err + } + cardinality := binary.LittleEndian.Uint32(lenBytes) + + d.event.StartSet(key, int64(cardinality), expiry) + for i := uint32(0); i < cardinality; i++ { + intBytes, err := buf.Slice(int(intSize)) + if err != nil { + return err + } + var intString string + switch intSize { + case 2: + intString = strconv.FormatInt(int64(int16(binary.LittleEndian.Uint16(intBytes))), 10) + case 4: + intString = strconv.FormatInt(int64(int32(binary.LittleEndian.Uint32(intBytes))), 10) + case 8: + intString = strconv.FormatInt(int64(int64(binary.LittleEndian.Uint64(intBytes))), 10) + } + d.event.Sadd(key, []byte(intString)) + } + d.event.EndSet(key) + return nil +} + +func (d *decode) checkHeader() error { + header := make([]byte, 9) + _, err := io.ReadFull(d.r, header) + if err != nil { + return err + } + + if !bytes.Equal(header[:5], []byte("REDIS")) { + return fmt.Errorf("rdb: invalid file format") + } + + version, _ := strconv.ParseInt(string(header[5:]), 10, 64) + if version < 1 || version > 7 { + return fmt.Errorf("rdb: invalid RDB version number %d", version) + } + + return nil +} + +func (d *decode) readString() ([]byte, error) { + length, encoded, err := d.readLength() + if err != nil { + return nil, err + } + if encoded { + switch length { + case rdbEncInt8: + i, err := d.readUint8() + return []byte(strconv.FormatInt(int64(int8(i)), 10)), err + case rdbEncInt16: + i, err := d.readUint16() + return []byte(strconv.FormatInt(int64(int16(i)), 10)), err + case rdbEncInt32: + i, err := d.readUint32() + return []byte(strconv.FormatInt(int64(int32(i)), 10)), err + case rdbEncLZF: + clen, _, err := d.readLength() + if err != nil { + return nil, err + } + ulen, _, err := d.readLength() + if err != nil { + return nil, err + } + compressed := make([]byte, clen) + _, err = io.ReadFull(d.r, compressed) + if err != nil { + return nil, err + } + decompressed := lzfDecompress(compressed, int(ulen)) + if len(decompressed) != int(ulen) { + return nil, fmt.Errorf("decompressed string length %d didn't match expected length %d", len(decompressed), ulen) + } + return decompressed, nil + } + } + + str := make([]byte, length) + _, err = io.ReadFull(d.r, str) + return str, err +} + +func (d *decode) readUint8() (uint8, error) { + b, err := d.r.ReadByte() + return uint8(b), err +} + +func (d *decode) readUint16() (uint16, error) { + _, err := io.ReadFull(d.r, d.intBuf[:2]) + if err != nil { + return 0, err + } + return binary.LittleEndian.Uint16(d.intBuf), nil +} + +func (d *decode) readUint32() (uint32, error) { + _, err := io.ReadFull(d.r, d.intBuf[:4]) + if err != nil { + return 0, err + } + return binary.LittleEndian.Uint32(d.intBuf), nil +} + +func (d *decode) readUint64() (uint64, error) { + _, err := io.ReadFull(d.r, d.intBuf) + if err != nil { + return 0, err + } + return binary.LittleEndian.Uint64(d.intBuf), nil +} + +func (d *decode) readUint32Big() (uint32, error) { + _, err := io.ReadFull(d.r, d.intBuf[:4]) + if err != nil { + return 0, err + } + return binary.BigEndian.Uint32(d.intBuf), nil +} + +// Doubles are saved as strings prefixed by an unsigned +// 8 bit integer specifying the length of the representation. +// This 8 bit integer has special values in order to specify the following +// conditions: +// 253: not a number +// 254: + inf +// 255: - inf +func (d *decode) readFloat64() (float64, error) { + length, err := d.readUint8() + if err != nil { + return 0, err + } + switch length { + case 253: + return math.NaN(), nil + case 254: + return math.Inf(0), nil + case 255: + return math.Inf(-1), nil + default: + floatBytes := make([]byte, length) + _, err := io.ReadFull(d.r, floatBytes) + if err != nil { + return 0, err + } + f, err := strconv.ParseFloat(string(floatBytes), 64) + return f, err + } + + panic("not reached") +} + +func (d *decode) readLength() (uint32, bool, error) { + b, err := d.r.ReadByte() + if err != nil { + return 0, false, err + } + // The first two bits of the first byte are used to indicate the length encoding type + switch (b & 0xc0) >> 6 { + case rdb6bitLen: + // When the first two bits are 00, the next 6 bits are the length. + return uint32(b & 0x3f), false, nil + case rdb14bitLen: + // When the first two bits are 01, the next 14 bits are the length. + bb, err := d.r.ReadByte() + if err != nil { + return 0, false, err + } + return (uint32(b&0x3f) << 8) | uint32(bb), false, nil + case rdbEncVal: + // When the first two bits are 11, the next object is encoded. + // The next 6 bits indicate the encoding type. + return uint32(b & 0x3f), true, nil + default: + // When the first two bits are 10, the next 6 bits are discarded. + // The next 4 bytes are the length. + length, err := d.readUint32Big() + return length, false, err + } + + panic("not reached") +} + +func verifyDump(d []byte) error { + if len(d) < 10 { + return fmt.Errorf("rdb: invalid dump length") + } + version := binary.LittleEndian.Uint16(d[len(d)-10:]) + if version != uint16(Version) { + return fmt.Errorf("rdb: invalid version %d, expecting %d", version, Version) + } + + if binary.LittleEndian.Uint64(d[len(d)-8:]) != crc64.Digest(d[:len(d)-8]) { + return fmt.Errorf("rdb: invalid CRC checksum") + } + + return nil +} + +func lzfDecompress(in []byte, outlen int) []byte { + out := make([]byte, outlen) + for i, o := 0, 0; i < len(in); { + ctrl := int(in[i]) + i++ + if ctrl < 32 { + for x := 0; x <= ctrl; x++ { + out[o] = in[i] + i++ + o++ + } + } else { + length := ctrl >> 5 + if length == 7 { + length = length + int(in[i]) + i++ + } + ref := o - ((ctrl & 0x1f) << 8) - int(in[i]) - 1 + i++ + for x := 0; x <= length+1; x++ { + out[o] = out[ref] + ref++ + o++ + } + } + } + return out +} diff --git a/vendor/github.com/cupcake/rdb/encoder.go b/vendor/github.com/cupcake/rdb/encoder.go new file mode 100644 index 00000000..7902a7d3 --- /dev/null +++ b/vendor/github.com/cupcake/rdb/encoder.go @@ -0,0 +1,130 @@ +package rdb + +import ( + "encoding/binary" + "fmt" + "hash" + "io" + "math" + "strconv" + + "github.com/cupcake/rdb/crc64" +) + +const Version = 6 + +type Encoder struct { + w io.Writer + crc hash.Hash +} + +func NewEncoder(w io.Writer) *Encoder { + e := &Encoder{crc: crc64.New()} + e.w = io.MultiWriter(w, e.crc) + return e +} + +func (e *Encoder) EncodeHeader() error { + _, err := fmt.Fprintf(e.w, "REDIS%04d", Version) + return err +} + +func (e *Encoder) EncodeFooter() error { + e.w.Write([]byte{rdbFlagEOF}) + _, err := e.w.Write(e.crc.Sum(nil)) + return err +} + +func (e *Encoder) EncodeDumpFooter() error { + binary.Write(e.w, binary.LittleEndian, uint16(Version)) + _, err := e.w.Write(e.crc.Sum(nil)) + return err +} + +func (e *Encoder) EncodeDatabase(n int) error { + e.w.Write([]byte{rdbFlagSelectDB}) + return e.EncodeLength(uint32(n)) +} + +func (e *Encoder) EncodeExpiry(expiry uint64) error { + b := make([]byte, 9) + b[0] = rdbFlagExpiryMS + binary.LittleEndian.PutUint64(b[1:], expiry) + _, err := e.w.Write(b) + return err +} + +func (e *Encoder) EncodeType(v ValueType) error { + _, err := e.w.Write([]byte{byte(v)}) + return err +} + +func (e *Encoder) EncodeString(s []byte) error { + written, err := e.encodeIntString(s) + if written { + return err + } + e.EncodeLength(uint32(len(s))) + _, err = e.w.Write(s) + return err +} + +func (e *Encoder) EncodeLength(l uint32) (err error) { + switch { + case l < 1<<6: + _, err = e.w.Write([]byte{byte(l)}) + case l < 1<<14: + _, err = e.w.Write([]byte{byte(l>>8) | rdb14bitLen<<6, byte(l)}) + default: + b := make([]byte, 5) + b[0] = rdb32bitLen << 6 + binary.BigEndian.PutUint32(b[1:], l) + _, err = e.w.Write(b) + } + return +} + +func (e *Encoder) EncodeFloat(f float64) (err error) { + switch { + case math.IsNaN(f): + _, err = e.w.Write([]byte{253}) + case math.IsInf(f, 1): + _, err = e.w.Write([]byte{254}) + case math.IsInf(f, -1): + _, err = e.w.Write([]byte{255}) + default: + b := []byte(strconv.FormatFloat(f, 'g', 17, 64)) + e.w.Write([]byte{byte(len(b))}) + _, err = e.w.Write(b) + } + return +} + +func (e *Encoder) encodeIntString(b []byte) (written bool, err error) { + s := string(b) + i, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return + } + // if the stringified parsed int isn't exactly the same, we can't encode it as an int + if s != strconv.FormatInt(i, 10) { + return + } + switch { + case i >= math.MinInt8 && i <= math.MaxInt8: + _, err = e.w.Write([]byte{rdbEncVal << 6, byte(int8(i))}) + case i >= math.MinInt16 && i <= math.MaxInt16: + b := make([]byte, 3) + b[0] = rdbEncVal<<6 | rdbEncInt16 + binary.LittleEndian.PutUint16(b[1:], uint16(int16(i))) + _, err = e.w.Write(b) + case i >= math.MinInt32 && i <= math.MaxInt32: + b := make([]byte, 5) + b[0] = rdbEncVal<<6 | rdbEncInt32 + binary.LittleEndian.PutUint32(b[1:], uint32(int32(i))) + _, err = e.w.Write(b) + default: + return + } + return true, err +} diff --git a/vendor/github.com/cupcake/rdb/nopdecoder/nop_decoder.go b/vendor/github.com/cupcake/rdb/nopdecoder/nop_decoder.go new file mode 100644 index 00000000..de93a697 --- /dev/null +++ b/vendor/github.com/cupcake/rdb/nopdecoder/nop_decoder.go @@ -0,0 +1,24 @@ +package nopdecoder + +// NopDecoder may be embedded in a real Decoder to avoid implementing methods. +type NopDecoder struct{} + +func (d NopDecoder) StartRDB() {} +func (d NopDecoder) StartDatabase(n int) {} +func (d NopDecoder) Aux(key, value []byte) {} +func (d NopDecoder) ResizeDatabase(dbSize, expiresSize uint32) {} +func (d NopDecoder) EndDatabase(n int) {} +func (d NopDecoder) EndRDB() {} +func (d NopDecoder) Set(key, value []byte, expiry int64) {} +func (d NopDecoder) StartHash(key []byte, length, expiry int64) {} +func (d NopDecoder) Hset(key, field, value []byte) {} +func (d NopDecoder) EndHash(key []byte) {} +func (d NopDecoder) StartSet(key []byte, cardinality, expiry int64) {} +func (d NopDecoder) Sadd(key, member []byte) {} +func (d NopDecoder) EndSet(key []byte) {} +func (d NopDecoder) StartList(key []byte, length, expiry int64) {} +func (d NopDecoder) Rpush(key, value []byte) {} +func (d NopDecoder) EndList(key []byte) {} +func (d NopDecoder) StartZSet(key []byte, cardinality, expiry int64) {} +func (d NopDecoder) Zadd(key []byte, score float64, member []byte) {} +func (d NopDecoder) EndZSet(key []byte) {} diff --git a/vendor/github.com/cupcake/rdb/slice_buffer.go b/vendor/github.com/cupcake/rdb/slice_buffer.go new file mode 100644 index 00000000..b3e12a02 --- /dev/null +++ b/vendor/github.com/cupcake/rdb/slice_buffer.go @@ -0,0 +1,67 @@ +package rdb + +import ( + "errors" + "io" +) + +type sliceBuffer struct { + s []byte + i int +} + +func newSliceBuffer(s []byte) *sliceBuffer { + return &sliceBuffer{s, 0} +} + +func (s *sliceBuffer) Slice(n int) ([]byte, error) { + if s.i+n > len(s.s) { + return nil, io.EOF + } + b := s.s[s.i : s.i+n] + s.i += n + return b, nil +} + +func (s *sliceBuffer) ReadByte() (byte, error) { + if s.i >= len(s.s) { + return 0, io.EOF + } + b := s.s[s.i] + s.i++ + return b, nil +} + +func (s *sliceBuffer) Read(b []byte) (int, error) { + if len(b) == 0 { + return 0, nil + } + if s.i >= len(s.s) { + return 0, io.EOF + } + n := copy(b, s.s[s.i:]) + s.i += n + return n, nil +} + +func (s *sliceBuffer) Seek(offset int64, whence int) (int64, error) { + var abs int64 + switch whence { + case 0: + abs = offset + case 1: + abs = int64(s.i) + offset + case 2: + abs = int64(len(s.s)) + offset + default: + return 0, errors.New("invalid whence") + } + if abs < 0 { + return 0, errors.New("negative position") + } + if abs >= 1<<31 { + return 0, errors.New("position out of range") + } + s.i = int(abs) + return abs, nil +} diff --git a/vendor/github.com/edsrzf/mmap-go/.gitignore b/vendor/github.com/edsrzf/mmap-go/.gitignore new file mode 100644 index 00000000..9aa02c1e --- /dev/null +++ b/vendor/github.com/edsrzf/mmap-go/.gitignore @@ -0,0 +1,8 @@ +*.out +*.5 +*.6 +*.8 +*.swp +_obj +_test +testdata diff --git a/vendor/github.com/edsrzf/mmap-go/LICENSE b/vendor/github.com/edsrzf/mmap-go/LICENSE new file mode 100644 index 00000000..8f05f338 --- /dev/null +++ b/vendor/github.com/edsrzf/mmap-go/LICENSE @@ -0,0 +1,25 @@ +Copyright (c) 2011, Evan Shaw +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the copyright holder nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + diff --git a/vendor/github.com/edsrzf/mmap-go/README.md b/vendor/github.com/edsrzf/mmap-go/README.md new file mode 100644 index 00000000..4cc2bfe1 --- /dev/null +++ b/vendor/github.com/edsrzf/mmap-go/README.md @@ -0,0 +1,12 @@ +mmap-go +======= + +mmap-go is a portable mmap package for the [Go programming language](http://golang.org). +It has been tested on Linux (386, amd64), OS X, and Windows (386). It should also +work on other Unix-like platforms, but hasn't been tested with them. I'm interested +to hear about the results. + +I haven't been able to add more features without adding significant complexity, +so mmap-go doesn't support mprotect, mincore, and maybe a few other things. +If you're running on a Unix-like platform and need some of these features, +I suggest Gustavo Niemeyer's [gommap](http://labix.org/gommap). diff --git a/vendor/github.com/edsrzf/mmap-go/mmap.go b/vendor/github.com/edsrzf/mmap-go/mmap.go new file mode 100644 index 00000000..14fb2258 --- /dev/null +++ b/vendor/github.com/edsrzf/mmap-go/mmap.go @@ -0,0 +1,116 @@ +// Copyright 2011 Evan Shaw. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file defines the common package interface and contains a little bit of +// factored out logic. + +// Package mmap allows mapping files into memory. It tries to provide a simple, reasonably portable interface, +// but doesn't go out of its way to abstract away every little platform detail. +// This specifically means: +// * forked processes may or may not inherit mappings +// * a file's timestamp may or may not be updated by writes through mappings +// * specifying a size larger than the file's actual size can increase the file's size +// * If the mapped file is being modified by another process while your program's running, don't expect consistent results between platforms +package mmap + +import ( + "errors" + "os" + "reflect" + "unsafe" +) + +const ( + // RDONLY maps the memory read-only. + // Attempts to write to the MMap object will result in undefined behavior. + RDONLY = 0 + // RDWR maps the memory as read-write. Writes to the MMap object will update the + // underlying file. + RDWR = 1 << iota + // COPY maps the memory as copy-on-write. Writes to the MMap object will affect + // memory, but the underlying file will remain unchanged. + COPY + // If EXEC is set, the mapped memory is marked as executable. + EXEC +) + +const ( + // If the ANON flag is set, the mapped memory will not be backed by a file. + ANON = 1 << iota +) + +// MMap represents a file mapped into memory. +type MMap []byte + +// Map maps an entire file into memory. +// If ANON is set in flags, f is ignored. +func Map(f *os.File, prot, flags int) (MMap, error) { + return MapRegion(f, -1, prot, flags, 0) +} + +// MapRegion maps part of a file into memory. +// The offset parameter must be a multiple of the system's page size. +// If length < 0, the entire file will be mapped. +// If ANON is set in flags, f is ignored. +func MapRegion(f *os.File, length int, prot, flags int, offset int64) (MMap, error) { + if offset%int64(os.Getpagesize()) != 0 { + return nil, errors.New("offset parameter must be a multiple of the system's page size") + } + + var fd uintptr + if flags&ANON == 0 { + fd = uintptr(f.Fd()) + if length < 0 { + fi, err := f.Stat() + if err != nil { + return nil, err + } + length = int(fi.Size()) + } + } else { + if length <= 0 { + return nil, errors.New("anonymous mapping requires non-zero length") + } + fd = ^uintptr(0) + } + return mmap(length, uintptr(prot), uintptr(flags), fd, offset) +} + +func (m *MMap) header() *reflect.SliceHeader { + return (*reflect.SliceHeader)(unsafe.Pointer(m)) +} + +// Lock keeps the mapped region in physical memory, ensuring that it will not be +// swapped out. +func (m MMap) Lock() error { + dh := m.header() + return lock(dh.Data, uintptr(dh.Len)) +} + +// Unlock reverses the effect of Lock, allowing the mapped region to potentially +// be swapped out. +// If m is already unlocked, aan error will result. +func (m MMap) Unlock() error { + dh := m.header() + return unlock(dh.Data, uintptr(dh.Len)) +} + +// Flush synchronizes the mapping's contents to the file's contents on disk. +func (m MMap) Flush() error { + dh := m.header() + return flush(dh.Data, uintptr(dh.Len)) +} + +// Unmap deletes the memory mapped region, flushes any remaining changes, and sets +// m to nil. +// Trying to read or write any remaining references to m after Unmap is called will +// result in undefined behavior. +// Unmap should only be called on the slice value that was originally returned from +// a call to Map. Calling Unmap on a derived slice may cause errors. +func (m *MMap) Unmap() error { + dh := m.header() + err := unmap(dh.Data, uintptr(dh.Len)) + *m = nil + return err +} diff --git a/vendor/github.com/edsrzf/mmap-go/mmap_unix.go b/vendor/github.com/edsrzf/mmap-go/mmap_unix.go new file mode 100644 index 00000000..4af98420 --- /dev/null +++ b/vendor/github.com/edsrzf/mmap-go/mmap_unix.go @@ -0,0 +1,67 @@ +// Copyright 2011 Evan Shaw. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux openbsd solaris netbsd + +package mmap + +import ( + "syscall" +) + +func mmap(len int, inprot, inflags, fd uintptr, off int64) ([]byte, error) { + flags := syscall.MAP_SHARED + prot := syscall.PROT_READ + switch { + case inprot© != 0: + prot |= syscall.PROT_WRITE + flags = syscall.MAP_PRIVATE + case inprot&RDWR != 0: + prot |= syscall.PROT_WRITE + } + if inprot&EXEC != 0 { + prot |= syscall.PROT_EXEC + } + if inflags&ANON != 0 { + flags |= syscall.MAP_ANON + } + + b, err := syscall.Mmap(int(fd), off, len, prot, flags) + if err != nil { + return nil, err + } + return b, nil +} + +func flush(addr, len uintptr) error { + _, _, errno := syscall.Syscall(_SYS_MSYNC, addr, len, _MS_SYNC) + if errno != 0 { + return syscall.Errno(errno) + } + return nil +} + +func lock(addr, len uintptr) error { + _, _, errno := syscall.Syscall(syscall.SYS_MLOCK, addr, len, 0) + if errno != 0 { + return syscall.Errno(errno) + } + return nil +} + +func unlock(addr, len uintptr) error { + _, _, errno := syscall.Syscall(syscall.SYS_MUNLOCK, addr, len, 0) + if errno != 0 { + return syscall.Errno(errno) + } + return nil +} + +func unmap(addr, len uintptr) error { + _, _, errno := syscall.Syscall(syscall.SYS_MUNMAP, addr, len, 0) + if errno != 0 { + return syscall.Errno(errno) + } + return nil +} diff --git a/vendor/github.com/edsrzf/mmap-go/mmap_windows.go b/vendor/github.com/edsrzf/mmap-go/mmap_windows.go new file mode 100644 index 00000000..c3d2d02d --- /dev/null +++ b/vendor/github.com/edsrzf/mmap-go/mmap_windows.go @@ -0,0 +1,125 @@ +// Copyright 2011 Evan Shaw. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mmap + +import ( + "errors" + "os" + "sync" + "syscall" +) + +// mmap on Windows is a two-step process. +// First, we call CreateFileMapping to get a handle. +// Then, we call MapviewToFile to get an actual pointer into memory. +// Because we want to emulate a POSIX-style mmap, we don't want to expose +// the handle -- only the pointer. We also want to return only a byte slice, +// not a struct, so it's convenient to manipulate. + +// We keep this map so that we can get back the original handle from the memory address. +var handleLock sync.Mutex +var handleMap = map[uintptr]syscall.Handle{} + +func mmap(len int, prot, flags, hfile uintptr, off int64) ([]byte, error) { + flProtect := uint32(syscall.PAGE_READONLY) + dwDesiredAccess := uint32(syscall.FILE_MAP_READ) + switch { + case prot© != 0: + flProtect = syscall.PAGE_WRITECOPY + dwDesiredAccess = syscall.FILE_MAP_COPY + case prot&RDWR != 0: + flProtect = syscall.PAGE_READWRITE + dwDesiredAccess = syscall.FILE_MAP_WRITE + } + if prot&EXEC != 0 { + flProtect <<= 4 + dwDesiredAccess |= syscall.FILE_MAP_EXECUTE + } + + // The maximum size is the area of the file, starting from 0, + // that we wish to allow to be mappable. It is the sum of + // the length the user requested, plus the offset where that length + // is starting from. This does not map the data into memory. + maxSizeHigh := uint32((off + int64(len)) >> 32) + maxSizeLow := uint32((off + int64(len)) & 0xFFFFFFFF) + // TODO: Do we need to set some security attributes? It might help portability. + h, errno := syscall.CreateFileMapping(syscall.Handle(hfile), nil, flProtect, maxSizeHigh, maxSizeLow, nil) + if h == 0 { + return nil, os.NewSyscallError("CreateFileMapping", errno) + } + + // Actually map a view of the data into memory. The view's size + // is the length the user requested. + fileOffsetHigh := uint32(off >> 32) + fileOffsetLow := uint32(off & 0xFFFFFFFF) + addr, errno := syscall.MapViewOfFile(h, dwDesiredAccess, fileOffsetHigh, fileOffsetLow, uintptr(len)) + if addr == 0 { + return nil, os.NewSyscallError("MapViewOfFile", errno) + } + handleLock.Lock() + handleMap[addr] = h + handleLock.Unlock() + + m := MMap{} + dh := m.header() + dh.Data = addr + dh.Len = len + dh.Cap = dh.Len + + return m, nil +} + +func flush(addr, len uintptr) error { + errno := syscall.FlushViewOfFile(addr, len) + if errno != nil { + return os.NewSyscallError("FlushViewOfFile", errno) + } + + handleLock.Lock() + defer handleLock.Unlock() + handle, ok := handleMap[addr] + if !ok { + // should be impossible; we would've errored above + return errors.New("unknown base address") + } + + errno = syscall.FlushFileBuffers(handle) + return os.NewSyscallError("FlushFileBuffers", errno) +} + +func lock(addr, len uintptr) error { + errno := syscall.VirtualLock(addr, len) + return os.NewSyscallError("VirtualLock", errno) +} + +func unlock(addr, len uintptr) error { + errno := syscall.VirtualUnlock(addr, len) + return os.NewSyscallError("VirtualUnlock", errno) +} + +func unmap(addr, len uintptr) error { + flush(addr, len) + // Lock the UnmapViewOfFile along with the handleMap deletion. + // As soon as we unmap the view, the OS is free to give the + // same addr to another new map. We don't want another goroutine + // to insert and remove the same addr into handleMap while + // we're trying to remove our old addr/handle pair. + handleLock.Lock() + defer handleLock.Unlock() + err := syscall.UnmapViewOfFile(addr) + if err != nil { + return err + } + + handle, ok := handleMap[addr] + if !ok { + // should be impossible; we would've errored above + return errors.New("unknown base address") + } + delete(handleMap, addr) + + e := syscall.CloseHandle(syscall.Handle(handle)) + return os.NewSyscallError("CloseHandle", e) +} diff --git a/vendor/github.com/edsrzf/mmap-go/msync_netbsd.go b/vendor/github.com/edsrzf/mmap-go/msync_netbsd.go new file mode 100644 index 00000000..a64b003e --- /dev/null +++ b/vendor/github.com/edsrzf/mmap-go/msync_netbsd.go @@ -0,0 +1,8 @@ +// Copyright 2011 Evan Shaw. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mmap + +const _SYS_MSYNC = 277 +const _MS_SYNC = 0x04 diff --git a/vendor/github.com/edsrzf/mmap-go/msync_unix.go b/vendor/github.com/edsrzf/mmap-go/msync_unix.go new file mode 100644 index 00000000..91ee5f40 --- /dev/null +++ b/vendor/github.com/edsrzf/mmap-go/msync_unix.go @@ -0,0 +1,14 @@ +// Copyright 2011 Evan Shaw. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux openbsd solaris + +package mmap + +import ( + "syscall" +) + +const _SYS_MSYNC = syscall.SYS_MSYNC +const _MS_SYNC = syscall.MS_SYNC diff --git a/vendor/github.com/elazarl/go-bindata-assetfs/LICENSE b/vendor/github.com/elazarl/go-bindata-assetfs/LICENSE new file mode 100644 index 00000000..5782c726 --- /dev/null +++ b/vendor/github.com/elazarl/go-bindata-assetfs/LICENSE @@ -0,0 +1,23 @@ +Copyright (c) 2014, Elazar Leibovich +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/elazarl/go-bindata-assetfs/README.md b/vendor/github.com/elazarl/go-bindata-assetfs/README.md new file mode 100644 index 00000000..27ee48f0 --- /dev/null +++ b/vendor/github.com/elazarl/go-bindata-assetfs/README.md @@ -0,0 +1,46 @@ +# go-bindata-assetfs + +Serve embedded files from [jteeuwen/go-bindata](https://github.com/jteeuwen/go-bindata) with `net/http`. + +[GoDoc](http://godoc.org/github.com/elazarl/go-bindata-assetfs) + +### Installation + +Install with + + $ go get github.com/jteeuwen/go-bindata/... + $ go get github.com/elazarl/go-bindata-assetfs/... + +### Creating embedded data + +Usage is identical to [jteeuwen/go-bindata](https://github.com/jteeuwen/go-bindata) usage, +instead of running `go-bindata` run `go-bindata-assetfs`. + +The tool will create a `bindata_assetfs.go` file, which contains the embedded data. + +A typical use case is + + $ go-bindata-assetfs data/... + +### Using assetFS in your code + +The generated file provides an `assetFS()` function that returns a `http.Filesystem` +wrapping the embedded files. What you usually want to do is: + + http.Handle("/", http.FileServer(assetFS())) + +This would run an HTTP server serving the embedded files. + +## Without running binary tool + +You can always just run the `go-bindata` tool, and then + +use + + import "github.com/elazarl/go-bindata-assetfs" + ... + http.Handle("/", + http.FileServer( + &assetfs.AssetFS{Asset: Asset, AssetDir: AssetDir, AssetInfo: AssetInfo, Prefix: "data"})) + +to serve files embedded from the `data` directory. diff --git a/vendor/github.com/elazarl/go-bindata-assetfs/assetfs.go b/vendor/github.com/elazarl/go-bindata-assetfs/assetfs.go new file mode 100644 index 00000000..04f6d7a3 --- /dev/null +++ b/vendor/github.com/elazarl/go-bindata-assetfs/assetfs.go @@ -0,0 +1,167 @@ +package assetfs + +import ( + "bytes" + "errors" + "io" + "io/ioutil" + "net/http" + "os" + "path" + "path/filepath" + "strings" + "time" +) + +var ( + defaultFileTimestamp = time.Now() +) + +// FakeFile implements os.FileInfo interface for a given path and size +type FakeFile struct { + // Path is the path of this file + Path string + // Dir marks of the path is a directory + Dir bool + // Len is the length of the fake file, zero if it is a directory + Len int64 + // Timestamp is the ModTime of this file + Timestamp time.Time +} + +func (f *FakeFile) Name() string { + _, name := filepath.Split(f.Path) + return name +} + +func (f *FakeFile) Mode() os.FileMode { + mode := os.FileMode(0644) + if f.Dir { + return mode | os.ModeDir + } + return mode +} + +func (f *FakeFile) ModTime() time.Time { + return f.Timestamp +} + +func (f *FakeFile) Size() int64 { + return f.Len +} + +func (f *FakeFile) IsDir() bool { + return f.Mode().IsDir() +} + +func (f *FakeFile) Sys() interface{} { + return nil +} + +// AssetFile implements http.File interface for a no-directory file with content +type AssetFile struct { + *bytes.Reader + io.Closer + FakeFile +} + +func NewAssetFile(name string, content []byte, timestamp time.Time) *AssetFile { + if timestamp.IsZero() { + timestamp = defaultFileTimestamp + } + return &AssetFile{ + bytes.NewReader(content), + ioutil.NopCloser(nil), + FakeFile{name, false, int64(len(content)), timestamp}} +} + +func (f *AssetFile) Readdir(count int) ([]os.FileInfo, error) { + return nil, errors.New("not a directory") +} + +func (f *AssetFile) Size() int64 { + return f.FakeFile.Size() +} + +func (f *AssetFile) Stat() (os.FileInfo, error) { + return f, nil +} + +// AssetDirectory implements http.File interface for a directory +type AssetDirectory struct { + AssetFile + ChildrenRead int + Children []os.FileInfo +} + +func NewAssetDirectory(name string, children []string, fs *AssetFS) *AssetDirectory { + fileinfos := make([]os.FileInfo, 0, len(children)) + for _, child := range children { + _, err := fs.AssetDir(filepath.Join(name, child)) + fileinfos = append(fileinfos, &FakeFile{child, err == nil, 0, time.Time{}}) + } + return &AssetDirectory{ + AssetFile{ + bytes.NewReader(nil), + ioutil.NopCloser(nil), + FakeFile{name, true, 0, time.Time{}}, + }, + 0, + fileinfos} +} + +func (f *AssetDirectory) Readdir(count int) ([]os.FileInfo, error) { + if count <= 0 { + return f.Children, nil + } + if f.ChildrenRead+count > len(f.Children) { + count = len(f.Children) - f.ChildrenRead + } + rv := f.Children[f.ChildrenRead : f.ChildrenRead+count] + f.ChildrenRead += count + return rv, nil +} + +func (f *AssetDirectory) Stat() (os.FileInfo, error) { + return f, nil +} + +// AssetFS implements http.FileSystem, allowing +// embedded files to be served from net/http package. +type AssetFS struct { + // Asset should return content of file in path if exists + Asset func(path string) ([]byte, error) + // AssetDir should return list of files in the path + AssetDir func(path string) ([]string, error) + // AssetInfo should return the info of file in path if exists + AssetInfo func(path string) (os.FileInfo, error) + // Prefix would be prepended to http requests + Prefix string +} + +func (fs *AssetFS) Open(name string) (http.File, error) { + name = path.Join(fs.Prefix, name) + if len(name) > 0 && name[0] == '/' { + name = name[1:] + } + if b, err := fs.Asset(name); err == nil { + timestamp := defaultFileTimestamp + if fs.AssetInfo != nil { + if info, err := fs.AssetInfo(name); err == nil { + timestamp = info.ModTime() + } + } + return NewAssetFile(name, b, timestamp), nil + } + if children, err := fs.AssetDir(name); err == nil { + return NewAssetDirectory(name, children, fs), nil + } else { + // If the error is not found, return an error that will + // result in a 404 error. Otherwise the server returns + // a 500 error for files not found. + if strings.Contains(err.Error(), "not found") { + return nil, os.ErrNotExist + } + return nil, err + } +} diff --git a/vendor/github.com/elazarl/go-bindata-assetfs/doc.go b/vendor/github.com/elazarl/go-bindata-assetfs/doc.go new file mode 100644 index 00000000..a664249f --- /dev/null +++ b/vendor/github.com/elazarl/go-bindata-assetfs/doc.go @@ -0,0 +1,13 @@ +// assetfs allows packages to serve static content embedded +// with the go-bindata tool with the standard net/http package. +// +// See https://github.com/jteeuwen/go-bindata for more information +// about embedding binary data with go-bindata. +// +// Usage example, after running +// $ go-bindata data/... +// use: +// http.Handle("/", +// http.FileServer( +// &assetfs.AssetFS{Asset: Asset, AssetDir: AssetDir, Prefix: "data"})) +package assetfs diff --git a/vendor/github.com/go-redis/redis/.gitignore b/vendor/github.com/go-redis/redis/.gitignore new file mode 100644 index 00000000..ebfe903b --- /dev/null +++ b/vendor/github.com/go-redis/redis/.gitignore @@ -0,0 +1,2 @@ +*.rdb +testdata/*/ diff --git a/vendor/github.com/go-redis/redis/.travis.yml b/vendor/github.com/go-redis/redis/.travis.yml new file mode 100644 index 00000000..632feca0 --- /dev/null +++ b/vendor/github.com/go-redis/redis/.travis.yml @@ -0,0 +1,21 @@ +sudo: false +language: go + +services: + - redis-server + +go: + - 1.7.x + - 1.8.x + - 1.9.x + - 1.10.x + - 1.11.x + - tip + +matrix: + allow_failures: + - go: tip + +install: + - go get github.com/onsi/ginkgo + - go get github.com/onsi/gomega diff --git a/vendor/github.com/go-redis/redis/CHANGELOG.md b/vendor/github.com/go-redis/redis/CHANGELOG.md new file mode 100644 index 00000000..19645661 --- /dev/null +++ b/vendor/github.com/go-redis/redis/CHANGELOG.md @@ -0,0 +1,25 @@ +# Changelog + +## Unreleased + +- Cluster and Ring pipelines process commands for each node in its own goroutine. + +## 6.14 + +- Added Options.MinIdleConns. +- Added Options.MaxConnAge. +- PoolStats.FreeConns is renamed to PoolStats.IdleConns. +- Add Client.Do to simplify creating custom commands. +- Add Cmd.String, Cmd.Int, Cmd.Int64, Cmd.Uint64, Cmd.Float64, and Cmd.Bool helpers. +- Lower memory usage. + +## v6.13 + +- Ring got new options called `HashReplicas` and `Hash`. It is recommended to set `HashReplicas = 1000` for better keys distribution between shards. +- Cluster client was optimized to use much less memory when reloading cluster state. +- PubSub.ReceiveMessage is re-worked to not use ReceiveTimeout so it does not lose data when timeout occurres. In most cases it is recommended to use PubSub.Channel instead. +- Dialer.KeepAlive is set to 5 minutes by default. + +## v6.12 + +- ClusterClient got new option called `ClusterSlots` which allows to build cluster of normal Redis Servers that don't have cluster mode enabled. See https://godoc.org/github.com/go-redis/redis#example-NewClusterClient--ManualSetup diff --git a/vendor/github.com/go-redis/redis/LICENSE b/vendor/github.com/go-redis/redis/LICENSE new file mode 100644 index 00000000..298bed9b --- /dev/null +++ b/vendor/github.com/go-redis/redis/LICENSE @@ -0,0 +1,25 @@ +Copyright (c) 2013 The github.com/go-redis/redis Authors. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/go-redis/redis/Makefile b/vendor/github.com/go-redis/redis/Makefile new file mode 100644 index 00000000..1fbdac91 --- /dev/null +++ b/vendor/github.com/go-redis/redis/Makefile @@ -0,0 +1,20 @@ +all: testdeps + go test ./... + go test ./... -short -race + env GOOS=linux GOARCH=386 go test ./... + go vet + +testdeps: testdata/redis/src/redis-server + +bench: testdeps + go test ./... -test.run=NONE -test.bench=. -test.benchmem + +.PHONY: all test testdeps bench + +testdata/redis: + mkdir -p $@ + wget -qO- https://github.com/antirez/redis/archive/unstable.tar.gz | tar xvz --strip-components=1 -C $@ + +testdata/redis/src/redis-server: testdata/redis + sed -i.bak 's/libjemalloc.a/libjemalloc.a -lrt/g' $ +} + +func ExampleClient() { + err := client.Set("key", "value", 0).Err() + if err != nil { + panic(err) + } + + val, err := client.Get("key").Result() + if err != nil { + panic(err) + } + fmt.Println("key", val) + + val2, err := client.Get("key2").Result() + if err == redis.Nil { + fmt.Println("key2 does not exist") + } else if err != nil { + panic(err) + } else { + fmt.Println("key2", val2) + } + // Output: key value + // key2 does not exist +} +``` + +## Howto + +Please go through [examples](https://godoc.org/github.com/go-redis/redis#pkg-examples) to get an idea how to use this package. + +## Look and feel + +Some corner cases: + +```go +// SET key value EX 10 NX +set, err := client.SetNX("key", "value", 10*time.Second).Result() + +// SORT list LIMIT 0 2 ASC +vals, err := client.Sort("list", redis.Sort{Offset: 0, Count: 2, Order: "ASC"}).Result() + +// ZRANGEBYSCORE zset -inf +inf WITHSCORES LIMIT 0 2 +vals, err := client.ZRangeByScoreWithScores("zset", redis.ZRangeBy{ + Min: "-inf", + Max: "+inf", + Offset: 0, + Count: 2, +}).Result() + +// ZINTERSTORE out 2 zset1 zset2 WEIGHTS 2 3 AGGREGATE SUM +vals, err := client.ZInterStore("out", redis.ZStore{Weights: []int64{2, 3}}, "zset1", "zset2").Result() + +// EVAL "return {KEYS[1],ARGV[1]}" 1 "key" "hello" +vals, err := client.Eval("return {KEYS[1],ARGV[1]}", []string{"key"}, "hello").Result() +``` + +## Benchmark + +go-redis vs redigo: + +``` +BenchmarkSetGoRedis10Conns64Bytes-4 200000 7621 ns/op 210 B/op 6 allocs/op +BenchmarkSetGoRedis100Conns64Bytes-4 200000 7554 ns/op 210 B/op 6 allocs/op +BenchmarkSetGoRedis10Conns1KB-4 200000 7697 ns/op 210 B/op 6 allocs/op +BenchmarkSetGoRedis100Conns1KB-4 200000 7688 ns/op 210 B/op 6 allocs/op +BenchmarkSetGoRedis10Conns10KB-4 200000 9214 ns/op 210 B/op 6 allocs/op +BenchmarkSetGoRedis100Conns10KB-4 200000 9181 ns/op 210 B/op 6 allocs/op +BenchmarkSetGoRedis10Conns1MB-4 2000 583242 ns/op 2337 B/op 6 allocs/op +BenchmarkSetGoRedis100Conns1MB-4 2000 583089 ns/op 2338 B/op 6 allocs/op +BenchmarkSetRedigo10Conns64Bytes-4 200000 7576 ns/op 208 B/op 7 allocs/op +BenchmarkSetRedigo100Conns64Bytes-4 200000 7782 ns/op 208 B/op 7 allocs/op +BenchmarkSetRedigo10Conns1KB-4 200000 7958 ns/op 208 B/op 7 allocs/op +BenchmarkSetRedigo100Conns1KB-4 200000 7725 ns/op 208 B/op 7 allocs/op +BenchmarkSetRedigo10Conns10KB-4 100000 18442 ns/op 208 B/op 7 allocs/op +BenchmarkSetRedigo100Conns10KB-4 100000 18818 ns/op 208 B/op 7 allocs/op +BenchmarkSetRedigo10Conns1MB-4 2000 668829 ns/op 226 B/op 7 allocs/op +BenchmarkSetRedigo100Conns1MB-4 2000 679542 ns/op 226 B/op 7 allocs/op +``` + +Redis Cluster: + +``` +BenchmarkRedisPing-4 200000 6983 ns/op 116 B/op 4 allocs/op +BenchmarkRedisClusterPing-4 100000 11535 ns/op 117 B/op 4 allocs/op +``` + +## See also + +- [Golang PostgreSQL ORM](https://github.com/go-pg/pg) +- [Golang msgpack](https://github.com/vmihailenco/msgpack) +- [Golang message task queue](https://github.com/go-msgqueue/msgqueue) diff --git a/vendor/github.com/go-redis/redis/cluster.go b/vendor/github.com/go-redis/redis/cluster.go new file mode 100644 index 00000000..55bc5bae --- /dev/null +++ b/vendor/github.com/go-redis/redis/cluster.go @@ -0,0 +1,1649 @@ +package redis + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "math" + "math/rand" + "net" + "runtime" + "sort" + "sync" + "sync/atomic" + "time" + + "github.com/go-redis/redis/internal" + "github.com/go-redis/redis/internal/hashtag" + "github.com/go-redis/redis/internal/pool" + "github.com/go-redis/redis/internal/proto" + "github.com/go-redis/redis/internal/singleflight" +) + +var errClusterNoNodes = fmt.Errorf("redis: cluster has no nodes") + +// ClusterOptions are used to configure a cluster client and should be +// passed to NewClusterClient. +type ClusterOptions struct { + // A seed list of host:port addresses of cluster nodes. + Addrs []string + + // The maximum number of retries before giving up. Command is retried + // on network errors and MOVED/ASK redirects. + // Default is 8 retries. + MaxRedirects int + + // Enables read-only commands on slave nodes. + ReadOnly bool + // Allows routing read-only commands to the closest master or slave node. + // It automatically enables ReadOnly. + RouteByLatency bool + // Allows routing read-only commands to the random master or slave node. + // It automatically enables ReadOnly. + RouteRandomly bool + + // Optional function that returns cluster slots information. + // It is useful to manually create cluster of standalone Redis servers + // and load-balance read/write operations between master and slaves. + // It can use service like ZooKeeper to maintain configuration information + // and Cluster.ReloadState to manually trigger state reloading. + ClusterSlots func() ([]ClusterSlot, error) + + // Following options are copied from Options struct. + + OnConnect func(*Conn) error + + Password string + + MaxRetries int + MinRetryBackoff time.Duration + MaxRetryBackoff time.Duration + + DialTimeout time.Duration + ReadTimeout time.Duration + WriteTimeout time.Duration + + // PoolSize applies per cluster node and not for the whole cluster. + PoolSize int + MinIdleConns int + MaxConnAge time.Duration + PoolTimeout time.Duration + IdleTimeout time.Duration + IdleCheckFrequency time.Duration + + TLSConfig *tls.Config +} + +func (opt *ClusterOptions) init() { + if opt.MaxRedirects == -1 { + opt.MaxRedirects = 0 + } else if opt.MaxRedirects == 0 { + opt.MaxRedirects = 8 + } + + if (opt.RouteByLatency || opt.RouteRandomly) && opt.ClusterSlots == nil { + opt.ReadOnly = true + } + + if opt.PoolSize == 0 { + opt.PoolSize = 5 * runtime.NumCPU() + } + + switch opt.ReadTimeout { + case -1: + opt.ReadTimeout = 0 + case 0: + opt.ReadTimeout = 3 * time.Second + } + switch opt.WriteTimeout { + case -1: + opt.WriteTimeout = 0 + case 0: + opt.WriteTimeout = opt.ReadTimeout + } + + switch opt.MinRetryBackoff { + case -1: + opt.MinRetryBackoff = 0 + case 0: + opt.MinRetryBackoff = 8 * time.Millisecond + } + switch opt.MaxRetryBackoff { + case -1: + opt.MaxRetryBackoff = 0 + case 0: + opt.MaxRetryBackoff = 512 * time.Millisecond + } +} + +func (opt *ClusterOptions) clientOptions() *Options { + const disableIdleCheck = -1 + + return &Options{ + OnConnect: opt.OnConnect, + + MaxRetries: opt.MaxRetries, + MinRetryBackoff: opt.MinRetryBackoff, + MaxRetryBackoff: opt.MaxRetryBackoff, + Password: opt.Password, + readOnly: opt.ReadOnly, + + DialTimeout: opt.DialTimeout, + ReadTimeout: opt.ReadTimeout, + WriteTimeout: opt.WriteTimeout, + + PoolSize: opt.PoolSize, + MinIdleConns: opt.MinIdleConns, + MaxConnAge: opt.MaxConnAge, + PoolTimeout: opt.PoolTimeout, + IdleTimeout: opt.IdleTimeout, + IdleCheckFrequency: disableIdleCheck, + + TLSConfig: opt.TLSConfig, + } +} + +//------------------------------------------------------------------------------ + +type clusterNode struct { + Client *Client + + latency uint32 // atomic + generation uint32 // atomic + loading uint32 // atomic +} + +func newClusterNode(clOpt *ClusterOptions, addr string) *clusterNode { + opt := clOpt.clientOptions() + opt.Addr = addr + node := clusterNode{ + Client: NewClient(opt), + } + + node.latency = math.MaxUint32 + if clOpt.RouteByLatency { + go node.updateLatency() + } + + return &node +} + +func (n *clusterNode) String() string { + return n.Client.String() +} + +func (n *clusterNode) Close() error { + return n.Client.Close() +} + +func (n *clusterNode) updateLatency() { + const probes = 10 + + var latency uint32 + for i := 0; i < probes; i++ { + start := time.Now() + n.Client.Ping() + probe := uint32(time.Since(start) / time.Microsecond) + latency = (latency + probe) / 2 + } + atomic.StoreUint32(&n.latency, latency) +} + +func (n *clusterNode) Latency() time.Duration { + latency := atomic.LoadUint32(&n.latency) + return time.Duration(latency) * time.Microsecond +} + +func (n *clusterNode) MarkAsLoading() { + atomic.StoreUint32(&n.loading, uint32(time.Now().Unix())) +} + +func (n *clusterNode) Loading() bool { + const minute = int64(time.Minute / time.Second) + + loading := atomic.LoadUint32(&n.loading) + if loading == 0 { + return false + } + if time.Now().Unix()-int64(loading) < minute { + return true + } + atomic.StoreUint32(&n.loading, 0) + return false +} + +func (n *clusterNode) Generation() uint32 { + return atomic.LoadUint32(&n.generation) +} + +func (n *clusterNode) SetGeneration(gen uint32) { + for { + v := atomic.LoadUint32(&n.generation) + if gen < v || atomic.CompareAndSwapUint32(&n.generation, v, gen) { + break + } + } +} + +//------------------------------------------------------------------------------ + +type clusterNodes struct { + opt *ClusterOptions + + mu sync.RWMutex + allAddrs []string + allNodes map[string]*clusterNode + clusterAddrs []string + closed bool + + nodeCreateGroup singleflight.Group + + _generation uint32 // atomic +} + +func newClusterNodes(opt *ClusterOptions) *clusterNodes { + return &clusterNodes{ + opt: opt, + + allAddrs: opt.Addrs, + allNodes: make(map[string]*clusterNode), + } +} + +func (c *clusterNodes) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.closed { + return nil + } + c.closed = true + + var firstErr error + for _, node := range c.allNodes { + if err := node.Client.Close(); err != nil && firstErr == nil { + firstErr = err + } + } + + c.allNodes = nil + c.clusterAddrs = nil + + return firstErr +} + +func (c *clusterNodes) Addrs() ([]string, error) { + var addrs []string + c.mu.RLock() + closed := c.closed + if !closed { + if len(c.clusterAddrs) > 0 { + addrs = c.clusterAddrs + } else { + addrs = c.allAddrs + } + } + c.mu.RUnlock() + + if closed { + return nil, pool.ErrClosed + } + if len(addrs) == 0 { + return nil, errClusterNoNodes + } + return addrs, nil +} + +func (c *clusterNodes) NextGeneration() uint32 { + return atomic.AddUint32(&c._generation, 1) +} + +// GC removes unused nodes. +func (c *clusterNodes) GC(generation uint32) { + var collected []*clusterNode + c.mu.Lock() + for addr, node := range c.allNodes { + if node.Generation() >= generation { + continue + } + + c.clusterAddrs = remove(c.clusterAddrs, addr) + delete(c.allNodes, addr) + collected = append(collected, node) + } + c.mu.Unlock() + + for _, node := range collected { + _ = node.Client.Close() + } +} + +func (c *clusterNodes) Get(addr string) (*clusterNode, error) { + var node *clusterNode + var err error + c.mu.RLock() + if c.closed { + err = pool.ErrClosed + } else { + node = c.allNodes[addr] + } + c.mu.RUnlock() + return node, err +} + +func (c *clusterNodes) GetOrCreate(addr string) (*clusterNode, error) { + node, err := c.Get(addr) + if err != nil { + return nil, err + } + if node != nil { + return node, nil + } + + v, err := c.nodeCreateGroup.Do(addr, func() (interface{}, error) { + node := newClusterNode(c.opt, addr) + return node, nil + }) + + c.mu.Lock() + defer c.mu.Unlock() + + if c.closed { + return nil, pool.ErrClosed + } + + node, ok := c.allNodes[addr] + if ok { + _ = v.(*clusterNode).Close() + return node, err + } + node = v.(*clusterNode) + + c.allAddrs = appendIfNotExists(c.allAddrs, addr) + if err == nil { + c.clusterAddrs = append(c.clusterAddrs, addr) + } + c.allNodes[addr] = node + + return node, err +} + +func (c *clusterNodes) All() ([]*clusterNode, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + if c.closed { + return nil, pool.ErrClosed + } + + cp := make([]*clusterNode, 0, len(c.allNodes)) + for _, node := range c.allNodes { + cp = append(cp, node) + } + return cp, nil +} + +func (c *clusterNodes) Random() (*clusterNode, error) { + addrs, err := c.Addrs() + if err != nil { + return nil, err + } + + n := rand.Intn(len(addrs)) + return c.GetOrCreate(addrs[n]) +} + +//------------------------------------------------------------------------------ + +type clusterSlot struct { + start, end int + nodes []*clusterNode +} + +type clusterSlotSlice []*clusterSlot + +func (p clusterSlotSlice) Len() int { + return len(p) +} + +func (p clusterSlotSlice) Less(i, j int) bool { + return p[i].start < p[j].start +} + +func (p clusterSlotSlice) Swap(i, j int) { + p[i], p[j] = p[j], p[i] +} + +type clusterState struct { + nodes *clusterNodes + Masters []*clusterNode + Slaves []*clusterNode + + slots []*clusterSlot + + generation uint32 + createdAt time.Time +} + +func newClusterState( + nodes *clusterNodes, slots []ClusterSlot, origin string, +) (*clusterState, error) { + c := clusterState{ + nodes: nodes, + + slots: make([]*clusterSlot, 0, len(slots)), + + generation: nodes.NextGeneration(), + createdAt: time.Now(), + } + + originHost, _, _ := net.SplitHostPort(origin) + isLoopbackOrigin := isLoopback(originHost) + + for _, slot := range slots { + var nodes []*clusterNode + for i, slotNode := range slot.Nodes { + addr := slotNode.Addr + if !isLoopbackOrigin { + addr = replaceLoopbackHost(addr, originHost) + } + + node, err := c.nodes.GetOrCreate(addr) + if err != nil { + return nil, err + } + + node.SetGeneration(c.generation) + nodes = append(nodes, node) + + if i == 0 { + c.Masters = appendUniqueNode(c.Masters, node) + } else { + c.Slaves = appendUniqueNode(c.Slaves, node) + } + } + + c.slots = append(c.slots, &clusterSlot{ + start: slot.Start, + end: slot.End, + nodes: nodes, + }) + } + + sort.Sort(clusterSlotSlice(c.slots)) + + time.AfterFunc(time.Minute, func() { + nodes.GC(c.generation) + }) + + return &c, nil +} + +func replaceLoopbackHost(nodeAddr, originHost string) string { + nodeHost, nodePort, err := net.SplitHostPort(nodeAddr) + if err != nil { + return nodeAddr + } + + nodeIP := net.ParseIP(nodeHost) + if nodeIP == nil { + return nodeAddr + } + + if !nodeIP.IsLoopback() { + return nodeAddr + } + + // Use origin host which is not loopback and node port. + return net.JoinHostPort(originHost, nodePort) +} + +func isLoopback(host string) bool { + ip := net.ParseIP(host) + if ip == nil { + return true + } + return ip.IsLoopback() +} + +func (c *clusterState) slotMasterNode(slot int) (*clusterNode, error) { + nodes := c.slotNodes(slot) + if len(nodes) > 0 { + return nodes[0], nil + } + return c.nodes.Random() +} + +func (c *clusterState) slotSlaveNode(slot int) (*clusterNode, error) { + nodes := c.slotNodes(slot) + switch len(nodes) { + case 0: + return c.nodes.Random() + case 1: + return nodes[0], nil + case 2: + if slave := nodes[1]; !slave.Loading() { + return slave, nil + } + return nodes[0], nil + default: + var slave *clusterNode + for i := 0; i < 10; i++ { + n := rand.Intn(len(nodes)-1) + 1 + slave = nodes[n] + if !slave.Loading() { + break + } + } + return slave, nil + } +} + +func (c *clusterState) slotClosestNode(slot int) (*clusterNode, error) { + const threshold = time.Millisecond + + nodes := c.slotNodes(slot) + if len(nodes) == 0 { + return c.nodes.Random() + } + + var node *clusterNode + for _, n := range nodes { + if n.Loading() { + continue + } + if node == nil || node.Latency()-n.Latency() > threshold { + node = n + } + } + return node, nil +} + +func (c *clusterState) slotRandomNode(slot int) *clusterNode { + nodes := c.slotNodes(slot) + n := rand.Intn(len(nodes)) + return nodes[n] +} + +func (c *clusterState) slotNodes(slot int) []*clusterNode { + i := sort.Search(len(c.slots), func(i int) bool { + return c.slots[i].end >= slot + }) + if i >= len(c.slots) { + return nil + } + x := c.slots[i] + if slot >= x.start && slot <= x.end { + return x.nodes + } + return nil +} + +func (c *clusterState) IsConsistent() bool { + if c.nodes.opt.ClusterSlots != nil { + return true + } + return len(c.Masters) <= len(c.Slaves) +} + +//------------------------------------------------------------------------------ + +type clusterStateHolder struct { + load func() (*clusterState, error) + + state atomic.Value + + firstErrMu sync.RWMutex + firstErr error + + reloading uint32 // atomic +} + +func newClusterStateHolder(fn func() (*clusterState, error)) *clusterStateHolder { + return &clusterStateHolder{ + load: fn, + } +} + +func (c *clusterStateHolder) Reload() (*clusterState, error) { + state, err := c.reload() + if err != nil { + return nil, err + } + if !state.IsConsistent() { + time.AfterFunc(time.Second, c.LazyReload) + } + return state, nil +} + +func (c *clusterStateHolder) reload() (*clusterState, error) { + state, err := c.load() + if err != nil { + c.firstErrMu.Lock() + if c.firstErr == nil { + c.firstErr = err + } + c.firstErrMu.Unlock() + return nil, err + } + c.state.Store(state) + return state, nil +} + +func (c *clusterStateHolder) LazyReload() { + if !atomic.CompareAndSwapUint32(&c.reloading, 0, 1) { + return + } + go func() { + defer atomic.StoreUint32(&c.reloading, 0) + + for { + state, err := c.reload() + if err != nil { + return + } + time.Sleep(100 * time.Millisecond) + if state.IsConsistent() { + return + } + } + }() +} + +func (c *clusterStateHolder) Get() (*clusterState, error) { + v := c.state.Load() + if v != nil { + state := v.(*clusterState) + if time.Since(state.createdAt) > time.Minute { + c.LazyReload() + } + return state, nil + } + + c.firstErrMu.RLock() + err := c.firstErr + c.firstErrMu.RUnlock() + if err != nil { + return nil, err + } + + return nil, errors.New("redis: cluster has no state") +} + +func (c *clusterStateHolder) ReloadOrGet() (*clusterState, error) { + state, err := c.Reload() + if err == nil { + return state, nil + } + return c.Get() +} + +//------------------------------------------------------------------------------ + +// ClusterClient is a Redis Cluster client representing a pool of zero +// or more underlying connections. It's safe for concurrent use by +// multiple goroutines. +type ClusterClient struct { + cmdable + + ctx context.Context + + opt *ClusterOptions + nodes *clusterNodes + state *clusterStateHolder + cmdsInfoCache *cmdsInfoCache + + process func(Cmder) error + processPipeline func([]Cmder) error + processTxPipeline func([]Cmder) error +} + +// NewClusterClient returns a Redis Cluster client as described in +// http://redis.io/topics/cluster-spec. +func NewClusterClient(opt *ClusterOptions) *ClusterClient { + opt.init() + + c := &ClusterClient{ + opt: opt, + nodes: newClusterNodes(opt), + } + c.state = newClusterStateHolder(c.loadState) + c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo) + + c.process = c.defaultProcess + c.processPipeline = c.defaultProcessPipeline + c.processTxPipeline = c.defaultProcessTxPipeline + + c.init() + + _, _ = c.state.Reload() + _, _ = c.cmdsInfoCache.Get() + + if opt.IdleCheckFrequency > 0 { + go c.reaper(opt.IdleCheckFrequency) + } + + return c +} + +// ReloadState reloads cluster state. It calls ClusterSlots func +// to get cluster slots information. +func (c *ClusterClient) ReloadState() error { + _, err := c.state.Reload() + return err +} + +func (c *ClusterClient) init() { + c.cmdable.setProcessor(c.Process) +} + +func (c *ClusterClient) Context() context.Context { + if c.ctx != nil { + return c.ctx + } + return context.Background() +} + +func (c *ClusterClient) WithContext(ctx context.Context) *ClusterClient { + if ctx == nil { + panic("nil context") + } + c2 := c.copy() + c2.ctx = ctx + return c2 +} + +func (c *ClusterClient) copy() *ClusterClient { + cp := *c + cp.init() + return &cp +} + +// Options returns read-only Options that were used to create the client. +func (c *ClusterClient) Options() *ClusterOptions { + return c.opt +} + +func (c *ClusterClient) retryBackoff(attempt int) time.Duration { + return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff) +} + +func (c *ClusterClient) cmdsInfo() (map[string]*CommandInfo, error) { + addrs, err := c.nodes.Addrs() + if err != nil { + return nil, err + } + + var firstErr error + for _, addr := range addrs { + node, err := c.nodes.Get(addr) + if err != nil { + return nil, err + } + if node == nil { + continue + } + + info, err := node.Client.Command().Result() + if err == nil { + return info, nil + } + if firstErr == nil { + firstErr = err + } + } + return nil, firstErr +} + +func (c *ClusterClient) cmdInfo(name string) *CommandInfo { + cmdsInfo, err := c.cmdsInfoCache.Get() + if err != nil { + return nil + } + + info := cmdsInfo[name] + if info == nil { + internal.Logf("info for cmd=%s not found", name) + } + return info +} + +func cmdSlot(cmd Cmder, pos int) int { + if pos == 0 { + return hashtag.RandomSlot() + } + firstKey := cmd.stringArg(pos) + return hashtag.Slot(firstKey) +} + +func (c *ClusterClient) cmdSlot(cmd Cmder) int { + cmdInfo := c.cmdInfo(cmd.Name()) + return cmdSlot(cmd, cmdFirstKeyPos(cmd, cmdInfo)) +} + +func (c *ClusterClient) cmdSlotAndNode(cmd Cmder) (int, *clusterNode, error) { + state, err := c.state.Get() + if err != nil { + return 0, nil, err + } + + cmdInfo := c.cmdInfo(cmd.Name()) + slot := cmdSlot(cmd, cmdFirstKeyPos(cmd, cmdInfo)) + + if c.opt.ReadOnly && cmdInfo != nil && cmdInfo.ReadOnly { + if c.opt.RouteByLatency { + node, err := state.slotClosestNode(slot) + return slot, node, err + } + + if c.opt.RouteRandomly { + node := state.slotRandomNode(slot) + return slot, node, nil + } + + node, err := state.slotSlaveNode(slot) + return slot, node, err + } + + node, err := state.slotMasterNode(slot) + return slot, node, err +} + +func (c *ClusterClient) slotMasterNode(slot int) (*clusterNode, error) { + state, err := c.state.Get() + if err != nil { + return nil, err + } + + nodes := state.slotNodes(slot) + if len(nodes) > 0 { + return nodes[0], nil + } + return c.nodes.Random() +} + +func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error { + if len(keys) == 0 { + return fmt.Errorf("redis: Watch requires at least one key") + } + + slot := hashtag.Slot(keys[0]) + for _, key := range keys[1:] { + if hashtag.Slot(key) != slot { + err := fmt.Errorf("redis: Watch requires all keys to be in the same slot") + return err + } + } + + node, err := c.slotMasterNode(slot) + if err != nil { + return err + } + + for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { + if attempt > 0 { + time.Sleep(c.retryBackoff(attempt)) + } + + err = node.Client.Watch(fn, keys...) + if err == nil { + break + } + + if internal.IsRetryableError(err, true) { + c.state.LazyReload() + continue + } + + moved, ask, addr := internal.IsMovedError(err) + if moved || ask { + c.state.LazyReload() + node, err = c.nodes.GetOrCreate(addr) + if err != nil { + return err + } + continue + } + + if err == pool.ErrClosed { + node, err = c.slotMasterNode(slot) + if err != nil { + return err + } + continue + } + + return err + } + + return err +} + +// Close closes the cluster client, releasing any open resources. +// +// It is rare to Close a ClusterClient, as the ClusterClient is meant +// to be long-lived and shared between many goroutines. +func (c *ClusterClient) Close() error { + return c.nodes.Close() +} + +// Do creates a Cmd from the args and processes the cmd. +func (c *ClusterClient) Do(args ...interface{}) *Cmd { + cmd := NewCmd(args...) + c.Process(cmd) + return cmd +} + +func (c *ClusterClient) WrapProcess( + fn func(oldProcess func(Cmder) error) func(Cmder) error, +) { + c.process = fn(c.process) +} + +func (c *ClusterClient) Process(cmd Cmder) error { + return c.process(cmd) +} + +func (c *ClusterClient) defaultProcess(cmd Cmder) error { + var node *clusterNode + var ask bool + for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { + if attempt > 0 { + time.Sleep(c.retryBackoff(attempt)) + } + + if node == nil { + var err error + _, node, err = c.cmdSlotAndNode(cmd) + if err != nil { + cmd.setErr(err) + break + } + } + + var err error + if ask { + pipe := node.Client.Pipeline() + _ = pipe.Process(NewCmd("ASKING")) + _ = pipe.Process(cmd) + _, err = pipe.Exec() + _ = pipe.Close() + ask = false + } else { + err = node.Client.Process(cmd) + } + + // If there is no error - we are done. + if err == nil { + break + } + + // If slave is loading - read from master. + if c.opt.ReadOnly && internal.IsLoadingError(err) { + node.MarkAsLoading() + continue + } + + if internal.IsRetryableError(err, true) { + c.state.LazyReload() + + // First retry the same node. + if attempt == 0 { + continue + } + + // Second try random node. + node, err = c.nodes.Random() + if err != nil { + break + } + continue + } + + var moved bool + var addr string + moved, ask, addr = internal.IsMovedError(err) + if moved || ask { + c.state.LazyReload() + + node, err = c.nodes.GetOrCreate(addr) + if err != nil { + break + } + continue + } + + if err == pool.ErrClosed { + node = nil + continue + } + + break + } + + return cmd.Err() +} + +// ForEachMaster concurrently calls the fn on each master node in the cluster. +// It returns the first error if any. +func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error { + state, err := c.state.ReloadOrGet() + if err != nil { + return err + } + + var wg sync.WaitGroup + errCh := make(chan error, 1) + for _, master := range state.Masters { + wg.Add(1) + go func(node *clusterNode) { + defer wg.Done() + err := fn(node.Client) + if err != nil { + select { + case errCh <- err: + default: + } + } + }(master) + } + wg.Wait() + + select { + case err := <-errCh: + return err + default: + return nil + } +} + +// ForEachSlave concurrently calls the fn on each slave node in the cluster. +// It returns the first error if any. +func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error { + state, err := c.state.ReloadOrGet() + if err != nil { + return err + } + + var wg sync.WaitGroup + errCh := make(chan error, 1) + for _, slave := range state.Slaves { + wg.Add(1) + go func(node *clusterNode) { + defer wg.Done() + err := fn(node.Client) + if err != nil { + select { + case errCh <- err: + default: + } + } + }(slave) + } + wg.Wait() + + select { + case err := <-errCh: + return err + default: + return nil + } +} + +// ForEachNode concurrently calls the fn on each known node in the cluster. +// It returns the first error if any. +func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error { + state, err := c.state.ReloadOrGet() + if err != nil { + return err + } + + var wg sync.WaitGroup + errCh := make(chan error, 1) + worker := func(node *clusterNode) { + defer wg.Done() + err := fn(node.Client) + if err != nil { + select { + case errCh <- err: + default: + } + } + } + + for _, node := range state.Masters { + wg.Add(1) + go worker(node) + } + for _, node := range state.Slaves { + wg.Add(1) + go worker(node) + } + + wg.Wait() + select { + case err := <-errCh: + return err + default: + return nil + } +} + +// PoolStats returns accumulated connection pool stats. +func (c *ClusterClient) PoolStats() *PoolStats { + var acc PoolStats + + state, _ := c.state.Get() + if state == nil { + return &acc + } + + for _, node := range state.Masters { + s := node.Client.connPool.Stats() + acc.Hits += s.Hits + acc.Misses += s.Misses + acc.Timeouts += s.Timeouts + + acc.TotalConns += s.TotalConns + acc.IdleConns += s.IdleConns + acc.StaleConns += s.StaleConns + } + + for _, node := range state.Slaves { + s := node.Client.connPool.Stats() + acc.Hits += s.Hits + acc.Misses += s.Misses + acc.Timeouts += s.Timeouts + + acc.TotalConns += s.TotalConns + acc.IdleConns += s.IdleConns + acc.StaleConns += s.StaleConns + } + + return &acc +} + +func (c *ClusterClient) loadState() (*clusterState, error) { + if c.opt.ClusterSlots != nil { + slots, err := c.opt.ClusterSlots() + if err != nil { + return nil, err + } + return newClusterState(c.nodes, slots, "") + } + + addrs, err := c.nodes.Addrs() + if err != nil { + return nil, err + } + + var firstErr error + for _, addr := range addrs { + node, err := c.nodes.GetOrCreate(addr) + if err != nil { + if firstErr == nil { + firstErr = err + } + continue + } + + slots, err := node.Client.ClusterSlots().Result() + if err != nil { + if firstErr == nil { + firstErr = err + } + continue + } + + return newClusterState(c.nodes, slots, node.Client.opt.Addr) + } + + return nil, firstErr +} + +// reaper closes idle connections to the cluster. +func (c *ClusterClient) reaper(idleCheckFrequency time.Duration) { + ticker := time.NewTicker(idleCheckFrequency) + defer ticker.Stop() + + for range ticker.C { + nodes, err := c.nodes.All() + if err != nil { + break + } + + for _, node := range nodes { + _, err := node.Client.connPool.(*pool.ConnPool).ReapStaleConns() + if err != nil { + internal.Logf("ReapStaleConns failed: %s", err) + } + } + } +} + +func (c *ClusterClient) Pipeline() Pipeliner { + pipe := Pipeline{ + exec: c.processPipeline, + } + pipe.statefulCmdable.setProcessor(pipe.Process) + return &pipe +} + +func (c *ClusterClient) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) { + return c.Pipeline().Pipelined(fn) +} + +func (c *ClusterClient) WrapProcessPipeline( + fn func(oldProcess func([]Cmder) error) func([]Cmder) error, +) { + c.processPipeline = fn(c.processPipeline) +} + +func (c *ClusterClient) defaultProcessPipeline(cmds []Cmder) error { + cmdsMap := newCmdsMap() + err := c.mapCmdsByNode(cmds, cmdsMap) + if err != nil { + setCmdsErr(cmds, err) + return err + } + + for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { + if attempt > 0 { + time.Sleep(c.retryBackoff(attempt)) + } + + failedCmds := newCmdsMap() + var wg sync.WaitGroup + + for node, cmds := range cmdsMap.m { + wg.Add(1) + go func(node *clusterNode, cmds []Cmder) { + defer wg.Done() + + cn, err := node.Client.getConn() + if err != nil { + if err == pool.ErrClosed { + c.mapCmdsByNode(cmds, failedCmds) + } else { + setCmdsErr(cmds, err) + } + return + } + + err = c.pipelineProcessCmds(node, cn, cmds, failedCmds) + node.Client.releaseConnStrict(cn, err) + }(node, cmds) + } + + wg.Wait() + if len(failedCmds.m) == 0 { + break + } + cmdsMap = failedCmds + } + + return cmdsFirstErr(cmds) +} + +type cmdsMap struct { + mu sync.Mutex + m map[*clusterNode][]Cmder +} + +func newCmdsMap() *cmdsMap { + return &cmdsMap{ + m: make(map[*clusterNode][]Cmder), + } +} + +func (c *ClusterClient) mapCmdsByNode(cmds []Cmder, cmdsMap *cmdsMap) error { + state, err := c.state.Get() + if err != nil { + setCmdsErr(cmds, err) + return err + } + + cmdsAreReadOnly := c.cmdsAreReadOnly(cmds) + for _, cmd := range cmds { + var node *clusterNode + var err error + if cmdsAreReadOnly { + _, node, err = c.cmdSlotAndNode(cmd) + } else { + slot := c.cmdSlot(cmd) + node, err = state.slotMasterNode(slot) + } + if err != nil { + return err + } + cmdsMap.mu.Lock() + cmdsMap.m[node] = append(cmdsMap.m[node], cmd) + cmdsMap.mu.Unlock() + } + return nil +} + +func (c *ClusterClient) cmdsAreReadOnly(cmds []Cmder) bool { + for _, cmd := range cmds { + cmdInfo := c.cmdInfo(cmd.Name()) + if cmdInfo == nil || !cmdInfo.ReadOnly { + return false + } + } + return true +} + +func (c *ClusterClient) pipelineProcessCmds( + node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap, +) error { + err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { + return writeCmd(wr, cmds...) + }) + if err != nil { + setCmdsErr(cmds, err) + failedCmds.mu.Lock() + failedCmds.m[node] = cmds + failedCmds.mu.Unlock() + return err + } + + err = cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error { + return c.pipelineReadCmds(rd, cmds, failedCmds) + }) + return err +} + +func (c *ClusterClient) pipelineReadCmds( + rd *proto.Reader, cmds []Cmder, failedCmds *cmdsMap, +) error { + for _, cmd := range cmds { + err := cmd.readReply(rd) + if err == nil { + continue + } + + if c.checkMovedErr(cmd, err, failedCmds) { + continue + } + + if internal.IsRedisError(err) { + continue + } + + return err + } + return nil +} + +func (c *ClusterClient) checkMovedErr( + cmd Cmder, err error, failedCmds *cmdsMap, +) bool { + moved, ask, addr := internal.IsMovedError(err) + + if moved { + c.state.LazyReload() + + node, err := c.nodes.GetOrCreate(addr) + if err != nil { + return false + } + + failedCmds.mu.Lock() + failedCmds.m[node] = append(failedCmds.m[node], cmd) + failedCmds.mu.Unlock() + return true + } + + if ask { + node, err := c.nodes.GetOrCreate(addr) + if err != nil { + return false + } + + failedCmds.mu.Lock() + failedCmds.m[node] = append(failedCmds.m[node], NewCmd("ASKING"), cmd) + failedCmds.mu.Unlock() + return true + } + + return false +} + +// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC. +func (c *ClusterClient) TxPipeline() Pipeliner { + pipe := Pipeline{ + exec: c.processTxPipeline, + } + pipe.statefulCmdable.setProcessor(pipe.Process) + return &pipe +} + +func (c *ClusterClient) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) { + return c.TxPipeline().Pipelined(fn) +} + +func (c *ClusterClient) defaultProcessTxPipeline(cmds []Cmder) error { + state, err := c.state.Get() + if err != nil { + return err + } + + cmdsMap := c.mapCmdsBySlot(cmds) + for slot, cmds := range cmdsMap { + node, err := state.slotMasterNode(slot) + if err != nil { + setCmdsErr(cmds, err) + continue + } + cmdsMap := map[*clusterNode][]Cmder{node: cmds} + + for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { + if attempt > 0 { + time.Sleep(c.retryBackoff(attempt)) + } + + failedCmds := newCmdsMap() + var wg sync.WaitGroup + + for node, cmds := range cmdsMap { + wg.Add(1) + go func(node *clusterNode, cmds []Cmder) { + defer wg.Done() + + cn, err := node.Client.getConn() + if err != nil { + if err == pool.ErrClosed { + c.mapCmdsByNode(cmds, failedCmds) + } else { + setCmdsErr(cmds, err) + } + return + } + + err = c.txPipelineProcessCmds(node, cn, cmds, failedCmds) + node.Client.releaseConnStrict(cn, err) + }(node, cmds) + } + + wg.Wait() + if len(failedCmds.m) == 0 { + break + } + cmdsMap = failedCmds.m + } + } + + return cmdsFirstErr(cmds) +} + +func (c *ClusterClient) mapCmdsBySlot(cmds []Cmder) map[int][]Cmder { + cmdsMap := make(map[int][]Cmder) + for _, cmd := range cmds { + slot := c.cmdSlot(cmd) + cmdsMap[slot] = append(cmdsMap[slot], cmd) + } + return cmdsMap +} + +func (c *ClusterClient) txPipelineProcessCmds( + node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap, +) error { + err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { + return txPipelineWriteMulti(wr, cmds) + }) + if err != nil { + setCmdsErr(cmds, err) + failedCmds.mu.Lock() + failedCmds.m[node] = cmds + failedCmds.mu.Unlock() + return err + } + + err = cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error { + err := c.txPipelineReadQueued(rd, cmds, failedCmds) + if err != nil { + setCmdsErr(cmds, err) + return err + } + return pipelineReadCmds(rd, cmds) + }) + return err +} + +func (c *ClusterClient) txPipelineReadQueued( + rd *proto.Reader, cmds []Cmder, failedCmds *cmdsMap, +) error { + // Parse queued replies. + var statusCmd StatusCmd + if err := statusCmd.readReply(rd); err != nil { + return err + } + + for _, cmd := range cmds { + err := statusCmd.readReply(rd) + if err == nil { + continue + } + + if c.checkMovedErr(cmd, err, failedCmds) || internal.IsRedisError(err) { + continue + } + + return err + } + + // Parse number of replies. + line, err := rd.ReadLine() + if err != nil { + if err == Nil { + err = TxFailedErr + } + return err + } + + switch line[0] { + case proto.ErrorReply: + err := proto.ParseErrorReply(line) + for _, cmd := range cmds { + if !c.checkMovedErr(cmd, err, failedCmds) { + break + } + } + return err + case proto.ArrayReply: + // ok + default: + err := fmt.Errorf("redis: expected '*', but got line %q", line) + return err + } + + return nil +} + +func (c *ClusterClient) pubSub(channels []string) *PubSub { + var node *clusterNode + pubsub := &PubSub{ + opt: c.opt.clientOptions(), + + newConn: func(channels []string) (*pool.Conn, error) { + if node == nil { + var slot int + if len(channels) > 0 { + slot = hashtag.Slot(channels[0]) + } else { + slot = -1 + } + + masterNode, err := c.slotMasterNode(slot) + if err != nil { + return nil, err + } + node = masterNode + } + return node.Client.newConn() + }, + closeConn: func(cn *pool.Conn) error { + return node.Client.connPool.CloseConn(cn) + }, + } + pubsub.init() + return pubsub +} + +// Subscribe subscribes the client to the specified channels. +// Channels can be omitted to create empty subscription. +func (c *ClusterClient) Subscribe(channels ...string) *PubSub { + pubsub := c.pubSub(channels) + if len(channels) > 0 { + _ = pubsub.Subscribe(channels...) + } + return pubsub +} + +// PSubscribe subscribes the client to the given patterns. +// Patterns can be omitted to create empty subscription. +func (c *ClusterClient) PSubscribe(channels ...string) *PubSub { + pubsub := c.pubSub(channels) + if len(channels) > 0 { + _ = pubsub.PSubscribe(channels...) + } + return pubsub +} + +func appendUniqueNode(nodes []*clusterNode, node *clusterNode) []*clusterNode { + for _, n := range nodes { + if n == node { + return nodes + } + } + return append(nodes, node) +} + +func appendIfNotExists(ss []string, es ...string) []string { +loop: + for _, e := range es { + for _, s := range ss { + if s == e { + continue loop + } + } + ss = append(ss, e) + } + return ss +} + +func remove(ss []string, es ...string) []string { + if len(es) == 0 { + return ss[:0] + } + for _, e := range es { + for i, s := range ss { + if s == e { + ss = append(ss[:i], ss[i+1:]...) + break + } + } + } + return ss +} diff --git a/vendor/github.com/go-redis/redis/cluster_commands.go b/vendor/github.com/go-redis/redis/cluster_commands.go new file mode 100644 index 00000000..dff62c90 --- /dev/null +++ b/vendor/github.com/go-redis/redis/cluster_commands.go @@ -0,0 +1,22 @@ +package redis + +import "sync/atomic" + +func (c *ClusterClient) DBSize() *IntCmd { + cmd := NewIntCmd("dbsize") + var size int64 + err := c.ForEachMaster(func(master *Client) error { + n, err := master.DBSize().Result() + if err != nil { + return err + } + atomic.AddInt64(&size, n) + return nil + }) + if err != nil { + cmd.setErr(err) + return cmd + } + cmd.val = size + return cmd +} diff --git a/vendor/github.com/go-redis/redis/command.go b/vendor/github.com/go-redis/redis/command.go new file mode 100644 index 00000000..05dd6755 --- /dev/null +++ b/vendor/github.com/go-redis/redis/command.go @@ -0,0 +1,1874 @@ +package redis + +import ( + "fmt" + "net" + "strconv" + "strings" + "time" + + "github.com/go-redis/redis/internal" + "github.com/go-redis/redis/internal/proto" +) + +type Cmder interface { + Name() string + Args() []interface{} + stringArg(int) string + + readReply(rd *proto.Reader) error + setErr(error) + + readTimeout() *time.Duration + + Err() error +} + +func setCmdsErr(cmds []Cmder, e error) { + for _, cmd := range cmds { + if cmd.Err() == nil { + cmd.setErr(e) + } + } +} + +func cmdsFirstErr(cmds []Cmder) error { + for _, cmd := range cmds { + if err := cmd.Err(); err != nil { + return err + } + } + return nil +} + +func writeCmd(wr *proto.Writer, cmds ...Cmder) error { + for _, cmd := range cmds { + err := wr.WriteArgs(cmd.Args()) + if err != nil { + return err + } + } + return nil +} + +func cmdString(cmd Cmder, val interface{}) string { + var ss []string + for _, arg := range cmd.Args() { + ss = append(ss, fmt.Sprint(arg)) + } + s := strings.Join(ss, " ") + if err := cmd.Err(); err != nil { + return s + ": " + err.Error() + } + if val != nil { + switch vv := val.(type) { + case []byte: + return s + ": " + string(vv) + default: + return s + ": " + fmt.Sprint(val) + } + } + return s + +} + +func cmdFirstKeyPos(cmd Cmder, info *CommandInfo) int { + switch cmd.Name() { + case "eval", "evalsha": + if cmd.stringArg(2) != "0" { + return 3 + } + + return 0 + case "publish": + return 1 + } + if info == nil { + return 0 + } + return int(info.FirstKeyPos) +} + +//------------------------------------------------------------------------------ + +type baseCmd struct { + _args []interface{} + err error + + _readTimeout *time.Duration +} + +var _ Cmder = (*Cmd)(nil) + +func (cmd *baseCmd) Err() error { + return cmd.err +} + +func (cmd *baseCmd) Args() []interface{} { + return cmd._args +} + +func (cmd *baseCmd) stringArg(pos int) string { + if pos < 0 || pos >= len(cmd._args) { + return "" + } + s, _ := cmd._args[pos].(string) + return s +} + +func (cmd *baseCmd) Name() string { + if len(cmd._args) > 0 { + // Cmd name must be lower cased. + s := internal.ToLower(cmd.stringArg(0)) + cmd._args[0] = s + return s + } + return "" +} + +func (cmd *baseCmd) readTimeout() *time.Duration { + return cmd._readTimeout +} + +func (cmd *baseCmd) setReadTimeout(d time.Duration) { + cmd._readTimeout = &d +} + +func (cmd *baseCmd) setErr(e error) { + cmd.err = e +} + +//------------------------------------------------------------------------------ + +type Cmd struct { + baseCmd + + val interface{} +} + +func NewCmd(args ...interface{}) *Cmd { + return &Cmd{ + baseCmd: baseCmd{_args: args}, + } +} + +func (cmd *Cmd) Val() interface{} { + return cmd.val +} + +func (cmd *Cmd) Result() (interface{}, error) { + return cmd.val, cmd.err +} + +func (cmd *Cmd) String() (string, error) { + if cmd.err != nil { + return "", cmd.err + } + switch val := cmd.val.(type) { + case string: + return val, nil + default: + err := fmt.Errorf("redis: unexpected type=%T for String", val) + return "", err + } +} + +func (cmd *Cmd) Int() (int, error) { + if cmd.err != nil { + return 0, cmd.err + } + switch val := cmd.val.(type) { + case int64: + return int(val), nil + case string: + return strconv.Atoi(val) + default: + err := fmt.Errorf("redis: unexpected type=%T for Int", val) + return 0, err + } +} + +func (cmd *Cmd) Int64() (int64, error) { + if cmd.err != nil { + return 0, cmd.err + } + switch val := cmd.val.(type) { + case int64: + return val, nil + case string: + return strconv.ParseInt(val, 10, 64) + default: + err := fmt.Errorf("redis: unexpected type=%T for Int64", val) + return 0, err + } +} + +func (cmd *Cmd) Uint64() (uint64, error) { + if cmd.err != nil { + return 0, cmd.err + } + switch val := cmd.val.(type) { + case int64: + return uint64(val), nil + case string: + return strconv.ParseUint(val, 10, 64) + default: + err := fmt.Errorf("redis: unexpected type=%T for Uint64", val) + return 0, err + } +} + +func (cmd *Cmd) Float64() (float64, error) { + if cmd.err != nil { + return 0, cmd.err + } + switch val := cmd.val.(type) { + case int64: + return float64(val), nil + case string: + return strconv.ParseFloat(val, 64) + default: + err := fmt.Errorf("redis: unexpected type=%T for Float64", val) + return 0, err + } +} + +func (cmd *Cmd) Bool() (bool, error) { + if cmd.err != nil { + return false, cmd.err + } + switch val := cmd.val.(type) { + case int64: + return val != 0, nil + case string: + return strconv.ParseBool(val) + default: + err := fmt.Errorf("redis: unexpected type=%T for Bool", val) + return false, err + } +} + +func (cmd *Cmd) readReply(rd *proto.Reader) error { + cmd.val, cmd.err = rd.ReadReply(sliceParser) + return cmd.err +} + +// Implements proto.MultiBulkParse +func sliceParser(rd *proto.Reader, n int64) (interface{}, error) { + vals := make([]interface{}, 0, n) + for i := int64(0); i < n; i++ { + v, err := rd.ReadReply(sliceParser) + if err != nil { + if err == Nil { + vals = append(vals, nil) + continue + } + if err, ok := err.(proto.RedisError); ok { + vals = append(vals, err) + continue + } + return nil, err + } + + switch v := v.(type) { + case string: + vals = append(vals, v) + default: + vals = append(vals, v) + } + } + return vals, nil +} + +//------------------------------------------------------------------------------ + +type SliceCmd struct { + baseCmd + + val []interface{} +} + +var _ Cmder = (*SliceCmd)(nil) + +func NewSliceCmd(args ...interface{}) *SliceCmd { + return &SliceCmd{ + baseCmd: baseCmd{_args: args}, + } +} + +func (cmd *SliceCmd) Val() []interface{} { + return cmd.val +} + +func (cmd *SliceCmd) Result() ([]interface{}, error) { + return cmd.val, cmd.err +} + +func (cmd *SliceCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *SliceCmd) readReply(rd *proto.Reader) error { + var v interface{} + v, cmd.err = rd.ReadArrayReply(sliceParser) + if cmd.err != nil { + return cmd.err + } + cmd.val = v.([]interface{}) + return nil +} + +//------------------------------------------------------------------------------ + +type StatusCmd struct { + baseCmd + + val string +} + +var _ Cmder = (*StatusCmd)(nil) + +func NewStatusCmd(args ...interface{}) *StatusCmd { + return &StatusCmd{ + baseCmd: baseCmd{_args: args}, + } +} + +func (cmd *StatusCmd) Val() string { + return cmd.val +} + +func (cmd *StatusCmd) Result() (string, error) { + return cmd.val, cmd.err +} + +func (cmd *StatusCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *StatusCmd) readReply(rd *proto.Reader) error { + cmd.val, cmd.err = rd.ReadString() + return cmd.err +} + +//------------------------------------------------------------------------------ + +type IntCmd struct { + baseCmd + + val int64 +} + +var _ Cmder = (*IntCmd)(nil) + +func NewIntCmd(args ...interface{}) *IntCmd { + return &IntCmd{ + baseCmd: baseCmd{_args: args}, + } +} + +func (cmd *IntCmd) Val() int64 { + return cmd.val +} + +func (cmd *IntCmd) Result() (int64, error) { + return cmd.val, cmd.err +} + +func (cmd *IntCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *IntCmd) readReply(rd *proto.Reader) error { + cmd.val, cmd.err = rd.ReadIntReply() + return cmd.err +} + +//------------------------------------------------------------------------------ + +type DurationCmd struct { + baseCmd + + val time.Duration + precision time.Duration +} + +var _ Cmder = (*DurationCmd)(nil) + +func NewDurationCmd(precision time.Duration, args ...interface{}) *DurationCmd { + return &DurationCmd{ + baseCmd: baseCmd{_args: args}, + precision: precision, + } +} + +func (cmd *DurationCmd) Val() time.Duration { + return cmd.val +} + +func (cmd *DurationCmd) Result() (time.Duration, error) { + return cmd.val, cmd.err +} + +func (cmd *DurationCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *DurationCmd) readReply(rd *proto.Reader) error { + var n int64 + n, cmd.err = rd.ReadIntReply() + if cmd.err != nil { + return cmd.err + } + cmd.val = time.Duration(n) * cmd.precision + return nil +} + +//------------------------------------------------------------------------------ + +type TimeCmd struct { + baseCmd + + val time.Time +} + +var _ Cmder = (*TimeCmd)(nil) + +func NewTimeCmd(args ...interface{}) *TimeCmd { + return &TimeCmd{ + baseCmd: baseCmd{_args: args}, + } +} + +func (cmd *TimeCmd) Val() time.Time { + return cmd.val +} + +func (cmd *TimeCmd) Result() (time.Time, error) { + return cmd.val, cmd.err +} + +func (cmd *TimeCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *TimeCmd) readReply(rd *proto.Reader) error { + var v interface{} + v, cmd.err = rd.ReadArrayReply(timeParser) + if cmd.err != nil { + return cmd.err + } + cmd.val = v.(time.Time) + return nil +} + +// Implements proto.MultiBulkParse +func timeParser(rd *proto.Reader, n int64) (interface{}, error) { + if n != 2 { + return nil, fmt.Errorf("got %d elements, expected 2", n) + } + + sec, err := rd.ReadInt() + if err != nil { + return nil, err + } + + microsec, err := rd.ReadInt() + if err != nil { + return nil, err + } + + return time.Unix(sec, microsec*1000), nil +} + +//------------------------------------------------------------------------------ + +type BoolCmd struct { + baseCmd + + val bool +} + +var _ Cmder = (*BoolCmd)(nil) + +func NewBoolCmd(args ...interface{}) *BoolCmd { + return &BoolCmd{ + baseCmd: baseCmd{_args: args}, + } +} + +func (cmd *BoolCmd) Val() bool { + return cmd.val +} + +func (cmd *BoolCmd) Result() (bool, error) { + return cmd.val, cmd.err +} + +func (cmd *BoolCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *BoolCmd) readReply(rd *proto.Reader) error { + var v interface{} + v, cmd.err = rd.ReadReply(nil) + // `SET key value NX` returns nil when key already exists. But + // `SETNX key value` returns bool (0/1). So convert nil to bool. + // TODO: is this okay? + if cmd.err == Nil { + cmd.val = false + cmd.err = nil + return nil + } + if cmd.err != nil { + return cmd.err + } + switch v := v.(type) { + case int64: + cmd.val = v == 1 + return nil + case string: + cmd.val = v == "OK" + return nil + default: + cmd.err = fmt.Errorf("got %T, wanted int64 or string", v) + return cmd.err + } +} + +//------------------------------------------------------------------------------ + +type StringCmd struct { + baseCmd + + val string +} + +var _ Cmder = (*StringCmd)(nil) + +func NewStringCmd(args ...interface{}) *StringCmd { + return &StringCmd{ + baseCmd: baseCmd{_args: args}, + } +} + +func (cmd *StringCmd) Val() string { + return cmd.val +} + +func (cmd *StringCmd) Result() (string, error) { + return cmd.Val(), cmd.err +} + +func (cmd *StringCmd) Bytes() ([]byte, error) { + return []byte(cmd.val), cmd.err +} + +func (cmd *StringCmd) Int() (int, error) { + if cmd.err != nil { + return 0, cmd.err + } + return strconv.Atoi(cmd.Val()) +} + +func (cmd *StringCmd) Int64() (int64, error) { + if cmd.err != nil { + return 0, cmd.err + } + return strconv.ParseInt(cmd.Val(), 10, 64) +} + +func (cmd *StringCmd) Uint64() (uint64, error) { + if cmd.err != nil { + return 0, cmd.err + } + return strconv.ParseUint(cmd.Val(), 10, 64) +} + +func (cmd *StringCmd) Float64() (float64, error) { + if cmd.err != nil { + return 0, cmd.err + } + return strconv.ParseFloat(cmd.Val(), 64) +} + +func (cmd *StringCmd) Scan(val interface{}) error { + if cmd.err != nil { + return cmd.err + } + return proto.Scan([]byte(cmd.val), val) +} + +func (cmd *StringCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *StringCmd) readReply(rd *proto.Reader) error { + cmd.val, cmd.err = rd.ReadString() + return cmd.err +} + +//------------------------------------------------------------------------------ + +type FloatCmd struct { + baseCmd + + val float64 +} + +var _ Cmder = (*FloatCmd)(nil) + +func NewFloatCmd(args ...interface{}) *FloatCmd { + return &FloatCmd{ + baseCmd: baseCmd{_args: args}, + } +} + +func (cmd *FloatCmd) Val() float64 { + return cmd.val +} + +func (cmd *FloatCmd) Result() (float64, error) { + return cmd.Val(), cmd.Err() +} + +func (cmd *FloatCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *FloatCmd) readReply(rd *proto.Reader) error { + cmd.val, cmd.err = rd.ReadFloatReply() + return cmd.err +} + +//------------------------------------------------------------------------------ + +type StringSliceCmd struct { + baseCmd + + val []string +} + +var _ Cmder = (*StringSliceCmd)(nil) + +func NewStringSliceCmd(args ...interface{}) *StringSliceCmd { + return &StringSliceCmd{ + baseCmd: baseCmd{_args: args}, + } +} + +func (cmd *StringSliceCmd) Val() []string { + return cmd.val +} + +func (cmd *StringSliceCmd) Result() ([]string, error) { + return cmd.Val(), cmd.Err() +} + +func (cmd *StringSliceCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *StringSliceCmd) ScanSlice(container interface{}) error { + return proto.ScanSlice(cmd.Val(), container) +} + +func (cmd *StringSliceCmd) readReply(rd *proto.Reader) error { + var v interface{} + v, cmd.err = rd.ReadArrayReply(stringSliceParser) + if cmd.err != nil { + return cmd.err + } + cmd.val = v.([]string) + return nil +} + +// Implements proto.MultiBulkParse +func stringSliceParser(rd *proto.Reader, n int64) (interface{}, error) { + ss := make([]string, 0, n) + for i := int64(0); i < n; i++ { + s, err := rd.ReadString() + if err == Nil { + ss = append(ss, "") + } else if err != nil { + return nil, err + } else { + ss = append(ss, s) + } + } + return ss, nil +} + +//------------------------------------------------------------------------------ + +type BoolSliceCmd struct { + baseCmd + + val []bool +} + +var _ Cmder = (*BoolSliceCmd)(nil) + +func NewBoolSliceCmd(args ...interface{}) *BoolSliceCmd { + return &BoolSliceCmd{ + baseCmd: baseCmd{_args: args}, + } +} + +func (cmd *BoolSliceCmd) Val() []bool { + return cmd.val +} + +func (cmd *BoolSliceCmd) Result() ([]bool, error) { + return cmd.val, cmd.err +} + +func (cmd *BoolSliceCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *BoolSliceCmd) readReply(rd *proto.Reader) error { + var v interface{} + v, cmd.err = rd.ReadArrayReply(boolSliceParser) + if cmd.err != nil { + return cmd.err + } + cmd.val = v.([]bool) + return nil +} + +// Implements proto.MultiBulkParse +func boolSliceParser(rd *proto.Reader, n int64) (interface{}, error) { + bools := make([]bool, 0, n) + for i := int64(0); i < n; i++ { + n, err := rd.ReadIntReply() + if err != nil { + return nil, err + } + bools = append(bools, n == 1) + } + return bools, nil +} + +//------------------------------------------------------------------------------ + +type StringStringMapCmd struct { + baseCmd + + val map[string]string +} + +var _ Cmder = (*StringStringMapCmd)(nil) + +func NewStringStringMapCmd(args ...interface{}) *StringStringMapCmd { + return &StringStringMapCmd{ + baseCmd: baseCmd{_args: args}, + } +} + +func (cmd *StringStringMapCmd) Val() map[string]string { + return cmd.val +} + +func (cmd *StringStringMapCmd) Result() (map[string]string, error) { + return cmd.val, cmd.err +} + +func (cmd *StringStringMapCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *StringStringMapCmd) readReply(rd *proto.Reader) error { + var v interface{} + v, cmd.err = rd.ReadArrayReply(stringStringMapParser) + if cmd.err != nil { + return cmd.err + } + cmd.val = v.(map[string]string) + return nil +} + +// Implements proto.MultiBulkParse +func stringStringMapParser(rd *proto.Reader, n int64) (interface{}, error) { + m := make(map[string]string, n/2) + for i := int64(0); i < n; i += 2 { + key, err := rd.ReadString() + if err != nil { + return nil, err + } + + value, err := rd.ReadString() + if err != nil { + return nil, err + } + + m[key] = value + } + return m, nil +} + +//------------------------------------------------------------------------------ + +type StringIntMapCmd struct { + baseCmd + + val map[string]int64 +} + +var _ Cmder = (*StringIntMapCmd)(nil) + +func NewStringIntMapCmd(args ...interface{}) *StringIntMapCmd { + return &StringIntMapCmd{ + baseCmd: baseCmd{_args: args}, + } +} + +func (cmd *StringIntMapCmd) Val() map[string]int64 { + return cmd.val +} + +func (cmd *StringIntMapCmd) Result() (map[string]int64, error) { + return cmd.val, cmd.err +} + +func (cmd *StringIntMapCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *StringIntMapCmd) readReply(rd *proto.Reader) error { + var v interface{} + v, cmd.err = rd.ReadArrayReply(stringIntMapParser) + if cmd.err != nil { + return cmd.err + } + cmd.val = v.(map[string]int64) + return nil +} + +// Implements proto.MultiBulkParse +func stringIntMapParser(rd *proto.Reader, n int64) (interface{}, error) { + m := make(map[string]int64, n/2) + for i := int64(0); i < n; i += 2 { + key, err := rd.ReadString() + if err != nil { + return nil, err + } + + n, err := rd.ReadIntReply() + if err != nil { + return nil, err + } + + m[key] = n + } + return m, nil +} + +//------------------------------------------------------------------------------ + +type StringStructMapCmd struct { + baseCmd + + val map[string]struct{} +} + +var _ Cmder = (*StringStructMapCmd)(nil) + +func NewStringStructMapCmd(args ...interface{}) *StringStructMapCmd { + return &StringStructMapCmd{ + baseCmd: baseCmd{_args: args}, + } +} + +func (cmd *StringStructMapCmd) Val() map[string]struct{} { + return cmd.val +} + +func (cmd *StringStructMapCmd) Result() (map[string]struct{}, error) { + return cmd.val, cmd.err +} + +func (cmd *StringStructMapCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *StringStructMapCmd) readReply(rd *proto.Reader) error { + var v interface{} + v, cmd.err = rd.ReadArrayReply(stringStructMapParser) + if cmd.err != nil { + return cmd.err + } + cmd.val = v.(map[string]struct{}) + return nil +} + +// Implements proto.MultiBulkParse +func stringStructMapParser(rd *proto.Reader, n int64) (interface{}, error) { + m := make(map[string]struct{}, n) + for i := int64(0); i < n; i++ { + key, err := rd.ReadString() + if err != nil { + return nil, err + } + + m[key] = struct{}{} + } + return m, nil +} + +//------------------------------------------------------------------------------ + +type XMessage struct { + ID string + Values map[string]interface{} +} + +type XMessageSliceCmd struct { + baseCmd + + val []XMessage +} + +var _ Cmder = (*XMessageSliceCmd)(nil) + +func NewXMessageSliceCmd(args ...interface{}) *XMessageSliceCmd { + return &XMessageSliceCmd{ + baseCmd: baseCmd{_args: args}, + } +} + +func (cmd *XMessageSliceCmd) Val() []XMessage { + return cmd.val +} + +func (cmd *XMessageSliceCmd) Result() ([]XMessage, error) { + return cmd.val, cmd.err +} + +func (cmd *XMessageSliceCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *XMessageSliceCmd) readReply(rd *proto.Reader) error { + var v interface{} + v, cmd.err = rd.ReadArrayReply(xMessageSliceParser) + if cmd.err != nil { + return cmd.err + } + cmd.val = v.([]XMessage) + return nil +} + +// Implements proto.MultiBulkParse +func xMessageSliceParser(rd *proto.Reader, n int64) (interface{}, error) { + msgs := make([]XMessage, 0, n) + for i := int64(0); i < n; i++ { + _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + id, err := rd.ReadString() + if err != nil { + return nil, err + } + + v, err := rd.ReadArrayReply(stringInterfaceMapParser) + if err != nil { + return nil, err + } + + msgs = append(msgs, XMessage{ + ID: id, + Values: v.(map[string]interface{}), + }) + return nil, nil + }) + if err != nil { + return nil, err + } + } + return msgs, nil +} + +// Implements proto.MultiBulkParse +func stringInterfaceMapParser(rd *proto.Reader, n int64) (interface{}, error) { + m := make(map[string]interface{}, n/2) + for i := int64(0); i < n; i += 2 { + key, err := rd.ReadString() + if err != nil { + return nil, err + } + + value, err := rd.ReadString() + if err != nil { + return nil, err + } + + m[key] = value + } + return m, nil +} + +//------------------------------------------------------------------------------ + +type XStream struct { + Stream string + Messages []XMessage +} + +type XStreamSliceCmd struct { + baseCmd + + val []XStream +} + +var _ Cmder = (*XStreamSliceCmd)(nil) + +func NewXStreamSliceCmd(args ...interface{}) *XStreamSliceCmd { + return &XStreamSliceCmd{ + baseCmd: baseCmd{_args: args}, + } +} + +func (cmd *XStreamSliceCmd) Val() []XStream { + return cmd.val +} + +func (cmd *XStreamSliceCmd) Result() ([]XStream, error) { + return cmd.val, cmd.err +} + +func (cmd *XStreamSliceCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *XStreamSliceCmd) readReply(rd *proto.Reader) error { + var v interface{} + v, cmd.err = rd.ReadArrayReply(xStreamSliceParser) + if cmd.err != nil { + return cmd.err + } + cmd.val = v.([]XStream) + return nil +} + +// Implements proto.MultiBulkParse +func xStreamSliceParser(rd *proto.Reader, n int64) (interface{}, error) { + ret := make([]XStream, 0, n) + for i := int64(0); i < n; i++ { + _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + if n != 2 { + return nil, fmt.Errorf("got %d, wanted 2", n) + } + + stream, err := rd.ReadString() + if err != nil { + return nil, err + } + + v, err := rd.ReadArrayReply(xMessageSliceParser) + if err != nil { + return nil, err + } + + ret = append(ret, XStream{ + Stream: stream, + Messages: v.([]XMessage), + }) + return nil, nil + }) + if err != nil { + return nil, err + } + } + return ret, nil +} + +//------------------------------------------------------------------------------ + +type XPending struct { + Count int64 + Lower string + Higher string + Consumers map[string]int64 +} + +type XPendingCmd struct { + baseCmd + val *XPending +} + +var _ Cmder = (*XPendingCmd)(nil) + +func NewXPendingCmd(args ...interface{}) *XPendingCmd { + return &XPendingCmd{ + baseCmd: baseCmd{_args: args}, + } +} + +func (cmd *XPendingCmd) Val() *XPending { + return cmd.val +} + +func (cmd *XPendingCmd) Result() (*XPending, error) { + return cmd.val, cmd.err +} + +func (cmd *XPendingCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *XPendingCmd) readReply(rd *proto.Reader) error { + var info interface{} + info, cmd.err = rd.ReadArrayReply(xPendingParser) + if cmd.err != nil { + return cmd.err + } + cmd.val = info.(*XPending) + return nil +} + +func xPendingParser(rd *proto.Reader, n int64) (interface{}, error) { + if n != 4 { + return nil, fmt.Errorf("got %d, wanted 4", n) + } + + count, err := rd.ReadIntReply() + if err != nil { + return nil, err + } + + lower, err := rd.ReadString() + if err != nil && err != Nil { + return nil, err + } + + higher, err := rd.ReadString() + if err != nil && err != Nil { + return nil, err + } + + pending := &XPending{ + Count: count, + Lower: lower, + Higher: higher, + } + _, err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + for i := int64(0); i < n; i++ { + _, err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + if n != 2 { + return nil, fmt.Errorf("got %d, wanted 2", n) + } + + consumerName, err := rd.ReadString() + if err != nil { + return nil, err + } + + consumerPending, err := rd.ReadInt() + if err != nil { + return nil, err + } + + if pending.Consumers == nil { + pending.Consumers = make(map[string]int64) + } + pending.Consumers[consumerName] = consumerPending + + return nil, nil + }) + if err != nil { + return nil, err + } + } + return nil, nil + }) + if err != nil && err != Nil { + return nil, err + } + + return pending, nil +} + +//------------------------------------------------------------------------------ + +type XPendingExt struct { + Id string + Consumer string + Idle time.Duration + RetryCount int64 +} + +type XPendingExtCmd struct { + baseCmd + val []XPendingExt +} + +var _ Cmder = (*XPendingExtCmd)(nil) + +func NewXPendingExtCmd(args ...interface{}) *XPendingExtCmd { + return &XPendingExtCmd{ + baseCmd: baseCmd{_args: args}, + } +} + +func (cmd *XPendingExtCmd) Val() []XPendingExt { + return cmd.val +} + +func (cmd *XPendingExtCmd) Result() ([]XPendingExt, error) { + return cmd.val, cmd.err +} + +func (cmd *XPendingExtCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *XPendingExtCmd) readReply(rd *proto.Reader) error { + var info interface{} + info, cmd.err = rd.ReadArrayReply(xPendingExtSliceParser) + if cmd.err != nil { + return cmd.err + } + cmd.val = info.([]XPendingExt) + return nil +} + +func xPendingExtSliceParser(rd *proto.Reader, n int64) (interface{}, error) { + ret := make([]XPendingExt, 0, n) + for i := int64(0); i < n; i++ { + _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + if n != 4 { + return nil, fmt.Errorf("got %d, wanted 4", n) + } + + id, err := rd.ReadString() + if err != nil { + return nil, err + } + + consumer, err := rd.ReadString() + if err != nil && err != Nil { + return nil, err + } + + idle, err := rd.ReadIntReply() + if err != nil && err != Nil { + return nil, err + } + + retryCount, err := rd.ReadIntReply() + if err != nil && err != Nil { + return nil, err + } + + ret = append(ret, XPendingExt{ + Id: id, + Consumer: consumer, + Idle: time.Duration(idle) * time.Millisecond, + RetryCount: retryCount, + }) + return nil, nil + }) + if err != nil { + return nil, err + } + } + return ret, nil +} + +//------------------------------------------------------------------------------ + +//------------------------------------------------------------------------------ + +type ZSliceCmd struct { + baseCmd + + val []Z +} + +var _ Cmder = (*ZSliceCmd)(nil) + +func NewZSliceCmd(args ...interface{}) *ZSliceCmd { + return &ZSliceCmd{ + baseCmd: baseCmd{_args: args}, + } +} + +func (cmd *ZSliceCmd) Val() []Z { + return cmd.val +} + +func (cmd *ZSliceCmd) Result() ([]Z, error) { + return cmd.val, cmd.err +} + +func (cmd *ZSliceCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *ZSliceCmd) readReply(rd *proto.Reader) error { + var v interface{} + v, cmd.err = rd.ReadArrayReply(zSliceParser) + if cmd.err != nil { + return cmd.err + } + cmd.val = v.([]Z) + return nil +} + +// Implements proto.MultiBulkParse +func zSliceParser(rd *proto.Reader, n int64) (interface{}, error) { + zz := make([]Z, n/2) + for i := int64(0); i < n; i += 2 { + var err error + + z := &zz[i/2] + + z.Member, err = rd.ReadString() + if err != nil { + return nil, err + } + + z.Score, err = rd.ReadFloatReply() + if err != nil { + return nil, err + } + } + return zz, nil +} + +//------------------------------------------------------------------------------ + +type ScanCmd struct { + baseCmd + + page []string + cursor uint64 + + process func(cmd Cmder) error +} + +var _ Cmder = (*ScanCmd)(nil) + +func NewScanCmd(process func(cmd Cmder) error, args ...interface{}) *ScanCmd { + return &ScanCmd{ + baseCmd: baseCmd{_args: args}, + process: process, + } +} + +func (cmd *ScanCmd) Val() (keys []string, cursor uint64) { + return cmd.page, cmd.cursor +} + +func (cmd *ScanCmd) Result() (keys []string, cursor uint64, err error) { + return cmd.page, cmd.cursor, cmd.err +} + +func (cmd *ScanCmd) String() string { + return cmdString(cmd, cmd.page) +} + +func (cmd *ScanCmd) readReply(rd *proto.Reader) error { + cmd.page, cmd.cursor, cmd.err = rd.ReadScanReply() + return cmd.err +} + +// Iterator creates a new ScanIterator. +func (cmd *ScanCmd) Iterator() *ScanIterator { + return &ScanIterator{ + cmd: cmd, + } +} + +//------------------------------------------------------------------------------ + +type ClusterNode struct { + Id string + Addr string +} + +type ClusterSlot struct { + Start int + End int + Nodes []ClusterNode +} + +type ClusterSlotsCmd struct { + baseCmd + + val []ClusterSlot +} + +var _ Cmder = (*ClusterSlotsCmd)(nil) + +func NewClusterSlotsCmd(args ...interface{}) *ClusterSlotsCmd { + return &ClusterSlotsCmd{ + baseCmd: baseCmd{_args: args}, + } +} + +func (cmd *ClusterSlotsCmd) Val() []ClusterSlot { + return cmd.val +} + +func (cmd *ClusterSlotsCmd) Result() ([]ClusterSlot, error) { + return cmd.Val(), cmd.Err() +} + +func (cmd *ClusterSlotsCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *ClusterSlotsCmd) readReply(rd *proto.Reader) error { + var v interface{} + v, cmd.err = rd.ReadArrayReply(clusterSlotsParser) + if cmd.err != nil { + return cmd.err + } + cmd.val = v.([]ClusterSlot) + return nil +} + +// Implements proto.MultiBulkParse +func clusterSlotsParser(rd *proto.Reader, n int64) (interface{}, error) { + slots := make([]ClusterSlot, n) + for i := 0; i < len(slots); i++ { + n, err := rd.ReadArrayLen() + if err != nil { + return nil, err + } + if n < 2 { + err := fmt.Errorf("redis: got %d elements in cluster info, expected at least 2", n) + return nil, err + } + + start, err := rd.ReadIntReply() + if err != nil { + return nil, err + } + + end, err := rd.ReadIntReply() + if err != nil { + return nil, err + } + + nodes := make([]ClusterNode, n-2) + for j := 0; j < len(nodes); j++ { + n, err := rd.ReadArrayLen() + if err != nil { + return nil, err + } + if n != 2 && n != 3 { + err := fmt.Errorf("got %d elements in cluster info address, expected 2 or 3", n) + return nil, err + } + + ip, err := rd.ReadString() + if err != nil { + return nil, err + } + + port, err := rd.ReadString() + if err != nil { + return nil, err + } + + nodes[j].Addr = net.JoinHostPort(ip, port) + + if n == 3 { + id, err := rd.ReadString() + if err != nil { + return nil, err + } + nodes[j].Id = id + } + } + + slots[i] = ClusterSlot{ + Start: int(start), + End: int(end), + Nodes: nodes, + } + } + return slots, nil +} + +//------------------------------------------------------------------------------ + +// GeoLocation is used with GeoAdd to add geospatial location. +type GeoLocation struct { + Name string + Longitude, Latitude, Dist float64 + GeoHash int64 +} + +// GeoRadiusQuery is used with GeoRadius to query geospatial index. +type GeoRadiusQuery struct { + Radius float64 + // Can be m, km, ft, or mi. Default is km. + Unit string + WithCoord bool + WithDist bool + WithGeoHash bool + Count int + // Can be ASC or DESC. Default is no sort order. + Sort string + Store string + StoreDist string +} + +type GeoLocationCmd struct { + baseCmd + + q *GeoRadiusQuery + locations []GeoLocation +} + +var _ Cmder = (*GeoLocationCmd)(nil) + +func NewGeoLocationCmd(q *GeoRadiusQuery, args ...interface{}) *GeoLocationCmd { + args = append(args, q.Radius) + if q.Unit != "" { + args = append(args, q.Unit) + } else { + args = append(args, "km") + } + if q.WithCoord { + args = append(args, "withcoord") + } + if q.WithDist { + args = append(args, "withdist") + } + if q.WithGeoHash { + args = append(args, "withhash") + } + if q.Count > 0 { + args = append(args, "count", q.Count) + } + if q.Sort != "" { + args = append(args, q.Sort) + } + if q.Store != "" { + args = append(args, "store") + args = append(args, q.Store) + } + if q.StoreDist != "" { + args = append(args, "storedist") + args = append(args, q.StoreDist) + } + return &GeoLocationCmd{ + baseCmd: baseCmd{_args: args}, + q: q, + } +} + +func (cmd *GeoLocationCmd) Val() []GeoLocation { + return cmd.locations +} + +func (cmd *GeoLocationCmd) Result() ([]GeoLocation, error) { + return cmd.locations, cmd.err +} + +func (cmd *GeoLocationCmd) String() string { + return cmdString(cmd, cmd.locations) +} + +func (cmd *GeoLocationCmd) readReply(rd *proto.Reader) error { + var v interface{} + v, cmd.err = rd.ReadArrayReply(newGeoLocationSliceParser(cmd.q)) + if cmd.err != nil { + return cmd.err + } + cmd.locations = v.([]GeoLocation) + return nil +} + +func newGeoLocationParser(q *GeoRadiusQuery) proto.MultiBulkParse { + return func(rd *proto.Reader, n int64) (interface{}, error) { + var loc GeoLocation + var err error + + loc.Name, err = rd.ReadString() + if err != nil { + return nil, err + } + if q.WithDist { + loc.Dist, err = rd.ReadFloatReply() + if err != nil { + return nil, err + } + } + if q.WithGeoHash { + loc.GeoHash, err = rd.ReadIntReply() + if err != nil { + return nil, err + } + } + if q.WithCoord { + n, err := rd.ReadArrayLen() + if err != nil { + return nil, err + } + if n != 2 { + return nil, fmt.Errorf("got %d coordinates, expected 2", n) + } + + loc.Longitude, err = rd.ReadFloatReply() + if err != nil { + return nil, err + } + loc.Latitude, err = rd.ReadFloatReply() + if err != nil { + return nil, err + } + } + + return &loc, nil + } +} + +func newGeoLocationSliceParser(q *GeoRadiusQuery) proto.MultiBulkParse { + return func(rd *proto.Reader, n int64) (interface{}, error) { + locs := make([]GeoLocation, 0, n) + for i := int64(0); i < n; i++ { + v, err := rd.ReadReply(newGeoLocationParser(q)) + if err != nil { + return nil, err + } + switch vv := v.(type) { + case string: + locs = append(locs, GeoLocation{ + Name: vv, + }) + case *GeoLocation: + locs = append(locs, *vv) + default: + return nil, fmt.Errorf("got %T, expected string or *GeoLocation", v) + } + } + return locs, nil + } +} + +//------------------------------------------------------------------------------ + +type GeoPos struct { + Longitude, Latitude float64 +} + +type GeoPosCmd struct { + baseCmd + + positions []*GeoPos +} + +var _ Cmder = (*GeoPosCmd)(nil) + +func NewGeoPosCmd(args ...interface{}) *GeoPosCmd { + return &GeoPosCmd{ + baseCmd: baseCmd{_args: args}, + } +} + +func (cmd *GeoPosCmd) Val() []*GeoPos { + return cmd.positions +} + +func (cmd *GeoPosCmd) Result() ([]*GeoPos, error) { + return cmd.Val(), cmd.Err() +} + +func (cmd *GeoPosCmd) String() string { + return cmdString(cmd, cmd.positions) +} + +func (cmd *GeoPosCmd) readReply(rd *proto.Reader) error { + var v interface{} + v, cmd.err = rd.ReadArrayReply(geoPosSliceParser) + if cmd.err != nil { + return cmd.err + } + cmd.positions = v.([]*GeoPos) + return nil +} + +func geoPosSliceParser(rd *proto.Reader, n int64) (interface{}, error) { + positions := make([]*GeoPos, 0, n) + for i := int64(0); i < n; i++ { + v, err := rd.ReadReply(geoPosParser) + if err != nil { + if err == Nil { + positions = append(positions, nil) + continue + } + return nil, err + } + switch v := v.(type) { + case *GeoPos: + positions = append(positions, v) + default: + return nil, fmt.Errorf("got %T, expected *GeoPos", v) + } + } + return positions, nil +} + +func geoPosParser(rd *proto.Reader, n int64) (interface{}, error) { + var pos GeoPos + var err error + + pos.Longitude, err = rd.ReadFloatReply() + if err != nil { + return nil, err + } + + pos.Latitude, err = rd.ReadFloatReply() + if err != nil { + return nil, err + } + + return &pos, nil +} + +//------------------------------------------------------------------------------ + +type CommandInfo struct { + Name string + Arity int8 + Flags []string + FirstKeyPos int8 + LastKeyPos int8 + StepCount int8 + ReadOnly bool +} + +type CommandsInfoCmd struct { + baseCmd + + val map[string]*CommandInfo +} + +var _ Cmder = (*CommandsInfoCmd)(nil) + +func NewCommandsInfoCmd(args ...interface{}) *CommandsInfoCmd { + return &CommandsInfoCmd{ + baseCmd: baseCmd{_args: args}, + } +} + +func (cmd *CommandsInfoCmd) Val() map[string]*CommandInfo { + return cmd.val +} + +func (cmd *CommandsInfoCmd) Result() (map[string]*CommandInfo, error) { + return cmd.Val(), cmd.Err() +} + +func (cmd *CommandsInfoCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *CommandsInfoCmd) readReply(rd *proto.Reader) error { + var v interface{} + v, cmd.err = rd.ReadArrayReply(commandInfoSliceParser) + if cmd.err != nil { + return cmd.err + } + cmd.val = v.(map[string]*CommandInfo) + return nil +} + +// Implements proto.MultiBulkParse +func commandInfoSliceParser(rd *proto.Reader, n int64) (interface{}, error) { + m := make(map[string]*CommandInfo, n) + for i := int64(0); i < n; i++ { + v, err := rd.ReadReply(commandInfoParser) + if err != nil { + return nil, err + } + vv := v.(*CommandInfo) + m[vv.Name] = vv + + } + return m, nil +} + +func commandInfoParser(rd *proto.Reader, n int64) (interface{}, error) { + var cmd CommandInfo + var err error + + if n != 6 { + return nil, fmt.Errorf("redis: got %d elements in COMMAND reply, wanted 6", n) + } + + cmd.Name, err = rd.ReadString() + if err != nil { + return nil, err + } + + arity, err := rd.ReadIntReply() + if err != nil { + return nil, err + } + cmd.Arity = int8(arity) + + flags, err := rd.ReadReply(stringSliceParser) + if err != nil { + return nil, err + } + cmd.Flags = flags.([]string) + + firstKeyPos, err := rd.ReadIntReply() + if err != nil { + return nil, err + } + cmd.FirstKeyPos = int8(firstKeyPos) + + lastKeyPos, err := rd.ReadIntReply() + if err != nil { + return nil, err + } + cmd.LastKeyPos = int8(lastKeyPos) + + stepCount, err := rd.ReadIntReply() + if err != nil { + return nil, err + } + cmd.StepCount = int8(stepCount) + + for _, flag := range cmd.Flags { + if flag == "readonly" { + cmd.ReadOnly = true + break + } + } + + return &cmd, nil +} + +//------------------------------------------------------------------------------ + +type cmdsInfoCache struct { + fn func() (map[string]*CommandInfo, error) + + once internal.Once + cmds map[string]*CommandInfo +} + +func newCmdsInfoCache(fn func() (map[string]*CommandInfo, error)) *cmdsInfoCache { + return &cmdsInfoCache{ + fn: fn, + } +} + +func (c *cmdsInfoCache) Get() (map[string]*CommandInfo, error) { + err := c.once.Do(func() error { + cmds, err := c.fn() + if err != nil { + return err + } + c.cmds = cmds + return nil + }) + return c.cmds, err +} diff --git a/vendor/github.com/go-redis/redis/commands.go b/vendor/github.com/go-redis/redis/commands.go new file mode 100644 index 00000000..e9a8992f --- /dev/null +++ b/vendor/github.com/go-redis/redis/commands.go @@ -0,0 +1,2498 @@ +package redis + +import ( + "errors" + "io" + "time" + + "github.com/go-redis/redis/internal" +) + +func usePrecise(dur time.Duration) bool { + return dur < time.Second || dur%time.Second != 0 +} + +func formatMs(dur time.Duration) int64 { + if dur > 0 && dur < time.Millisecond { + internal.Logf( + "specified duration is %s, but minimal supported value is %s", + dur, time.Millisecond, + ) + } + return int64(dur / time.Millisecond) +} + +func formatSec(dur time.Duration) int64 { + if dur > 0 && dur < time.Second { + internal.Logf( + "specified duration is %s, but minimal supported value is %s", + dur, time.Second, + ) + } + return int64(dur / time.Second) +} + +func appendArgs(dst, src []interface{}) []interface{} { + if len(src) == 1 { + if ss, ok := src[0].([]string); ok { + for _, s := range ss { + dst = append(dst, s) + } + return dst + } + } + + for _, v := range src { + dst = append(dst, v) + } + return dst +} + +type Cmdable interface { + Pipeline() Pipeliner + Pipelined(fn func(Pipeliner) error) ([]Cmder, error) + + TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) + TxPipeline() Pipeliner + + Command() *CommandsInfoCmd + ClientGetName() *StringCmd + Echo(message interface{}) *StringCmd + Ping() *StatusCmd + Quit() *StatusCmd + Del(keys ...string) *IntCmd + Unlink(keys ...string) *IntCmd + Dump(key string) *StringCmd + Exists(keys ...string) *IntCmd + Expire(key string, expiration time.Duration) *BoolCmd + ExpireAt(key string, tm time.Time) *BoolCmd + Keys(pattern string) *StringSliceCmd + Migrate(host, port, key string, db int64, timeout time.Duration) *StatusCmd + Move(key string, db int64) *BoolCmd + ObjectRefCount(key string) *IntCmd + ObjectEncoding(key string) *StringCmd + ObjectIdleTime(key string) *DurationCmd + Persist(key string) *BoolCmd + PExpire(key string, expiration time.Duration) *BoolCmd + PExpireAt(key string, tm time.Time) *BoolCmd + PTTL(key string) *DurationCmd + RandomKey() *StringCmd + Rename(key, newkey string) *StatusCmd + RenameNX(key, newkey string) *BoolCmd + Restore(key string, ttl time.Duration, value string) *StatusCmd + RestoreReplace(key string, ttl time.Duration, value string) *StatusCmd + Sort(key string, sort *Sort) *StringSliceCmd + SortStore(key, store string, sort *Sort) *IntCmd + SortInterfaces(key string, sort *Sort) *SliceCmd + Touch(keys ...string) *IntCmd + TTL(key string) *DurationCmd + Type(key string) *StatusCmd + Scan(cursor uint64, match string, count int64) *ScanCmd + SScan(key string, cursor uint64, match string, count int64) *ScanCmd + HScan(key string, cursor uint64, match string, count int64) *ScanCmd + ZScan(key string, cursor uint64, match string, count int64) *ScanCmd + Append(key, value string) *IntCmd + BitCount(key string, bitCount *BitCount) *IntCmd + BitOpAnd(destKey string, keys ...string) *IntCmd + BitOpOr(destKey string, keys ...string) *IntCmd + BitOpXor(destKey string, keys ...string) *IntCmd + BitOpNot(destKey string, key string) *IntCmd + BitPos(key string, bit int64, pos ...int64) *IntCmd + Decr(key string) *IntCmd + DecrBy(key string, decrement int64) *IntCmd + Get(key string) *StringCmd + GetBit(key string, offset int64) *IntCmd + GetRange(key string, start, end int64) *StringCmd + GetSet(key string, value interface{}) *StringCmd + Incr(key string) *IntCmd + IncrBy(key string, value int64) *IntCmd + IncrByFloat(key string, value float64) *FloatCmd + MGet(keys ...string) *SliceCmd + MSet(pairs ...interface{}) *StatusCmd + MSetNX(pairs ...interface{}) *BoolCmd + Set(key string, value interface{}, expiration time.Duration) *StatusCmd + SetBit(key string, offset int64, value int) *IntCmd + SetNX(key string, value interface{}, expiration time.Duration) *BoolCmd + SetXX(key string, value interface{}, expiration time.Duration) *BoolCmd + SetRange(key string, offset int64, value string) *IntCmd + StrLen(key string) *IntCmd + HDel(key string, fields ...string) *IntCmd + HExists(key, field string) *BoolCmd + HGet(key, field string) *StringCmd + HGetAll(key string) *StringStringMapCmd + HIncrBy(key, field string, incr int64) *IntCmd + HIncrByFloat(key, field string, incr float64) *FloatCmd + HKeys(key string) *StringSliceCmd + HLen(key string) *IntCmd + HMGet(key string, fields ...string) *SliceCmd + HMSet(key string, fields map[string]interface{}) *StatusCmd + HSet(key, field string, value interface{}) *BoolCmd + HSetNX(key, field string, value interface{}) *BoolCmd + HVals(key string) *StringSliceCmd + BLPop(timeout time.Duration, keys ...string) *StringSliceCmd + BRPop(timeout time.Duration, keys ...string) *StringSliceCmd + BRPopLPush(source, destination string, timeout time.Duration) *StringCmd + LIndex(key string, index int64) *StringCmd + LInsert(key, op string, pivot, value interface{}) *IntCmd + LInsertBefore(key string, pivot, value interface{}) *IntCmd + LInsertAfter(key string, pivot, value interface{}) *IntCmd + LLen(key string) *IntCmd + LPop(key string) *StringCmd + LPush(key string, values ...interface{}) *IntCmd + LPushX(key string, value interface{}) *IntCmd + LRange(key string, start, stop int64) *StringSliceCmd + LRem(key string, count int64, value interface{}) *IntCmd + LSet(key string, index int64, value interface{}) *StatusCmd + LTrim(key string, start, stop int64) *StatusCmd + RPop(key string) *StringCmd + RPopLPush(source, destination string) *StringCmd + RPush(key string, values ...interface{}) *IntCmd + RPushX(key string, value interface{}) *IntCmd + SAdd(key string, members ...interface{}) *IntCmd + SCard(key string) *IntCmd + SDiff(keys ...string) *StringSliceCmd + SDiffStore(destination string, keys ...string) *IntCmd + SInter(keys ...string) *StringSliceCmd + SInterStore(destination string, keys ...string) *IntCmd + SIsMember(key string, member interface{}) *BoolCmd + SMembers(key string) *StringSliceCmd + SMembersMap(key string) *StringStructMapCmd + SMove(source, destination string, member interface{}) *BoolCmd + SPop(key string) *StringCmd + SPopN(key string, count int64) *StringSliceCmd + SRandMember(key string) *StringCmd + SRandMemberN(key string, count int64) *StringSliceCmd + SRem(key string, members ...interface{}) *IntCmd + SUnion(keys ...string) *StringSliceCmd + SUnionStore(destination string, keys ...string) *IntCmd + XAdd(a *XAddArgs) *StringCmd + XLen(stream string) *IntCmd + XRange(stream, start, stop string) *XMessageSliceCmd + XRangeN(stream, start, stop string, count int64) *XMessageSliceCmd + XRevRange(stream string, start, stop string) *XMessageSliceCmd + XRevRangeN(stream string, start, stop string, count int64) *XMessageSliceCmd + XRead(a *XReadArgs) *XStreamSliceCmd + XReadStreams(streams ...string) *XStreamSliceCmd + XGroupCreate(stream, group, start string) *StatusCmd + XGroupSetID(stream, group, start string) *StatusCmd + XGroupDestroy(stream, group string) *IntCmd + XGroupDelConsumer(stream, group, consumer string) *IntCmd + XReadGroup(a *XReadGroupArgs) *XStreamSliceCmd + XAck(stream, group string, ids ...string) *IntCmd + XPending(stream, group string) *XPendingCmd + XPendingExt(a *XPendingExtArgs) *XPendingExtCmd + XClaim(a *XClaimArgs) *XMessageSliceCmd + XClaimJustID(a *XClaimArgs) *StringSliceCmd + XTrim(key string, maxLen int64) *IntCmd + XTrimApprox(key string, maxLen int64) *IntCmd + ZAdd(key string, members ...Z) *IntCmd + ZAddNX(key string, members ...Z) *IntCmd + ZAddXX(key string, members ...Z) *IntCmd + ZAddCh(key string, members ...Z) *IntCmd + ZAddNXCh(key string, members ...Z) *IntCmd + ZAddXXCh(key string, members ...Z) *IntCmd + ZIncr(key string, member Z) *FloatCmd + ZIncrNX(key string, member Z) *FloatCmd + ZIncrXX(key string, member Z) *FloatCmd + ZCard(key string) *IntCmd + ZCount(key, min, max string) *IntCmd + ZLexCount(key, min, max string) *IntCmd + ZIncrBy(key string, increment float64, member string) *FloatCmd + ZInterStore(destination string, store ZStore, keys ...string) *IntCmd + ZPopMax(key string, count ...int64) *ZSliceCmd + ZPopMin(key string, count ...int64) *ZSliceCmd + ZRange(key string, start, stop int64) *StringSliceCmd + ZRangeWithScores(key string, start, stop int64) *ZSliceCmd + ZRangeByScore(key string, opt ZRangeBy) *StringSliceCmd + ZRangeByLex(key string, opt ZRangeBy) *StringSliceCmd + ZRangeByScoreWithScores(key string, opt ZRangeBy) *ZSliceCmd + ZRank(key, member string) *IntCmd + ZRem(key string, members ...interface{}) *IntCmd + ZRemRangeByRank(key string, start, stop int64) *IntCmd + ZRemRangeByScore(key, min, max string) *IntCmd + ZRemRangeByLex(key, min, max string) *IntCmd + ZRevRange(key string, start, stop int64) *StringSliceCmd + ZRevRangeWithScores(key string, start, stop int64) *ZSliceCmd + ZRevRangeByScore(key string, opt ZRangeBy) *StringSliceCmd + ZRevRangeByLex(key string, opt ZRangeBy) *StringSliceCmd + ZRevRangeByScoreWithScores(key string, opt ZRangeBy) *ZSliceCmd + ZRevRank(key, member string) *IntCmd + ZScore(key, member string) *FloatCmd + ZUnionStore(dest string, store ZStore, keys ...string) *IntCmd + PFAdd(key string, els ...interface{}) *IntCmd + PFCount(keys ...string) *IntCmd + PFMerge(dest string, keys ...string) *StatusCmd + BgRewriteAOF() *StatusCmd + BgSave() *StatusCmd + ClientKill(ipPort string) *StatusCmd + ClientKillByFilter(keys ...string) *IntCmd + ClientList() *StringCmd + ClientPause(dur time.Duration) *BoolCmd + ConfigGet(parameter string) *SliceCmd + ConfigResetStat() *StatusCmd + ConfigSet(parameter, value string) *StatusCmd + ConfigRewrite() *StatusCmd + DBSize() *IntCmd + FlushAll() *StatusCmd + FlushAllAsync() *StatusCmd + FlushDB() *StatusCmd + FlushDBAsync() *StatusCmd + Info(section ...string) *StringCmd + LastSave() *IntCmd + Save() *StatusCmd + Shutdown() *StatusCmd + ShutdownSave() *StatusCmd + ShutdownNoSave() *StatusCmd + SlaveOf(host, port string) *StatusCmd + Time() *TimeCmd + Eval(script string, keys []string, args ...interface{}) *Cmd + EvalSha(sha1 string, keys []string, args ...interface{}) *Cmd + ScriptExists(hashes ...string) *BoolSliceCmd + ScriptFlush() *StatusCmd + ScriptKill() *StatusCmd + ScriptLoad(script string) *StringCmd + DebugObject(key string) *StringCmd + Publish(channel string, message interface{}) *IntCmd + PubSubChannels(pattern string) *StringSliceCmd + PubSubNumSub(channels ...string) *StringIntMapCmd + PubSubNumPat() *IntCmd + ClusterSlots() *ClusterSlotsCmd + ClusterNodes() *StringCmd + ClusterMeet(host, port string) *StatusCmd + ClusterForget(nodeID string) *StatusCmd + ClusterReplicate(nodeID string) *StatusCmd + ClusterResetSoft() *StatusCmd + ClusterResetHard() *StatusCmd + ClusterInfo() *StringCmd + ClusterKeySlot(key string) *IntCmd + ClusterCountFailureReports(nodeID string) *IntCmd + ClusterCountKeysInSlot(slot int) *IntCmd + ClusterDelSlots(slots ...int) *StatusCmd + ClusterDelSlotsRange(min, max int) *StatusCmd + ClusterSaveConfig() *StatusCmd + ClusterSlaves(nodeID string) *StringSliceCmd + ClusterFailover() *StatusCmd + ClusterAddSlots(slots ...int) *StatusCmd + ClusterAddSlotsRange(min, max int) *StatusCmd + GeoAdd(key string, geoLocation ...*GeoLocation) *IntCmd + GeoPos(key string, members ...string) *GeoPosCmd + GeoRadius(key string, longitude, latitude float64, query *GeoRadiusQuery) *GeoLocationCmd + GeoRadiusRO(key string, longitude, latitude float64, query *GeoRadiusQuery) *GeoLocationCmd + GeoRadiusByMember(key, member string, query *GeoRadiusQuery) *GeoLocationCmd + GeoRadiusByMemberRO(key, member string, query *GeoRadiusQuery) *GeoLocationCmd + GeoDist(key string, member1, member2, unit string) *FloatCmd + GeoHash(key string, members ...string) *StringSliceCmd + ReadOnly() *StatusCmd + ReadWrite() *StatusCmd + MemoryUsage(key string, samples ...int) *IntCmd +} + +type StatefulCmdable interface { + Cmdable + Auth(password string) *StatusCmd + Select(index int) *StatusCmd + SwapDB(index1, index2 int) *StatusCmd + ClientSetName(name string) *BoolCmd +} + +var _ Cmdable = (*Client)(nil) +var _ Cmdable = (*Tx)(nil) +var _ Cmdable = (*Ring)(nil) +var _ Cmdable = (*ClusterClient)(nil) + +type cmdable struct { + process func(cmd Cmder) error +} + +func (c *cmdable) setProcessor(fn func(Cmder) error) { + c.process = fn +} + +type statefulCmdable struct { + cmdable + process func(cmd Cmder) error +} + +func (c *statefulCmdable) setProcessor(fn func(Cmder) error) { + c.process = fn + c.cmdable.setProcessor(fn) +} + +//------------------------------------------------------------------------------ + +func (c *statefulCmdable) Auth(password string) *StatusCmd { + cmd := NewStatusCmd("auth", password) + c.process(cmd) + return cmd +} + +func (c *cmdable) Echo(message interface{}) *StringCmd { + cmd := NewStringCmd("echo", message) + c.process(cmd) + return cmd +} + +func (c *cmdable) Ping() *StatusCmd { + cmd := NewStatusCmd("ping") + c.process(cmd) + return cmd +} + +func (c *cmdable) Wait(numSlaves int, timeout time.Duration) *IntCmd { + cmd := NewIntCmd("wait", numSlaves, int(timeout/time.Millisecond)) + c.process(cmd) + return cmd +} + +func (c *cmdable) Quit() *StatusCmd { + panic("not implemented") +} + +func (c *statefulCmdable) Select(index int) *StatusCmd { + cmd := NewStatusCmd("select", index) + c.process(cmd) + return cmd +} + +func (c *statefulCmdable) SwapDB(index1, index2 int) *StatusCmd { + cmd := NewStatusCmd("swapdb", index1, index2) + c.process(cmd) + return cmd +} + +//------------------------------------------------------------------------------ + +func (c *cmdable) Command() *CommandsInfoCmd { + cmd := NewCommandsInfoCmd("command") + c.process(cmd) + return cmd +} + +func (c *cmdable) Del(keys ...string) *IntCmd { + args := make([]interface{}, 1+len(keys)) + args[0] = "del" + for i, key := range keys { + args[1+i] = key + } + cmd := NewIntCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) Unlink(keys ...string) *IntCmd { + args := make([]interface{}, 1+len(keys)) + args[0] = "unlink" + for i, key := range keys { + args[1+i] = key + } + cmd := NewIntCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) Dump(key string) *StringCmd { + cmd := NewStringCmd("dump", key) + c.process(cmd) + return cmd +} + +func (c *cmdable) Exists(keys ...string) *IntCmd { + args := make([]interface{}, 1+len(keys)) + args[0] = "exists" + for i, key := range keys { + args[1+i] = key + } + cmd := NewIntCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) Expire(key string, expiration time.Duration) *BoolCmd { + cmd := NewBoolCmd("expire", key, formatSec(expiration)) + c.process(cmd) + return cmd +} + +func (c *cmdable) ExpireAt(key string, tm time.Time) *BoolCmd { + cmd := NewBoolCmd("expireat", key, tm.Unix()) + c.process(cmd) + return cmd +} + +func (c *cmdable) Keys(pattern string) *StringSliceCmd { + cmd := NewStringSliceCmd("keys", pattern) + c.process(cmd) + return cmd +} + +func (c *cmdable) Migrate(host, port, key string, db int64, timeout time.Duration) *StatusCmd { + cmd := NewStatusCmd( + "migrate", + host, + port, + key, + db, + formatMs(timeout), + ) + cmd.setReadTimeout(timeout) + c.process(cmd) + return cmd +} + +func (c *cmdable) Move(key string, db int64) *BoolCmd { + cmd := NewBoolCmd("move", key, db) + c.process(cmd) + return cmd +} + +func (c *cmdable) ObjectRefCount(key string) *IntCmd { + cmd := NewIntCmd("object", "refcount", key) + c.process(cmd) + return cmd +} + +func (c *cmdable) ObjectEncoding(key string) *StringCmd { + cmd := NewStringCmd("object", "encoding", key) + c.process(cmd) + return cmd +} + +func (c *cmdable) ObjectIdleTime(key string) *DurationCmd { + cmd := NewDurationCmd(time.Second, "object", "idletime", key) + c.process(cmd) + return cmd +} + +func (c *cmdable) Persist(key string) *BoolCmd { + cmd := NewBoolCmd("persist", key) + c.process(cmd) + return cmd +} + +func (c *cmdable) PExpire(key string, expiration time.Duration) *BoolCmd { + cmd := NewBoolCmd("pexpire", key, formatMs(expiration)) + c.process(cmd) + return cmd +} + +func (c *cmdable) PExpireAt(key string, tm time.Time) *BoolCmd { + cmd := NewBoolCmd( + "pexpireat", + key, + tm.UnixNano()/int64(time.Millisecond), + ) + c.process(cmd) + return cmd +} + +func (c *cmdable) PTTL(key string) *DurationCmd { + cmd := NewDurationCmd(time.Millisecond, "pttl", key) + c.process(cmd) + return cmd +} + +func (c *cmdable) RandomKey() *StringCmd { + cmd := NewStringCmd("randomkey") + c.process(cmd) + return cmd +} + +func (c *cmdable) Rename(key, newkey string) *StatusCmd { + cmd := NewStatusCmd("rename", key, newkey) + c.process(cmd) + return cmd +} + +func (c *cmdable) RenameNX(key, newkey string) *BoolCmd { + cmd := NewBoolCmd("renamenx", key, newkey) + c.process(cmd) + return cmd +} + +func (c *cmdable) Restore(key string, ttl time.Duration, value string) *StatusCmd { + cmd := NewStatusCmd( + "restore", + key, + formatMs(ttl), + value, + ) + c.process(cmd) + return cmd +} + +func (c *cmdable) RestoreReplace(key string, ttl time.Duration, value string) *StatusCmd { + cmd := NewStatusCmd( + "restore", + key, + formatMs(ttl), + value, + "replace", + ) + c.process(cmd) + return cmd +} + +type Sort struct { + By string + Offset, Count int64 + Get []string + Order string + Alpha bool +} + +func (sort *Sort) args(key string) []interface{} { + args := []interface{}{"sort", key} + if sort.By != "" { + args = append(args, "by", sort.By) + } + if sort.Offset != 0 || sort.Count != 0 { + args = append(args, "limit", sort.Offset, sort.Count) + } + for _, get := range sort.Get { + args = append(args, "get", get) + } + if sort.Order != "" { + args = append(args, sort.Order) + } + if sort.Alpha { + args = append(args, "alpha") + } + return args +} + +func (c *cmdable) Sort(key string, sort *Sort) *StringSliceCmd { + cmd := NewStringSliceCmd(sort.args(key)...) + c.process(cmd) + return cmd +} + +func (c *cmdable) SortStore(key, store string, sort *Sort) *IntCmd { + args := sort.args(key) + if store != "" { + args = append(args, "store", store) + } + cmd := NewIntCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) SortInterfaces(key string, sort *Sort) *SliceCmd { + cmd := NewSliceCmd(sort.args(key)...) + c.process(cmd) + return cmd +} + +func (c *cmdable) Touch(keys ...string) *IntCmd { + args := make([]interface{}, len(keys)+1) + args[0] = "touch" + for i, key := range keys { + args[i+1] = key + } + cmd := NewIntCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) TTL(key string) *DurationCmd { + cmd := NewDurationCmd(time.Second, "ttl", key) + c.process(cmd) + return cmd +} + +func (c *cmdable) Type(key string) *StatusCmd { + cmd := NewStatusCmd("type", key) + c.process(cmd) + return cmd +} + +func (c *cmdable) Scan(cursor uint64, match string, count int64) *ScanCmd { + args := []interface{}{"scan", cursor} + if match != "" { + args = append(args, "match", match) + } + if count > 0 { + args = append(args, "count", count) + } + cmd := NewScanCmd(c.process, args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) SScan(key string, cursor uint64, match string, count int64) *ScanCmd { + args := []interface{}{"sscan", key, cursor} + if match != "" { + args = append(args, "match", match) + } + if count > 0 { + args = append(args, "count", count) + } + cmd := NewScanCmd(c.process, args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) HScan(key string, cursor uint64, match string, count int64) *ScanCmd { + args := []interface{}{"hscan", key, cursor} + if match != "" { + args = append(args, "match", match) + } + if count > 0 { + args = append(args, "count", count) + } + cmd := NewScanCmd(c.process, args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) ZScan(key string, cursor uint64, match string, count int64) *ScanCmd { + args := []interface{}{"zscan", key, cursor} + if match != "" { + args = append(args, "match", match) + } + if count > 0 { + args = append(args, "count", count) + } + cmd := NewScanCmd(c.process, args...) + c.process(cmd) + return cmd +} + +//------------------------------------------------------------------------------ + +func (c *cmdable) Append(key, value string) *IntCmd { + cmd := NewIntCmd("append", key, value) + c.process(cmd) + return cmd +} + +type BitCount struct { + Start, End int64 +} + +func (c *cmdable) BitCount(key string, bitCount *BitCount) *IntCmd { + args := []interface{}{"bitcount", key} + if bitCount != nil { + args = append( + args, + bitCount.Start, + bitCount.End, + ) + } + cmd := NewIntCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) bitOp(op, destKey string, keys ...string) *IntCmd { + args := make([]interface{}, 3+len(keys)) + args[0] = "bitop" + args[1] = op + args[2] = destKey + for i, key := range keys { + args[3+i] = key + } + cmd := NewIntCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) BitOpAnd(destKey string, keys ...string) *IntCmd { + return c.bitOp("and", destKey, keys...) +} + +func (c *cmdable) BitOpOr(destKey string, keys ...string) *IntCmd { + return c.bitOp("or", destKey, keys...) +} + +func (c *cmdable) BitOpXor(destKey string, keys ...string) *IntCmd { + return c.bitOp("xor", destKey, keys...) +} + +func (c *cmdable) BitOpNot(destKey string, key string) *IntCmd { + return c.bitOp("not", destKey, key) +} + +func (c *cmdable) BitPos(key string, bit int64, pos ...int64) *IntCmd { + args := make([]interface{}, 3+len(pos)) + args[0] = "bitpos" + args[1] = key + args[2] = bit + switch len(pos) { + case 0: + case 1: + args[3] = pos[0] + case 2: + args[3] = pos[0] + args[4] = pos[1] + default: + panic("too many arguments") + } + cmd := NewIntCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) Decr(key string) *IntCmd { + cmd := NewIntCmd("decr", key) + c.process(cmd) + return cmd +} + +func (c *cmdable) DecrBy(key string, decrement int64) *IntCmd { + cmd := NewIntCmd("decrby", key, decrement) + c.process(cmd) + return cmd +} + +// Redis `GET key` command. It returns redis.Nil error when key does not exist. +func (c *cmdable) Get(key string) *StringCmd { + cmd := NewStringCmd("get", key) + c.process(cmd) + return cmd +} + +func (c *cmdable) GetBit(key string, offset int64) *IntCmd { + cmd := NewIntCmd("getbit", key, offset) + c.process(cmd) + return cmd +} + +func (c *cmdable) GetRange(key string, start, end int64) *StringCmd { + cmd := NewStringCmd("getrange", key, start, end) + c.process(cmd) + return cmd +} + +func (c *cmdable) GetSet(key string, value interface{}) *StringCmd { + cmd := NewStringCmd("getset", key, value) + c.process(cmd) + return cmd +} + +func (c *cmdable) Incr(key string) *IntCmd { + cmd := NewIntCmd("incr", key) + c.process(cmd) + return cmd +} + +func (c *cmdable) IncrBy(key string, value int64) *IntCmd { + cmd := NewIntCmd("incrby", key, value) + c.process(cmd) + return cmd +} + +func (c *cmdable) IncrByFloat(key string, value float64) *FloatCmd { + cmd := NewFloatCmd("incrbyfloat", key, value) + c.process(cmd) + return cmd +} + +func (c *cmdable) MGet(keys ...string) *SliceCmd { + args := make([]interface{}, 1+len(keys)) + args[0] = "mget" + for i, key := range keys { + args[1+i] = key + } + cmd := NewSliceCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) MSet(pairs ...interface{}) *StatusCmd { + args := make([]interface{}, 1, 1+len(pairs)) + args[0] = "mset" + args = appendArgs(args, pairs) + cmd := NewStatusCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) MSetNX(pairs ...interface{}) *BoolCmd { + args := make([]interface{}, 1, 1+len(pairs)) + args[0] = "msetnx" + args = appendArgs(args, pairs) + cmd := NewBoolCmd(args...) + c.process(cmd) + return cmd +} + +// Redis `SET key value [expiration]` command. +// +// Use expiration for `SETEX`-like behavior. +// Zero expiration means the key has no expiration time. +func (c *cmdable) Set(key string, value interface{}, expiration time.Duration) *StatusCmd { + args := make([]interface{}, 3, 4) + args[0] = "set" + args[1] = key + args[2] = value + if expiration > 0 { + if usePrecise(expiration) { + args = append(args, "px", formatMs(expiration)) + } else { + args = append(args, "ex", formatSec(expiration)) + } + } + cmd := NewStatusCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) SetBit(key string, offset int64, value int) *IntCmd { + cmd := NewIntCmd( + "setbit", + key, + offset, + value, + ) + c.process(cmd) + return cmd +} + +// Redis `SET key value [expiration] NX` command. +// +// Zero expiration means the key has no expiration time. +func (c *cmdable) SetNX(key string, value interface{}, expiration time.Duration) *BoolCmd { + var cmd *BoolCmd + if expiration == 0 { + // Use old `SETNX` to support old Redis versions. + cmd = NewBoolCmd("setnx", key, value) + } else { + if usePrecise(expiration) { + cmd = NewBoolCmd("set", key, value, "px", formatMs(expiration), "nx") + } else { + cmd = NewBoolCmd("set", key, value, "ex", formatSec(expiration), "nx") + } + } + c.process(cmd) + return cmd +} + +// Redis `SET key value [expiration] XX` command. +// +// Zero expiration means the key has no expiration time. +func (c *cmdable) SetXX(key string, value interface{}, expiration time.Duration) *BoolCmd { + var cmd *BoolCmd + if expiration == 0 { + cmd = NewBoolCmd("set", key, value, "xx") + } else { + if usePrecise(expiration) { + cmd = NewBoolCmd("set", key, value, "px", formatMs(expiration), "xx") + } else { + cmd = NewBoolCmd("set", key, value, "ex", formatSec(expiration), "xx") + } + } + c.process(cmd) + return cmd +} + +func (c *cmdable) SetRange(key string, offset int64, value string) *IntCmd { + cmd := NewIntCmd("setrange", key, offset, value) + c.process(cmd) + return cmd +} + +func (c *cmdable) StrLen(key string) *IntCmd { + cmd := NewIntCmd("strlen", key) + c.process(cmd) + return cmd +} + +//------------------------------------------------------------------------------ + +func (c *cmdable) HDel(key string, fields ...string) *IntCmd { + args := make([]interface{}, 2+len(fields)) + args[0] = "hdel" + args[1] = key + for i, field := range fields { + args[2+i] = field + } + cmd := NewIntCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) HExists(key, field string) *BoolCmd { + cmd := NewBoolCmd("hexists", key, field) + c.process(cmd) + return cmd +} + +func (c *cmdable) HGet(key, field string) *StringCmd { + cmd := NewStringCmd("hget", key, field) + c.process(cmd) + return cmd +} + +func (c *cmdable) HGetAll(key string) *StringStringMapCmd { + cmd := NewStringStringMapCmd("hgetall", key) + c.process(cmd) + return cmd +} + +func (c *cmdable) HIncrBy(key, field string, incr int64) *IntCmd { + cmd := NewIntCmd("hincrby", key, field, incr) + c.process(cmd) + return cmd +} + +func (c *cmdable) HIncrByFloat(key, field string, incr float64) *FloatCmd { + cmd := NewFloatCmd("hincrbyfloat", key, field, incr) + c.process(cmd) + return cmd +} + +func (c *cmdable) HKeys(key string) *StringSliceCmd { + cmd := NewStringSliceCmd("hkeys", key) + c.process(cmd) + return cmd +} + +func (c *cmdable) HLen(key string) *IntCmd { + cmd := NewIntCmd("hlen", key) + c.process(cmd) + return cmd +} + +func (c *cmdable) HMGet(key string, fields ...string) *SliceCmd { + args := make([]interface{}, 2+len(fields)) + args[0] = "hmget" + args[1] = key + for i, field := range fields { + args[2+i] = field + } + cmd := NewSliceCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) HMSet(key string, fields map[string]interface{}) *StatusCmd { + args := make([]interface{}, 2+len(fields)*2) + args[0] = "hmset" + args[1] = key + i := 2 + for k, v := range fields { + args[i] = k + args[i+1] = v + i += 2 + } + cmd := NewStatusCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) HSet(key, field string, value interface{}) *BoolCmd { + cmd := NewBoolCmd("hset", key, field, value) + c.process(cmd) + return cmd +} + +func (c *cmdable) HSetNX(key, field string, value interface{}) *BoolCmd { + cmd := NewBoolCmd("hsetnx", key, field, value) + c.process(cmd) + return cmd +} + +func (c *cmdable) HVals(key string) *StringSliceCmd { + cmd := NewStringSliceCmd("hvals", key) + c.process(cmd) + return cmd +} + +//------------------------------------------------------------------------------ + +func (c *cmdable) BLPop(timeout time.Duration, keys ...string) *StringSliceCmd { + args := make([]interface{}, 1+len(keys)+1) + args[0] = "blpop" + for i, key := range keys { + args[1+i] = key + } + args[len(args)-1] = formatSec(timeout) + cmd := NewStringSliceCmd(args...) + cmd.setReadTimeout(timeout) + c.process(cmd) + return cmd +} + +func (c *cmdable) BRPop(timeout time.Duration, keys ...string) *StringSliceCmd { + args := make([]interface{}, 1+len(keys)+1) + args[0] = "brpop" + for i, key := range keys { + args[1+i] = key + } + args[len(keys)+1] = formatSec(timeout) + cmd := NewStringSliceCmd(args...) + cmd.setReadTimeout(timeout) + c.process(cmd) + return cmd +} + +func (c *cmdable) BRPopLPush(source, destination string, timeout time.Duration) *StringCmd { + cmd := NewStringCmd( + "brpoplpush", + source, + destination, + formatSec(timeout), + ) + cmd.setReadTimeout(timeout) + c.process(cmd) + return cmd +} + +func (c *cmdable) LIndex(key string, index int64) *StringCmd { + cmd := NewStringCmd("lindex", key, index) + c.process(cmd) + return cmd +} + +func (c *cmdable) LInsert(key, op string, pivot, value interface{}) *IntCmd { + cmd := NewIntCmd("linsert", key, op, pivot, value) + c.process(cmd) + return cmd +} + +func (c *cmdable) LInsertBefore(key string, pivot, value interface{}) *IntCmd { + cmd := NewIntCmd("linsert", key, "before", pivot, value) + c.process(cmd) + return cmd +} + +func (c *cmdable) LInsertAfter(key string, pivot, value interface{}) *IntCmd { + cmd := NewIntCmd("linsert", key, "after", pivot, value) + c.process(cmd) + return cmd +} + +func (c *cmdable) LLen(key string) *IntCmd { + cmd := NewIntCmd("llen", key) + c.process(cmd) + return cmd +} + +func (c *cmdable) LPop(key string) *StringCmd { + cmd := NewStringCmd("lpop", key) + c.process(cmd) + return cmd +} + +func (c *cmdable) LPush(key string, values ...interface{}) *IntCmd { + args := make([]interface{}, 2, 2+len(values)) + args[0] = "lpush" + args[1] = key + args = appendArgs(args, values) + cmd := NewIntCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) LPushX(key string, value interface{}) *IntCmd { + cmd := NewIntCmd("lpushx", key, value) + c.process(cmd) + return cmd +} + +func (c *cmdable) LRange(key string, start, stop int64) *StringSliceCmd { + cmd := NewStringSliceCmd( + "lrange", + key, + start, + stop, + ) + c.process(cmd) + return cmd +} + +func (c *cmdable) LRem(key string, count int64, value interface{}) *IntCmd { + cmd := NewIntCmd("lrem", key, count, value) + c.process(cmd) + return cmd +} + +func (c *cmdable) LSet(key string, index int64, value interface{}) *StatusCmd { + cmd := NewStatusCmd("lset", key, index, value) + c.process(cmd) + return cmd +} + +func (c *cmdable) LTrim(key string, start, stop int64) *StatusCmd { + cmd := NewStatusCmd( + "ltrim", + key, + start, + stop, + ) + c.process(cmd) + return cmd +} + +func (c *cmdable) RPop(key string) *StringCmd { + cmd := NewStringCmd("rpop", key) + c.process(cmd) + return cmd +} + +func (c *cmdable) RPopLPush(source, destination string) *StringCmd { + cmd := NewStringCmd("rpoplpush", source, destination) + c.process(cmd) + return cmd +} + +func (c *cmdable) RPush(key string, values ...interface{}) *IntCmd { + args := make([]interface{}, 2, 2+len(values)) + args[0] = "rpush" + args[1] = key + args = appendArgs(args, values) + cmd := NewIntCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) RPushX(key string, value interface{}) *IntCmd { + cmd := NewIntCmd("rpushx", key, value) + c.process(cmd) + return cmd +} + +//------------------------------------------------------------------------------ + +func (c *cmdable) SAdd(key string, members ...interface{}) *IntCmd { + args := make([]interface{}, 2, 2+len(members)) + args[0] = "sadd" + args[1] = key + args = appendArgs(args, members) + cmd := NewIntCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) SCard(key string) *IntCmd { + cmd := NewIntCmd("scard", key) + c.process(cmd) + return cmd +} + +func (c *cmdable) SDiff(keys ...string) *StringSliceCmd { + args := make([]interface{}, 1+len(keys)) + args[0] = "sdiff" + for i, key := range keys { + args[1+i] = key + } + cmd := NewStringSliceCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) SDiffStore(destination string, keys ...string) *IntCmd { + args := make([]interface{}, 2+len(keys)) + args[0] = "sdiffstore" + args[1] = destination + for i, key := range keys { + args[2+i] = key + } + cmd := NewIntCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) SInter(keys ...string) *StringSliceCmd { + args := make([]interface{}, 1+len(keys)) + args[0] = "sinter" + for i, key := range keys { + args[1+i] = key + } + cmd := NewStringSliceCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) SInterStore(destination string, keys ...string) *IntCmd { + args := make([]interface{}, 2+len(keys)) + args[0] = "sinterstore" + args[1] = destination + for i, key := range keys { + args[2+i] = key + } + cmd := NewIntCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) SIsMember(key string, member interface{}) *BoolCmd { + cmd := NewBoolCmd("sismember", key, member) + c.process(cmd) + return cmd +} + +// Redis `SMEMBERS key` command output as a slice +func (c *cmdable) SMembers(key string) *StringSliceCmd { + cmd := NewStringSliceCmd("smembers", key) + c.process(cmd) + return cmd +} + +// Redis `SMEMBERS key` command output as a map +func (c *cmdable) SMembersMap(key string) *StringStructMapCmd { + cmd := NewStringStructMapCmd("smembers", key) + c.process(cmd) + return cmd +} + +func (c *cmdable) SMove(source, destination string, member interface{}) *BoolCmd { + cmd := NewBoolCmd("smove", source, destination, member) + c.process(cmd) + return cmd +} + +// Redis `SPOP key` command. +func (c *cmdable) SPop(key string) *StringCmd { + cmd := NewStringCmd("spop", key) + c.process(cmd) + return cmd +} + +// Redis `SPOP key count` command. +func (c *cmdable) SPopN(key string, count int64) *StringSliceCmd { + cmd := NewStringSliceCmd("spop", key, count) + c.process(cmd) + return cmd +} + +// Redis `SRANDMEMBER key` command. +func (c *cmdable) SRandMember(key string) *StringCmd { + cmd := NewStringCmd("srandmember", key) + c.process(cmd) + return cmd +} + +// Redis `SRANDMEMBER key count` command. +func (c *cmdable) SRandMemberN(key string, count int64) *StringSliceCmd { + cmd := NewStringSliceCmd("srandmember", key, count) + c.process(cmd) + return cmd +} + +func (c *cmdable) SRem(key string, members ...interface{}) *IntCmd { + args := make([]interface{}, 2, 2+len(members)) + args[0] = "srem" + args[1] = key + args = appendArgs(args, members) + cmd := NewIntCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) SUnion(keys ...string) *StringSliceCmd { + args := make([]interface{}, 1+len(keys)) + args[0] = "sunion" + for i, key := range keys { + args[1+i] = key + } + cmd := NewStringSliceCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) SUnionStore(destination string, keys ...string) *IntCmd { + args := make([]interface{}, 2+len(keys)) + args[0] = "sunionstore" + args[1] = destination + for i, key := range keys { + args[2+i] = key + } + cmd := NewIntCmd(args...) + c.process(cmd) + return cmd +} + +//------------------------------------------------------------------------------ + +type XAddArgs struct { + Stream string + MaxLen int64 // MAXLEN N + MaxLenApprox int64 // MAXLEN ~ N + ID string + Values map[string]interface{} +} + +func (c *cmdable) XAdd(a *XAddArgs) *StringCmd { + args := make([]interface{}, 0, 6+len(a.Values)*2) + args = append(args, "xadd") + args = append(args, a.Stream) + if a.MaxLen > 0 { + args = append(args, "maxlen", a.MaxLen) + } else if a.MaxLenApprox > 0 { + args = append(args, "maxlen", "~", a.MaxLenApprox) + } + if a.ID != "" { + args = append(args, a.ID) + } else { + args = append(args, "*") + } + for k, v := range a.Values { + args = append(args, k) + args = append(args, v) + } + + cmd := NewStringCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) XLen(stream string) *IntCmd { + cmd := NewIntCmd("xlen", stream) + c.process(cmd) + return cmd +} + +func (c *cmdable) XRange(stream, start, stop string) *XMessageSliceCmd { + cmd := NewXMessageSliceCmd("xrange", stream, start, stop) + c.process(cmd) + return cmd +} + +func (c *cmdable) XRangeN(stream, start, stop string, count int64) *XMessageSliceCmd { + cmd := NewXMessageSliceCmd("xrange", stream, start, stop, "count", count) + c.process(cmd) + return cmd +} + +func (c *cmdable) XRevRange(stream, start, stop string) *XMessageSliceCmd { + cmd := NewXMessageSliceCmd("xrevrange", stream, start, stop) + c.process(cmd) + return cmd +} + +func (c *cmdable) XRevRangeN(stream, start, stop string, count int64) *XMessageSliceCmd { + cmd := NewXMessageSliceCmd("xrevrange", stream, start, stop, "count", count) + c.process(cmd) + return cmd +} + +type XReadArgs struct { + Streams []string + Count int64 + Block time.Duration +} + +func (c *cmdable) XRead(a *XReadArgs) *XStreamSliceCmd { + args := make([]interface{}, 0, 5+len(a.Streams)) + args = append(args, "xread") + if a.Count > 0 { + args = append(args, "count") + args = append(args, a.Count) + } + if a.Block >= 0 { + args = append(args, "block") + args = append(args, int64(a.Block/time.Millisecond)) + } + args = append(args, "streams") + for _, s := range a.Streams { + args = append(args, s) + } + + cmd := NewXStreamSliceCmd(args...) + if a.Block >= 0 { + cmd.setReadTimeout(a.Block) + } + c.process(cmd) + return cmd +} + +func (c *cmdable) XReadStreams(streams ...string) *XStreamSliceCmd { + return c.XRead(&XReadArgs{ + Streams: streams, + Block: -1, + }) +} + +func (c *cmdable) XGroupCreate(stream, group, start string) *StatusCmd { + cmd := NewStatusCmd("xgroup", "create", stream, group, start) + c.process(cmd) + return cmd +} + +func (c *cmdable) XGroupSetID(stream, group, start string) *StatusCmd { + cmd := NewStatusCmd("xgroup", "setid", stream, group, start) + c.process(cmd) + return cmd +} + +func (c *cmdable) XGroupDestroy(stream, group string) *IntCmd { + cmd := NewIntCmd("xgroup", "destroy", stream, group) + c.process(cmd) + return cmd +} + +func (c *cmdable) XGroupDelConsumer(stream, group, consumer string) *IntCmd { + cmd := NewIntCmd("xgroup", "delconsumer", stream, group, consumer) + c.process(cmd) + return cmd +} + +type XReadGroupArgs struct { + Group string + Consumer string + Streams []string + Count int64 + Block time.Duration +} + +func (c *cmdable) XReadGroup(a *XReadGroupArgs) *XStreamSliceCmd { + args := make([]interface{}, 0, 8+len(a.Streams)) + args = append(args, "xreadgroup", "group", a.Group, a.Consumer) + if a.Count > 0 { + args = append(args, "count", a.Count) + } + if a.Block >= 0 { + args = append(args, "block", int64(a.Block/time.Millisecond)) + } + args = append(args, "streams") + for _, s := range a.Streams { + args = append(args, s) + } + + cmd := NewXStreamSliceCmd(args...) + if a.Block >= 0 { + cmd.setReadTimeout(a.Block) + } + c.process(cmd) + return cmd +} + +func (c *cmdable) XAck(stream, group string, ids ...string) *IntCmd { + args := []interface{}{"xack", stream, group} + for _, id := range ids { + args = append(args, id) + } + cmd := NewIntCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) XPending(stream, group string) *XPendingCmd { + cmd := NewXPendingCmd("xpending", stream, group) + c.process(cmd) + return cmd +} + +type XPendingExtArgs struct { + Stream string + Group string + Start string + End string + Count int64 + Consumer string +} + +func (c *cmdable) XPendingExt(a *XPendingExtArgs) *XPendingExtCmd { + args := make([]interface{}, 0, 7) + args = append(args, "xpending", a.Stream, a.Group, a.Start, a.End, a.Count) + if a.Consumer != "" { + args = append(args, a.Consumer) + } + cmd := NewXPendingExtCmd(args...) + c.process(cmd) + return cmd +} + +type XClaimArgs struct { + Stream string + Group string + Consumer string + MinIdle time.Duration + Messages []string +} + +func (c *cmdable) XClaim(a *XClaimArgs) *XMessageSliceCmd { + args := xClaimArgs(a) + cmd := NewXMessageSliceCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) XClaimJustID(a *XClaimArgs) *StringSliceCmd { + args := xClaimArgs(a) + args = append(args, "justid") + cmd := NewStringSliceCmd(args...) + c.process(cmd) + return cmd +} + +func xClaimArgs(a *XClaimArgs) []interface{} { + args := make([]interface{}, 0, 4+len(a.Messages)) + args = append(args, + "xclaim", + a.Stream, + a.Group, a.Consumer, + int64(a.MinIdle/time.Millisecond)) + for _, id := range a.Messages { + args = append(args, id) + } + return args +} + +func (c *cmdable) XTrim(key string, maxLen int64) *IntCmd { + cmd := NewIntCmd("xtrim", key, "maxlen", maxLen) + c.process(cmd) + return cmd +} + +func (c *cmdable) XTrimApprox(key string, maxLen int64) *IntCmd { + cmd := NewIntCmd("xtrim", key, "maxlen", "~", maxLen) + c.process(cmd) + return cmd +} + +//------------------------------------------------------------------------------ + +// Z represents sorted set member. +type Z struct { + Score float64 + Member interface{} +} + +// ZStore is used as an arg to ZInterStore and ZUnionStore. +type ZStore struct { + Weights []float64 + // Can be SUM, MIN or MAX. + Aggregate string +} + +func (c *cmdable) zAdd(a []interface{}, n int, members ...Z) *IntCmd { + for i, m := range members { + a[n+2*i] = m.Score + a[n+2*i+1] = m.Member + } + cmd := NewIntCmd(a...) + c.process(cmd) + return cmd +} + +// Redis `ZADD key score member [score member ...]` command. +func (c *cmdable) ZAdd(key string, members ...Z) *IntCmd { + const n = 2 + a := make([]interface{}, n+2*len(members)) + a[0], a[1] = "zadd", key + return c.zAdd(a, n, members...) +} + +// Redis `ZADD key NX score member [score member ...]` command. +func (c *cmdable) ZAddNX(key string, members ...Z) *IntCmd { + const n = 3 + a := make([]interface{}, n+2*len(members)) + a[0], a[1], a[2] = "zadd", key, "nx" + return c.zAdd(a, n, members...) +} + +// Redis `ZADD key XX score member [score member ...]` command. +func (c *cmdable) ZAddXX(key string, members ...Z) *IntCmd { + const n = 3 + a := make([]interface{}, n+2*len(members)) + a[0], a[1], a[2] = "zadd", key, "xx" + return c.zAdd(a, n, members...) +} + +// Redis `ZADD key CH score member [score member ...]` command. +func (c *cmdable) ZAddCh(key string, members ...Z) *IntCmd { + const n = 3 + a := make([]interface{}, n+2*len(members)) + a[0], a[1], a[2] = "zadd", key, "ch" + return c.zAdd(a, n, members...) +} + +// Redis `ZADD key NX CH score member [score member ...]` command. +func (c *cmdable) ZAddNXCh(key string, members ...Z) *IntCmd { + const n = 4 + a := make([]interface{}, n+2*len(members)) + a[0], a[1], a[2], a[3] = "zadd", key, "nx", "ch" + return c.zAdd(a, n, members...) +} + +// Redis `ZADD key XX CH score member [score member ...]` command. +func (c *cmdable) ZAddXXCh(key string, members ...Z) *IntCmd { + const n = 4 + a := make([]interface{}, n+2*len(members)) + a[0], a[1], a[2], a[3] = "zadd", key, "xx", "ch" + return c.zAdd(a, n, members...) +} + +func (c *cmdable) zIncr(a []interface{}, n int, members ...Z) *FloatCmd { + for i, m := range members { + a[n+2*i] = m.Score + a[n+2*i+1] = m.Member + } + cmd := NewFloatCmd(a...) + c.process(cmd) + return cmd +} + +// Redis `ZADD key INCR score member` command. +func (c *cmdable) ZIncr(key string, member Z) *FloatCmd { + const n = 3 + a := make([]interface{}, n+2) + a[0], a[1], a[2] = "zadd", key, "incr" + return c.zIncr(a, n, member) +} + +// Redis `ZADD key NX INCR score member` command. +func (c *cmdable) ZIncrNX(key string, member Z) *FloatCmd { + const n = 4 + a := make([]interface{}, n+2) + a[0], a[1], a[2], a[3] = "zadd", key, "incr", "nx" + return c.zIncr(a, n, member) +} + +// Redis `ZADD key XX INCR score member` command. +func (c *cmdable) ZIncrXX(key string, member Z) *FloatCmd { + const n = 4 + a := make([]interface{}, n+2) + a[0], a[1], a[2], a[3] = "zadd", key, "incr", "xx" + return c.zIncr(a, n, member) +} + +func (c *cmdable) ZCard(key string) *IntCmd { + cmd := NewIntCmd("zcard", key) + c.process(cmd) + return cmd +} + +func (c *cmdable) ZCount(key, min, max string) *IntCmd { + cmd := NewIntCmd("zcount", key, min, max) + c.process(cmd) + return cmd +} + +func (c *cmdable) ZLexCount(key, min, max string) *IntCmd { + cmd := NewIntCmd("zlexcount", key, min, max) + c.process(cmd) + return cmd +} + +func (c *cmdable) ZIncrBy(key string, increment float64, member string) *FloatCmd { + cmd := NewFloatCmd("zincrby", key, increment, member) + c.process(cmd) + return cmd +} + +func (c *cmdable) ZInterStore(destination string, store ZStore, keys ...string) *IntCmd { + args := make([]interface{}, 3+len(keys)) + args[0] = "zinterstore" + args[1] = destination + args[2] = len(keys) + for i, key := range keys { + args[3+i] = key + } + if len(store.Weights) > 0 { + args = append(args, "weights") + for _, weight := range store.Weights { + args = append(args, weight) + } + } + if store.Aggregate != "" { + args = append(args, "aggregate", store.Aggregate) + } + cmd := NewIntCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) ZPopMax(key string, count ...int64) *ZSliceCmd { + args := []interface{}{ + "zpopmax", + key, + } + + switch len(count) { + case 0: + break + case 1: + args = append(args, count[0]) + default: + panic("too many arguments") + } + + cmd := NewZSliceCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) ZPopMin(key string, count ...int64) *ZSliceCmd { + args := []interface{}{ + "zpopmin", + key, + } + + switch len(count) { + case 0: + break + case 1: + args = append(args, count[0]) + default: + panic("too many arguments") + } + + cmd := NewZSliceCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) zRange(key string, start, stop int64, withScores bool) *StringSliceCmd { + args := []interface{}{ + "zrange", + key, + start, + stop, + } + if withScores { + args = append(args, "withscores") + } + cmd := NewStringSliceCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) ZRange(key string, start, stop int64) *StringSliceCmd { + return c.zRange(key, start, stop, false) +} + +func (c *cmdable) ZRangeWithScores(key string, start, stop int64) *ZSliceCmd { + cmd := NewZSliceCmd("zrange", key, start, stop, "withscores") + c.process(cmd) + return cmd +} + +type ZRangeBy struct { + Min, Max string + Offset, Count int64 +} + +func (c *cmdable) zRangeBy(zcmd, key string, opt ZRangeBy, withScores bool) *StringSliceCmd { + args := []interface{}{zcmd, key, opt.Min, opt.Max} + if withScores { + args = append(args, "withscores") + } + if opt.Offset != 0 || opt.Count != 0 { + args = append( + args, + "limit", + opt.Offset, + opt.Count, + ) + } + cmd := NewStringSliceCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) ZRangeByScore(key string, opt ZRangeBy) *StringSliceCmd { + return c.zRangeBy("zrangebyscore", key, opt, false) +} + +func (c *cmdable) ZRangeByLex(key string, opt ZRangeBy) *StringSliceCmd { + return c.zRangeBy("zrangebylex", key, opt, false) +} + +func (c *cmdable) ZRangeByScoreWithScores(key string, opt ZRangeBy) *ZSliceCmd { + args := []interface{}{"zrangebyscore", key, opt.Min, opt.Max, "withscores"} + if opt.Offset != 0 || opt.Count != 0 { + args = append( + args, + "limit", + opt.Offset, + opt.Count, + ) + } + cmd := NewZSliceCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) ZRank(key, member string) *IntCmd { + cmd := NewIntCmd("zrank", key, member) + c.process(cmd) + return cmd +} + +func (c *cmdable) ZRem(key string, members ...interface{}) *IntCmd { + args := make([]interface{}, 2, 2+len(members)) + args[0] = "zrem" + args[1] = key + args = appendArgs(args, members) + cmd := NewIntCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) ZRemRangeByRank(key string, start, stop int64) *IntCmd { + cmd := NewIntCmd( + "zremrangebyrank", + key, + start, + stop, + ) + c.process(cmd) + return cmd +} + +func (c *cmdable) ZRemRangeByScore(key, min, max string) *IntCmd { + cmd := NewIntCmd("zremrangebyscore", key, min, max) + c.process(cmd) + return cmd +} + +func (c *cmdable) ZRemRangeByLex(key, min, max string) *IntCmd { + cmd := NewIntCmd("zremrangebylex", key, min, max) + c.process(cmd) + return cmd +} + +func (c *cmdable) ZRevRange(key string, start, stop int64) *StringSliceCmd { + cmd := NewStringSliceCmd("zrevrange", key, start, stop) + c.process(cmd) + return cmd +} + +func (c *cmdable) ZRevRangeWithScores(key string, start, stop int64) *ZSliceCmd { + cmd := NewZSliceCmd("zrevrange", key, start, stop, "withscores") + c.process(cmd) + return cmd +} + +func (c *cmdable) zRevRangeBy(zcmd, key string, opt ZRangeBy) *StringSliceCmd { + args := []interface{}{zcmd, key, opt.Max, opt.Min} + if opt.Offset != 0 || opt.Count != 0 { + args = append( + args, + "limit", + opt.Offset, + opt.Count, + ) + } + cmd := NewStringSliceCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) ZRevRangeByScore(key string, opt ZRangeBy) *StringSliceCmd { + return c.zRevRangeBy("zrevrangebyscore", key, opt) +} + +func (c *cmdable) ZRevRangeByLex(key string, opt ZRangeBy) *StringSliceCmd { + return c.zRevRangeBy("zrevrangebylex", key, opt) +} + +func (c *cmdable) ZRevRangeByScoreWithScores(key string, opt ZRangeBy) *ZSliceCmd { + args := []interface{}{"zrevrangebyscore", key, opt.Max, opt.Min, "withscores"} + if opt.Offset != 0 || opt.Count != 0 { + args = append( + args, + "limit", + opt.Offset, + opt.Count, + ) + } + cmd := NewZSliceCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) ZRevRank(key, member string) *IntCmd { + cmd := NewIntCmd("zrevrank", key, member) + c.process(cmd) + return cmd +} + +func (c *cmdable) ZScore(key, member string) *FloatCmd { + cmd := NewFloatCmd("zscore", key, member) + c.process(cmd) + return cmd +} + +func (c *cmdable) ZUnionStore(dest string, store ZStore, keys ...string) *IntCmd { + args := make([]interface{}, 3+len(keys)) + args[0] = "zunionstore" + args[1] = dest + args[2] = len(keys) + for i, key := range keys { + args[3+i] = key + } + if len(store.Weights) > 0 { + args = append(args, "weights") + for _, weight := range store.Weights { + args = append(args, weight) + } + } + if store.Aggregate != "" { + args = append(args, "aggregate", store.Aggregate) + } + cmd := NewIntCmd(args...) + c.process(cmd) + return cmd +} + +//------------------------------------------------------------------------------ + +func (c *cmdable) PFAdd(key string, els ...interface{}) *IntCmd { + args := make([]interface{}, 2, 2+len(els)) + args[0] = "pfadd" + args[1] = key + args = appendArgs(args, els) + cmd := NewIntCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) PFCount(keys ...string) *IntCmd { + args := make([]interface{}, 1+len(keys)) + args[0] = "pfcount" + for i, key := range keys { + args[1+i] = key + } + cmd := NewIntCmd(args...) + c.process(cmd) + return cmd +} + +func (c *cmdable) PFMerge(dest string, keys ...string) *StatusCmd { + args := make([]interface{}, 2+len(keys)) + args[0] = "pfmerge" + args[1] = dest + for i, key := range keys { + args[2+i] = key + } + cmd := NewStatusCmd(args...) + c.process(cmd) + return cmd +} + +//------------------------------------------------------------------------------ + +func (c *cmdable) BgRewriteAOF() *StatusCmd { + cmd := NewStatusCmd("bgrewriteaof") + c.process(cmd) + return cmd +} + +func (c *cmdable) BgSave() *StatusCmd { + cmd := NewStatusCmd("bgsave") + c.process(cmd) + return cmd +} + +func (c *cmdable) ClientKill(ipPort string) *StatusCmd { + cmd := NewStatusCmd("client", "kill", ipPort) + c.process(cmd) + return cmd +} + +// ClientKillByFilter is new style synx, while the ClientKill is old +// CLIENT KILL