diff --git a/.gosimpleignore b/.gosimpleignore new file mode 100644 index 00000000..84df9b95 --- /dev/null +++ b/.gosimpleignore @@ -0,0 +1,4 @@ +github.com/astaxie/beego/*/*:S1012 +github.com/astaxie/beego/*:S1012 +github.com/astaxie/beego/*/*:S1007 +github.com/astaxie/beego/*:S1007 \ No newline at end of file diff --git a/.travis.yml b/.travis.yml index df3e923f..aa88a44d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,9 +1,9 @@ language: go go: - - 1.6 - - 1.5.3 - - 1.4.3 + - 1.6.4 + - 1.7.5 + - 1.8.1 services: - redis-server - mysql @@ -33,6 +33,12 @@ install: - go get github.com/ssdb/gossdb/ssdb - go get github.com/cloudflare/golz4 - go get github.com/gogo/protobuf/proto + - go get github.com/Knetic/govaluate + - go get github.com/hsluoyz/casbin + - go get -u honnef.co/go/tools/cmd/gosimple + - go get -u github.com/mdempsky/unconvert + - go get -u github.com/gordonklaus/ineffassign + - go get -u github.com/golang/lint/golint before_script: - psql --version - sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi" @@ -47,5 +53,10 @@ after_script: - rm -rf ./res/var/* script: - go test -v ./... + - gosimple -ignore "$(cat .gosimpleignore)" $(go list ./... | grep -v /vendor/) + - unconvert $(go list ./... | grep -v /vendor/) + - ineffassign . + - find . ! \( -path './vendor' -prune \) -type f -name '*.go' -print0 | xargs -0 gofmt -l -s + - golint ./... addons: postgresql: "9.4" diff --git a/README.md b/README.md index d3c92d84..c08927fb 100644 --- a/README.md +++ b/README.md @@ -1,20 +1,17 @@ -## Beego - -[![Build Status](https://travis-ci.org/astaxie/beego.svg?branch=master)](https://travis-ci.org/astaxie/beego) -[![GoDoc](http://godoc.org/github.com/astaxie/beego?status.svg)](http://godoc.org/github.com/astaxie/beego) -[![Foundation](https://img.shields.io/badge/Golang-Foundation-green.svg)](http://golangfoundation.org) +# Beego [![Build Status](https://travis-ci.org/astaxie/beego.svg?branch=master)](https://travis-ci.org/astaxie/beego) [![GoDoc](http://godoc.org/github.com/astaxie/beego?status.svg)](http://godoc.org/github.com/astaxie/beego) [![Foundation](https://img.shields.io/badge/Golang-Foundation-green.svg)](http://golangfoundation.org) beego is used for rapid development of RESTful APIs, web apps and backend services in Go. It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific features such as interfaces and struct embedding. -More info [beego.me](http://beego.me) +###### More info at [beego.me](http://beego.me). -##Quick Start -######Download and install +## Quick Start + +#### Download and install go get github.com/astaxie/beego -######Create file `hello.go` +#### Create file `hello.go` ```go package main @@ -24,15 +21,16 @@ func main(){ beego.Run() } ``` -######Build and run -```bash +#### Build and run + go build hello.go ./hello -``` -######Congratulations! -You just built your first beego app. -Open your browser and visit `http://localhost:8080`. -Please see [Documentation](http://beego.me/docs) for more. + +#### Go to [http://localhost:8080](http://localhost:8080) + +Congratulations! You've just built your first **beego** app. + +###### Please see [Documentation](http://beego.me/docs) for more. ## Features @@ -56,7 +54,7 @@ Please see [Documentation](http://beego.me/docs) for more. * [http://beego.me/community](http://beego.me/community) * Welcome to join us in Slack: [https://beego.slack.com](https://beego.slack.com), you can get invited from [here](https://github.com/beego/beedoc/issues/232) -## LICENSE +## License beego source code is licensed under the Apache Licence, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0.html). diff --git a/admin.go b/admin.go index 71dafd0f..875cd0e8 100644 --- a/admin.go +++ b/admin.go @@ -140,8 +140,8 @@ func listConf(rw http.ResponseWriter, r *http.Request) { resultList := new([][]string) for _, f := range bf { var result = []string{ - fmt.Sprintf("%s", f.pattern), - fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)), + f.pattern, + utils.GetFuncName(f.filterFunc), } *resultList = append(*resultList, result) } @@ -213,12 +213,12 @@ func printTree(resultList *[][]string, t *Tree) { printTree(resultList, t.wildcard) } for _, l := range t.leaves { - if v, ok := l.runObject.(*controllerInfo); ok { + if v, ok := l.runObject.(*ControllerInfo); ok { if v.routerType == routerTypeBeego { var result = []string{ v.pattern, fmt.Sprintf("%s", v.methods), - fmt.Sprintf("%s", v.controllerType), + v.controllerType.String(), } *resultList = append(*resultList, result) } else if v.routerType == routerTypeRESTFul { @@ -281,8 +281,8 @@ func profIndex(rw http.ResponseWriter, r *http.Request) { // it's in "/healthcheck" pattern in admin module. func healthcheck(rw http.ResponseWriter, req *http.Request) { var ( + result []string data = make(map[interface{}]interface{}) - result = []string{} resultList = new([][]string) content = map[string]interface{}{ "Fields": []string{"Name", "Message", "Status"}, @@ -292,21 +292,20 @@ func healthcheck(rw http.ResponseWriter, req *http.Request) { for name, h := range toolbox.AdminCheckList { if err := h.Check(); err != nil { result = []string{ - fmt.Sprintf("error"), - fmt.Sprintf("%s", name), - fmt.Sprintf("%s", err.Error()), + "error", + name, + err.Error(), } - } else { result = []string{ - fmt.Sprintf("success"), - fmt.Sprintf("%s", name), - fmt.Sprintf("OK"), + "success", + name, + "OK", } - } *resultList = append(*resultList, result) } + content["Data"] = resultList data["Content"] = content data["Title"] = "Health Check" @@ -335,7 +334,6 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) { // List Tasks content := make(map[string]interface{}) resultList := new([][]string) - var result = []string{} var fields = []string{ "Task Name", "Task Spec", @@ -344,10 +342,10 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) { "", } for tname, tk := range toolbox.AdminTaskList { - result = []string{ + result := []string{ tname, - fmt.Sprintf("%s", tk.GetSpec()), - fmt.Sprintf("%s", tk.GetStatus()), + tk.GetSpec(), + tk.GetStatus(), tk.GetPrev().String(), } *resultList = append(*resultList, result) diff --git a/app.go b/app.go index 32776298..25ea2a04 100644 --- a/app.go +++ b/app.go @@ -348,9 +348,9 @@ func Any(rootpath string, f FilterFunc) *App { // Handler used to register a Handler router // usage: -// beego.Handler("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) +// beego.Handler("/api", http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { +// fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path)) +// })) func Handler(rootpath string, h http.Handler, options ...interface{}) *App { BeeApp.Handlers.Handler(rootpath, h, options...) return BeeApp diff --git a/beego.go b/beego.go index c06b499c..22079a20 100644 --- a/beego.go +++ b/beego.go @@ -23,7 +23,7 @@ import ( const ( // VERSION represent beego web framework version. - VERSION = "1.8.0" + VERSION = "1.8.3" // DEV is for develop DEV = "dev" diff --git a/cache/conv.go b/cache/conv.go index dbdff1c7..87800586 100644 --- a/cache/conv.go +++ b/cache/conv.go @@ -28,7 +28,7 @@ func GetString(v interface{}) string { return string(result) default: if v != nil { - return fmt.Sprintf("%v", result) + return fmt.Sprint(result) } } return "" diff --git a/cache/conv_test.go b/cache/conv_test.go index cf792fa6..b90e224a 100644 --- a/cache/conv_test.go +++ b/cache/conv_test.go @@ -118,14 +118,14 @@ func TestGetFloat64(t *testing.T) { func TestGetBool(t *testing.T) { var t1 = true - if true != GetBool(t1) { + if !GetBool(t1) { t.Error("get bool from bool error") } var t2 = "true" - if true != GetBool(t2) { + if !GetBool(t2) { t.Error("get bool from string error") } - if false != GetBool(nil) { + if GetBool(nil) { t.Error("get bool from nil error") } } diff --git a/cache/memcache/memcache.go b/cache/memcache/memcache.go index 972361f7..0624f5fa 100644 --- a/cache/memcache/memcache.go +++ b/cache/memcache/memcache.go @@ -146,10 +146,7 @@ func (rc *Cache) IsExist(key string) bool { } } _, err := rc.conn.Get(key) - if err != nil { - return false - } - return true + return !(err != nil) } // ClearAll clear all cached in memcache. diff --git a/cache/redis/redis.go b/cache/redis/redis.go index 781e3836..3e71fb53 100644 --- a/cache/redis/redis.go +++ b/cache/redis/redis.go @@ -137,7 +137,7 @@ func (rc *Cache) IsExist(key string) bool { if err != nil { return false } - if v == false { + if !v { if _, err = rc.do("HDEL", rc.key, key); err != nil { return false } diff --git a/cache/ssdb/ssdb.go b/cache/ssdb/ssdb.go index bbc43606..fa2ce04b 100644 --- a/cache/ssdb/ssdb.go +++ b/cache/ssdb/ssdb.go @@ -53,7 +53,7 @@ func (rc *Cache) GetMulti(keys []string) []interface{} { resSize := len(res) if err == nil { for i := 1; i < resSize; i += 2 { - values = append(values, string(res[i+1])) + values = append(values, res[i+1]) } return values } @@ -71,10 +71,7 @@ func (rc *Cache) DelMulti(keys []string) error { } } _, err := rc.conn.Do("multi_del", keys) - if err != nil { - return err - } - return nil + return err } // Put put value to memcache. only support string. @@ -113,10 +110,7 @@ func (rc *Cache) Delete(key string) error { } } _, err := rc.conn.Del(key) - if err != nil { - return err - } - return nil + return err } // Incr increase counter. @@ -175,7 +169,7 @@ func (rc *Cache) ClearAll() error { } keys := []string{} for i := 1; i < size; i += 2 { - keys = append(keys, string(resp[i])) + keys = append(keys, resp[i]) } _, e := rc.conn.Do("multi_del", keys) if e != nil { @@ -229,10 +223,7 @@ func (rc *Cache) connectInit() error { } var err error rc.conn, err = ssdb.Connect(host, port) - if err != nil { - return err - } - return nil + return err } func init() { diff --git a/config.go b/config.go index 3c202e53..e6e99570 100644 --- a/config.go +++ b/config.go @@ -345,7 +345,7 @@ func assignSingleConfig(p interface{}, ac config.Configer) { case reflect.String: pf.SetString(ac.DefaultString(name, pf.String())) case reflect.Int, reflect.Int64: - pf.SetInt(int64(ac.DefaultInt64(name, pf.Int()))) + pf.SetInt(ac.DefaultInt64(name, pf.Int())) case reflect.Bool: pf.SetBool(ac.DefaultBool(name, pf.Bool())) case reflect.Struct: diff --git a/config/env/env.go b/config/env/env.go index a819e51a..34f094fe 100644 --- a/config/env/env.go +++ b/config/env/env.go @@ -12,6 +12,8 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +// Package env is used to parse environment. package env import ( diff --git a/config/ini.go b/config/ini.go index 27220f90..a681bc1b 100644 --- a/config/ini.go +++ b/config/ini.go @@ -21,6 +21,7 @@ import ( "io" "io/ioutil" "os" + "os/user" "path/filepath" "strconv" "strings" @@ -184,10 +185,17 @@ func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, e // ParseData parse ini the data // When include other.conf,other.conf is either absolute directory -// or under beego in default temporary directory(/tmp/beego). +// or under beego in default temporary directory(/tmp/beego[-username]). func (ini *IniConfig) ParseData(data []byte) (Configer, error) { - dir := filepath.Join(os.TempDir(), "beego") - os.MkdirAll(dir, os.ModePerm) + dir := "beego" + currentUser, err := user.Current() + if err == nil { + dir = "beego-" + currentUser.Username + } + dir = filepath.Join(os.TempDir(), dir) + if err = os.MkdirAll(dir, os.ModePerm); err != nil { + return nil, err + } return ini.parseData(dir, data) } @@ -317,7 +325,10 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { // Get section or key comments. Fixed #1607 getCommentStr := func(section, key string) string { - comment, ok := "", false + var ( + comment string + ok bool + ) if len(key) == 0 { comment, ok = c.sectionComment[section] } else { @@ -397,11 +408,8 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { } } } - - if _, err = buf.WriteTo(f); err != nil { - return err - } - return nil + _, err = buf.WriteTo(f) + return err } // Set writes a new value for key. @@ -416,7 +424,7 @@ func (c *IniConfigContainer) Set(key, value string) error { var ( section, k string - sectionKey = strings.Split(key, "::") + sectionKey = strings.Split(strings.ToLower(key), "::") ) if len(sectionKey) >= 2 { diff --git a/config/ini_test.go b/config/ini_test.go index 83ff3668..ffcdb294 100644 --- a/config/ini_test.go +++ b/config/ini_test.go @@ -181,7 +181,7 @@ name=mysql cfgData := string(data) datas := strings.Split(saveResult, "\n") for _, line := range datas { - if strings.Contains(cfgData, line+"\n") == false { + if !strings.Contains(cfgData, line+"\n") { t.Fatalf("different after save ini config file. need contains %q", line) } } diff --git a/context/acceptencoder.go b/context/acceptencoder.go index 350b560d..b4e2492c 100644 --- a/context/acceptencoder.go +++ b/context/acceptencoder.go @@ -39,6 +39,7 @@ var ( getMethodOnly bool ) +// InitGzip init the gzipcompress func InitGzip(minLength, compressLevel int, methods []string) { if minLength >= 0 { gzipMinLength = minLength diff --git a/context/context.go b/context/context.go index 03286097..8b32062c 100644 --- a/context/context.go +++ b/context/context.go @@ -171,6 +171,22 @@ func (ctx *Context) CheckXSRFCookie() bool { return true } +// RenderMethodResult renders the return value of a controller method to the output +func (ctx *Context) RenderMethodResult(result interface{}) { + if result != nil { + renderer, ok := result.(Renderer) + if !ok { + err, ok := result.(error) + if ok { + renderer = errorRenderer(err) + } else { + renderer = jsonRenderer(result) + } + } + renderer.Render(ctx) + } +} + //Response is a wrapper for the http.ResponseWriter //started set to true if response was written to then don't execute other handler type Response struct { diff --git a/context/input_test.go b/context/input_test.go index 9853e398..db812a0f 100644 --- a/context/input_test.go +++ b/context/input_test.go @@ -73,8 +73,8 @@ func TestBind(t *testing.T) { {"/?human.ID=888&human.Nick=astaxie&human.Ms=true&human[Pwd]=pass", []testItem{{"human", Human{}, Human{ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass"}}}}, {"/?human[0].ID=888&human[0].Nick=astaxie&human[0].Ms=true&human[0][Pwd]=pass01&human[1].ID=999&human[1].Nick=ysqi&human[1].Ms=On&human[1].Pwd=pass02", []testItem{{"human", []Human{}, []Human{ - Human{ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass01"}, - Human{ID: 999, Nick: "ysqi", Ms: true, Pwd: "pass02"}, + {ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass01"}, + {ID: 999, Nick: "ysqi", Ms: true, Pwd: "pass02"}, }}}}, { diff --git a/context/output.go b/context/output.go index 564ef96d..cf9e7a7e 100644 --- a/context/output.go +++ b/context/output.go @@ -168,6 +168,19 @@ func sanitizeValue(v string) string { return cookieValueSanitizer.Replace(v) } +func jsonRenderer(value interface{}) Renderer { + return rendererFunc(func(ctx *Context) { + ctx.Output.JSON(value, false, false) + }) +} + +func errorRenderer(err error) Renderer { + return rendererFunc(func(ctx *Context) { + ctx.Output.SetStatus(500) + ctx.WriteString(err.Error()) + }) +} + // JSON writes json to response body. // if coding is true, it converts utf-8 to \u0000 type. func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, coding bool) error { @@ -330,9 +343,8 @@ func (output *BeegoOutput) IsServerError() bool { } func stringsToJSON(str string) string { - rs := []rune(str) var jsons bytes.Buffer - for _, r := range rs { + for _, r := range str { rint := int(r) if rint < 128 { jsons.WriteRune(r) diff --git a/context/param/conv.go b/context/param/conv.go new file mode 100644 index 00000000..c200e008 --- /dev/null +++ b/context/param/conv.go @@ -0,0 +1,78 @@ +package param + +import ( + "fmt" + "reflect" + + beecontext "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" +) + +// ConvertParams converts http method params to values that will be passed to the method controller as arguments +func ConvertParams(methodParams []*MethodParam, methodType reflect.Type, ctx *beecontext.Context) (result []reflect.Value) { + result = make([]reflect.Value, 0, len(methodParams)) + for i := 0; i < len(methodParams); i++ { + reflectValue := convertParam(methodParams[i], methodType.In(i), ctx) + result = append(result, reflectValue) + } + return +} + +func convertParam(param *MethodParam, paramType reflect.Type, ctx *beecontext.Context) (result reflect.Value) { + paramValue := getParamValue(param, ctx) + if paramValue == "" { + if param.required { + ctx.Abort(400, fmt.Sprintf("Missing parameter %s", param.name)) + } else { + paramValue = param.defaultValue + } + } + + reflectValue, err := parseValue(param, paramValue, paramType) + if err != nil { + logs.Debug(fmt.Sprintf("Error converting param %s to type %s. Value: %v, Error: %s", param.name, paramType, paramValue, err)) + ctx.Abort(400, fmt.Sprintf("Invalid parameter %s. Can not convert %v to type %s", param.name, paramValue, paramType)) + } + + return reflectValue +} + +func getParamValue(param *MethodParam, ctx *beecontext.Context) string { + switch param.in { + case body: + return string(ctx.Input.RequestBody) + case header: + return ctx.Input.Header(param.name) + case path: + return ctx.Input.Query(":" + param.name) + default: + return ctx.Input.Query(param.name) + } +} + +func parseValue(param *MethodParam, paramValue string, paramType reflect.Type) (result reflect.Value, err error) { + if paramValue == "" { + return reflect.Zero(paramType), nil + } + parser := getParser(param, paramType) + value, err := parser.parse(paramValue, paramType) + if err != nil { + return result, err + } + + return safeConvert(reflect.ValueOf(value), paramType) +} + +func safeConvert(value reflect.Value, t reflect.Type) (result reflect.Value, err error) { + defer func() { + if r := recover(); r != nil { + var ok bool + err, ok = r.(error) + if !ok { + err = fmt.Errorf("%v", r) + } + } + }() + result = value.Convert(t) + return +} diff --git a/context/param/methodparams.go b/context/param/methodparams.go new file mode 100644 index 00000000..cd6708a2 --- /dev/null +++ b/context/param/methodparams.go @@ -0,0 +1,69 @@ +package param + +import ( + "fmt" + "strings" +) + +//MethodParam keeps param information to be auto passed to controller methods +type MethodParam struct { + name string + in paramType + required bool + defaultValue string +} + +type paramType byte + +const ( + param paramType = iota + path + body + header +) + +//New creates a new MethodParam with name and specific options +func New(name string, opts ...MethodParamOption) *MethodParam { + return newParam(name, nil, opts) +} + +func newParam(name string, parser paramParser, opts []MethodParamOption) (param *MethodParam) { + param = &MethodParam{name: name} + for _, option := range opts { + option(param) + } + return +} + +//Make creates an array of MethodParmas or an empty array +func Make(list ...*MethodParam) []*MethodParam { + if len(list) > 0 { + return list + } + return nil +} + +func (mp *MethodParam) String() string { + options := []string{} + result := "param.New(\"" + mp.name + "\"" + if mp.required { + options = append(options, "param.IsRequired") + } + switch mp.in { + case path: + options = append(options, "param.InPath") + case body: + options = append(options, "param.InBody") + case header: + options = append(options, "param.InHeader") + } + if mp.defaultValue != "" { + options = append(options, fmt.Sprintf(`param.Default("%s")`, mp.defaultValue)) + } + if len(options) > 0 { + result += ", " + } + result += strings.Join(options, ", ") + result += ")" + return result +} diff --git a/context/param/options.go b/context/param/options.go new file mode 100644 index 00000000..58bdc3d0 --- /dev/null +++ b/context/param/options.go @@ -0,0 +1,37 @@ +package param + +import ( + "fmt" +) + +// MethodParamOption defines a func which apply options on a MethodParam +type MethodParamOption func(*MethodParam) + +// IsRequired indicates that this param is required and can not be ommited from the http request +var IsRequired MethodParamOption = func(p *MethodParam) { + p.required = true +} + +// InHeader indicates that this param is passed via an http header +var InHeader MethodParamOption = func(p *MethodParam) { + p.in = header +} + +// InPath indicates that this param is part of the URL path +var InPath MethodParamOption = func(p *MethodParam) { + p.in = path +} + +// InBody indicates that this param is passed as an http request body +var InBody MethodParamOption = func(p *MethodParam) { + p.in = body +} + +// Default provides a default value for the http param +func Default(defaultValue interface{}) MethodParamOption { + return func(p *MethodParam) { + if defaultValue != nil { + p.defaultValue = fmt.Sprint(defaultValue) + } + } +} diff --git a/context/param/parsers.go b/context/param/parsers.go new file mode 100644 index 00000000..421aecf0 --- /dev/null +++ b/context/param/parsers.go @@ -0,0 +1,149 @@ +package param + +import ( + "encoding/json" + "reflect" + "strconv" + "strings" + "time" +) + +type paramParser interface { + parse(value string, toType reflect.Type) (interface{}, error) +} + +func getParser(param *MethodParam, t reflect.Type) paramParser { + switch t.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return intParser{} + case reflect.Slice: + if t.Elem().Kind() == reflect.Uint8 { //treat []byte as string + return stringParser{} + } + if param.in == body { + return jsonParser{} + } + elemParser := getParser(param, t.Elem()) + if elemParser == (jsonParser{}) { + return elemParser + } + return sliceParser(elemParser) + case reflect.Bool: + return boolParser{} + case reflect.String: + return stringParser{} + case reflect.Float32, reflect.Float64: + return floatParser{} + case reflect.Ptr: + elemParser := getParser(param, t.Elem()) + if elemParser == (jsonParser{}) { + return elemParser + } + return ptrParser(elemParser) + default: + if t.PkgPath() == "time" && t.Name() == "Time" { + return timeParser{} + } + return jsonParser{} + } +} + +type parserFunc func(value string, toType reflect.Type) (interface{}, error) + +func (f parserFunc) parse(value string, toType reflect.Type) (interface{}, error) { + return f(value, toType) +} + +type boolParser struct { +} + +func (p boolParser) parse(value string, toType reflect.Type) (interface{}, error) { + return strconv.ParseBool(value) +} + +type stringParser struct { +} + +func (p stringParser) parse(value string, toType reflect.Type) (interface{}, error) { + return value, nil +} + +type intParser struct { +} + +func (p intParser) parse(value string, toType reflect.Type) (interface{}, error) { + return strconv.Atoi(value) +} + +type floatParser struct { +} + +func (p floatParser) parse(value string, toType reflect.Type) (interface{}, error) { + if toType.Kind() == reflect.Float32 { + res, err := strconv.ParseFloat(value, 32) + if err != nil { + return nil, err + } + return float32(res), nil + } + return strconv.ParseFloat(value, 64) +} + +type timeParser struct { +} + +func (p timeParser) parse(value string, toType reflect.Type) (result interface{}, err error) { + result, err = time.Parse(time.RFC3339, value) + if err != nil { + result, err = time.Parse("2006-01-02", value) + } + return +} + +type jsonParser struct { +} + +func (p jsonParser) parse(value string, toType reflect.Type) (interface{}, error) { + pResult := reflect.New(toType) + v := pResult.Interface() + err := json.Unmarshal([]byte(value), v) + if err != nil { + return nil, err + } + return pResult.Elem().Interface(), nil +} + +func sliceParser(elemParser paramParser) paramParser { + return parserFunc(func(value string, toType reflect.Type) (interface{}, error) { + values := strings.Split(value, ",") + result := reflect.MakeSlice(toType, 0, len(values)) + elemType := toType.Elem() + for _, v := range values { + parsedValue, err := elemParser.parse(v, elemType) + if err != nil { + return nil, err + } + result = reflect.Append(result, reflect.ValueOf(parsedValue)) + } + return result.Interface(), nil + }) +} + +func ptrParser(elemParser paramParser) paramParser { + return parserFunc(func(value string, toType reflect.Type) (interface{}, error) { + parsedValue, err := elemParser.parse(value, toType.Elem()) + if err != nil { + return nil, err + } + newValPtr := reflect.New(toType.Elem()) + newVal := reflect.Indirect(newValPtr) + convertedVal, err := safeConvert(reflect.ValueOf(parsedValue), toType.Elem()) + if err != nil { + return nil, err + } + + newVal.Set(convertedVal) + return newValPtr.Interface(), nil + }) +} diff --git a/context/param/parsers_test.go b/context/param/parsers_test.go new file mode 100644 index 00000000..b946ba08 --- /dev/null +++ b/context/param/parsers_test.go @@ -0,0 +1,84 @@ +package param + +import "testing" +import "reflect" +import "time" + +type testDefinition struct { + strValue string + expectedValue interface{} + expectedParser paramParser +} + +func Test_Parsers(t *testing.T) { + + //ints + checkParser(testDefinition{"1", 1, intParser{}}, t) + checkParser(testDefinition{"-1", int64(-1), intParser{}}, t) + checkParser(testDefinition{"1", uint64(1), intParser{}}, t) + + //floats + checkParser(testDefinition{"1.0", float32(1.0), floatParser{}}, t) + checkParser(testDefinition{"-1.0", float64(-1.0), floatParser{}}, t) + + //strings + checkParser(testDefinition{"AB", "AB", stringParser{}}, t) + checkParser(testDefinition{"AB", []byte{65, 66}, stringParser{}}, t) + + //bools + checkParser(testDefinition{"true", true, boolParser{}}, t) + checkParser(testDefinition{"0", false, boolParser{}}, t) + + //timeParser + checkParser(testDefinition{"2017-05-30T13:54:53Z", time.Date(2017, 5, 30, 13, 54, 53, 0, time.UTC), timeParser{}}, t) + checkParser(testDefinition{"2017-05-30", time.Date(2017, 5, 30, 0, 0, 0, 0, time.UTC), timeParser{}}, t) + + //json + checkParser(testDefinition{`{"X": 5, "Y":"Z"}`, struct { + X int + Y string + }{5, "Z"}, jsonParser{}}, t) + + //slice in query is parsed as comma delimited + checkParser(testDefinition{`1,2`, []int{1, 2}, sliceParser(intParser{})}, t) + + //slice in body is parsed as json + checkParser(testDefinition{`["a","b"]`, []string{"a", "b"}, jsonParser{}}, t, MethodParam{in: body}) + + //pointers + var someInt = 1 + checkParser(testDefinition{`1`, &someInt, ptrParser(intParser{})}, t) + + var someStruct = struct{ X int }{5} + checkParser(testDefinition{`{"X": 5}`, &someStruct, jsonParser{}}, t) + +} + +func checkParser(def testDefinition, t *testing.T, methodParam ...MethodParam) { + toType := reflect.TypeOf(def.expectedValue) + var mp MethodParam + if len(methodParam) == 0 { + mp = MethodParam{} + } else { + mp = methodParam[0] + } + parser := getParser(&mp, toType) + + if reflect.TypeOf(parser) != reflect.TypeOf(def.expectedParser) { + t.Errorf("Invalid parser for value %v. Expected: %v, actual: %v", def.strValue, reflect.TypeOf(def.expectedParser).Name(), reflect.TypeOf(parser).Name()) + return + } + result, err := parser.parse(def.strValue, toType) + if err != nil { + t.Errorf("Parsing error for value %v. Expected result: %v, error: %v", def.strValue, def.expectedValue, err) + return + } + convResult, err := safeConvert(reflect.ValueOf(result), toType) + if err != nil { + t.Errorf("Convertion error for %v. from value: %v, toType: %v, error: %v", def.strValue, result, toType, err) + return + } + if !reflect.DeepEqual(convResult.Interface(), def.expectedValue) { + t.Errorf("Parsing error for value %v. Expected result: %v, actual: %v", def.strValue, def.expectedValue, result) + } +} diff --git a/context/renderer.go b/context/renderer.go new file mode 100644 index 00000000..36a7cb53 --- /dev/null +++ b/context/renderer.go @@ -0,0 +1,12 @@ +package context + +// Renderer defines an http response renderer +type Renderer interface { + Render(ctx *Context) +} + +type rendererFunc func(ctx *Context) + +func (f rendererFunc) Render(ctx *Context) { + f(ctx) +} diff --git a/context/response.go b/context/response.go new file mode 100644 index 00000000..9c3c715a --- /dev/null +++ b/context/response.go @@ -0,0 +1,27 @@ +package context + +import ( + "strconv" + + "net/http" +) + +const ( + //BadRequest indicates http error 400 + BadRequest StatusCode = http.StatusBadRequest + + //NotFound indicates http error 404 + NotFound StatusCode = http.StatusNotFound +) + +// StatusCode sets the http response status code +type StatusCode int + +func (s StatusCode) Error() string { + return strconv.Itoa(int(s)) +} + +// Render sets the http status code +func (s StatusCode) Render(ctx *Context) { + ctx.Output.SetStatus(int(s)) +} diff --git a/controller.go b/controller.go index 488ffcda..510e16b8 100644 --- a/controller.go +++ b/controller.go @@ -28,6 +28,7 @@ import ( "strings" "github.com/astaxie/beego/context" + "github.com/astaxie/beego/context/param" "github.com/astaxie/beego/session" ) @@ -51,6 +52,7 @@ type ControllerComments struct { Router string AllowHTTPMethods []string Params []map[string]string + MethodParams []*param.MethodParam } // Controller defines some basic http request handler operations, such as @@ -223,7 +225,7 @@ func (c *Controller) RenderBytes() ([]byte, error) { } buf.Reset() - ExecuteViewPathTemplate(&buf, c.Layout, c.viewPath() ,c.Data) + ExecuteViewPathTemplate(&buf, c.Layout, c.viewPath(), c.Data) } return buf.Bytes(), err } @@ -249,7 +251,7 @@ func (c *Controller) renderTemplate() (bytes.Buffer, error) { } } } - BuildTemplate(c.viewPath() , buildFiles...) + BuildTemplate(c.viewPath(), buildFiles...) } return buf, ExecuteViewPathTemplate(&buf, c.TplName, c.viewPath(), c.Data) } @@ -314,7 +316,7 @@ func (c *Controller) ServeJSON(encoding ...bool) { if BConfig.RunMode == PROD { hasIndent = false } - if len(encoding) > 0 && encoding[0] == true { + if len(encoding) > 0 && encoding[0] { hasEncoding = true } c.Ctx.Output.JSON(c.Data["json"], hasIndent, hasEncoding) diff --git a/controller_test.go b/controller_test.go index c2025860..1e53416d 100644 --- a/controller_test.go +++ b/controller_test.go @@ -172,10 +172,10 @@ func TestAdditionalViewPaths(t *testing.T) { t.Fatal("TestAdditionalViewPaths expected error") } }() - ctrl.RenderString(); + ctrl.RenderString() }() ctrl.TplName = "file2.tpl" ctrl.ViewPath = dir2 - ctrl.RenderString(); + ctrl.RenderString() } diff --git a/error.go b/error.go index ab626247..b913db39 100644 --- a/error.go +++ b/error.go @@ -252,6 +252,30 @@ func forbidden(rw http.ResponseWriter, r *http.Request) { ) } +// show 422 missing xsrf token +func missingxsrf(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 422, + "
The page you have requested is forbidden."+ + "
Perhaps you are here because:"+ + "

", + ) +} + +// show 417 invalid xsrf token +func invalidxsrf(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 417, + "
The page you have requested is forbidden."+ + "
Perhaps you are here because:"+ + "

", + ) +} + // show 404 not found error. func notFound(rw http.ResponseWriter, r *http.Request) { responseError(rw, r, diff --git a/flash_test.go b/flash_test.go index 640d54de..d5e9608d 100644 --- a/flash_test.go +++ b/flash_test.go @@ -48,7 +48,7 @@ func TestFlashHeader(t *testing.T) { // match for the expected header res := strings.Contains(sc, "BEEGO_FLASH=%00notice%23BEEGOFLASH%23TestFlashString%00") // validate the assertion - if res != true { + if !res { t.Errorf("TestFlashHeader() unable to validate flash message") } } diff --git a/grace/listener.go b/grace/listener.go index 5439d0b2..823d3cce 100644 --- a/grace/listener.go +++ b/grace/listener.go @@ -21,7 +21,7 @@ func newGraceListener(l net.Listener, srv *Server) (el *graceListener) { server: srv, } go func() { - _ = <-el.stop + <-el.stop el.stopped = true el.stop <- el.Listener.Close() }() diff --git a/grace/server.go b/grace/server.go index cc985552..b8242335 100644 --- a/grace/server.go +++ b/grace/server.go @@ -196,7 +196,6 @@ func (srv *Server) signalHooks(ppFlag int, sig os.Signal) { for _, f := range srv.SignalHooks[ppFlag][sig] { f() } - return } // shutdown closes the listener so that no new connections are accepted. it also @@ -292,7 +291,7 @@ func (srv *Server) fork() (err error) { // RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal. func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) (err error) { if ppFlag != PreSignal && ppFlag != PostSignal { - err = fmt.Errorf("Invalid ppFlag argument. Must be either grace.PreSignal or grace.PostSignal.") + err = fmt.Errorf("Invalid ppFlag argument. Must be either grace.PreSignal or grace.PostSignal") return } for _, s := range hookableSignals { @@ -301,6 +300,6 @@ func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) (err return } } - err = fmt.Errorf("Signal '%v' is not supported.", sig) + err = fmt.Errorf("Signal '%v' is not supported", sig) return } diff --git a/hooks.go b/hooks.go index 091ecbc7..c5ec8e2d 100644 --- a/hooks.go +++ b/hooks.go @@ -32,6 +32,8 @@ func registerDefaultErrorHandler() error { "502": badGateway, "503": serviceUnavailable, "504": gatewayTimeout, + "417": invalidxsrf, + "422": missingxsrf, } for e, h := range m { if _, ok := ErrorMaps[e]; !ok { @@ -55,9 +57,9 @@ func registerSession() error { conf.ProviderConfig = filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig) conf.DisableHTTPOnly = BConfig.WebConfig.Session.SessionDisableHTTPOnly conf.Domain = BConfig.WebConfig.Session.SessionDomain - conf.EnableSidInHttpHeader = BConfig.WebConfig.Session.SessionEnableSidInHTTPHeader - conf.SessionNameInHttpHeader = BConfig.WebConfig.Session.SessionNameInHTTPHeader - conf.EnableSidInUrlQuery = BConfig.WebConfig.Session.SessionEnableSidInURLQuery + conf.EnableSidInHTTPHeader = BConfig.WebConfig.Session.SessionEnableSidInHTTPHeader + conf.SessionNameInHTTPHeader = BConfig.WebConfig.Session.SessionNameInHTTPHeader + conf.EnableSidInURLQuery = BConfig.WebConfig.Session.SessionEnableSidInURLQuery } else { if err = json.Unmarshal([]byte(sessionConfig), conf); err != nil { return err diff --git a/httplib/README.md b/httplib/README.md index 6a72cf7c..97df8e6b 100644 --- a/httplib/README.md +++ b/httplib/README.md @@ -32,7 +32,7 @@ The default timeout is `60` seconds, function prototype: SetTimeout(connectTimeout, readWriteTimeout time.Duration) -Exmaple: +Example: // GET httplib.Get("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second) diff --git a/httplib/httplib.go b/httplib/httplib.go index 39480469..4fd572d6 100644 --- a/httplib/httplib.go +++ b/httplib/httplib.go @@ -335,7 +335,7 @@ func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) func (b *BeegoHTTPRequest) buildURL(paramBody string) { // build GET url with query string if b.req.Method == "GET" && len(paramBody) > 0 { - if strings.Index(b.url, "?") != -1 { + if strings.Contains(b.url, "?") { b.url += "&" + paramBody } else { b.url = b.url + "?" + paramBody @@ -344,7 +344,7 @@ func (b *BeegoHTTPRequest) buildURL(paramBody string) { } // build POST/PUT/PATCH url and body - if (b.req.Method == "POST" || b.req.Method == "PUT" || b.req.Method == "PATCH") && b.req.Body == nil { + if (b.req.Method == "POST" || b.req.Method == "PUT" || b.req.Method == "PATCH" || b.req.Method == "DELETE") && b.req.Body == nil { // with files if len(b.files) > 0 { pr, pw := io.Pipe() @@ -520,9 +520,9 @@ func (b *BeegoHTTPRequest) Bytes() ([]byte, error) { return nil, err } b.body, err = ioutil.ReadAll(reader) - } else { - b.body, err = ioutil.ReadAll(resp.Body) + return b.body, err } + b.body, err = ioutil.ReadAll(resp.Body) return b.body, err } diff --git a/httplib/httplib_test.go b/httplib/httplib_test.go index 05815054..32d3e7f6 100644 --- a/httplib/httplib_test.go +++ b/httplib/httplib_test.go @@ -102,6 +102,14 @@ func TestSimpleDelete(t *testing.T) { t.Log(str) } +func TestSimpleDeleteParam(t *testing.T) { + str, err := Delete("http://httpbin.org/delete").Param("key", "val").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} + func TestWithCookie(t *testing.T) { v := "smallfish" str, err := Get("http://httpbin.org/cookies/set?k1=" + v).SetEnableCookie(true).String() diff --git a/logs/alils/alils.go b/logs/alils/alils.go index 30a09243..867ff4cb 100644 --- a/logs/alils/alils.go +++ b/logs/alils/alils.go @@ -2,19 +2,23 @@ package alils import ( "encoding/json" - "github.com/astaxie/beego/logs" - "github.com/gogo/protobuf/proto" "strings" "sync" "time" + + "github.com/astaxie/beego/logs" + "github.com/gogo/protobuf/proto" ) const ( - CacheSize int = 64 + // CacheSize set the flush size + CacheSize int = 64 + // Delimiter define the topic delimiter Delimiter string = "##" ) -type AliLSConfig struct { +// Config is the Config for Ali Log +type Config struct { Project string `json:"project"` Endpoint string `json:"endpoint"` KeyID string `json:"key_id"` @@ -34,18 +38,17 @@ type aliLSWriter struct { withMap bool groupMap map[string]*LogGroup lock *sync.Mutex - AliLSConfig + Config } -// 创建提供Logger接口的日志服务 +// NewAliLS create a new Logger func NewAliLS() logs.Logger { alils := new(aliLSWriter) alils.Level = logs.LevelTrace return alils } -// 读取配置 -// 初始化必要的数据结构 +// Init parse config and init struct func (c *aliLSWriter) Init(jsonConfig string) (err error) { json.Unmarshal([]byte(jsonConfig), c) @@ -54,28 +57,26 @@ func (c *aliLSWriter) Init(jsonConfig string) (err error) { c.FlushWhen = CacheSize } - // 初始化Project prj := &LogProject{ Name: c.Project, Endpoint: c.Endpoint, - AccessKeyId: c.KeyID, + AccessKeyID: c.KeyID, AccessKeySecret: c.KeySecret, } - // 获取logstore c.store, err = prj.GetLogStore(c.LogStore) if err != nil { return err } - // 创建默认Log Group + // Create default Log Group c.group = append(c.group, &LogGroup{ Topic: proto.String(""), Source: proto.String(c.Source), Logs: make([]*Log, 0, c.FlushWhen), }) - // 创建其它Log Group + // Create other Log Group c.groupMap = make(map[string]*LogGroup) for _, topic := range c.Topics { @@ -113,7 +114,7 @@ func (c *aliLSWriter) WriteMsg(when time.Time, msg string, level int) (err error var lg *LogGroup if c.withMap { - // 解析出Topic,并匹配LogGroup + // Topic,LogGroup strs := strings.SplitN(msg, Delimiter, 2) if len(strs) == 2 { pos := strings.LastIndex(strs[0], " ") @@ -122,27 +123,24 @@ func (c *aliLSWriter) WriteMsg(when time.Time, msg string, level int) (err error lg = c.groupMap[topic] } - // 默认发到空Topic + // send to empty Topic if lg == nil { - topic = "" content = msg lg = c.group[0] } } else { - topic = "" content = msg lg = c.group[0] } - // 生成日志 - c1 := &Log_Content{ + c1 := &LogContent{ Key: proto.String("msg"), Value: proto.String(content), } l := &Log{ - Time: proto.Uint32(uint32(when.Unix())), // 填写日志时间 - Contents: []*Log_Content{ + Time: proto.Uint32(uint32(when.Unix())), + Contents: []*LogContent{ c1, }, } @@ -151,7 +149,6 @@ func (c *aliLSWriter) WriteMsg(when time.Time, msg string, level int) (err error lg.Logs = append(lg.Logs, l) c.lock.Unlock() - // 满足条件则Flush if len(lg.Logs) >= c.FlushWhen { c.flush(lg) } @@ -162,7 +159,7 @@ func (c *aliLSWriter) WriteMsg(when time.Time, msg string, level int) (err error // Flush implementing method. empty. func (c *aliLSWriter) Flush() { - // flush所有group + // flush all group for _, lg := range c.group { c.flush(lg) } @@ -176,9 +173,6 @@ func (c *aliLSWriter) flush(lg *LogGroup) { c.lock.Lock() defer c.lock.Unlock() - - // 把以上的LogGroup推送到SLS服务器, - // SLS服务器会根据该logstore的shard个数自动进行负载均衡。 err := c.store.PutLogs(lg) if err != nil { return diff --git a/logs/alils/log.pb.go b/logs/alils/log.pb.go index 42f7e892..601b0d78 100755 --- a/logs/alils/log.pb.go +++ b/logs/alils/log.pb.go @@ -1,30 +1,43 @@ package alils -import "github.com/gogo/protobuf/proto" -import "fmt" -import "math" +import ( + "fmt" + "io" + "math" -// discarding unused import gogoproto "." - -import github_com_gogo_protobuf_proto "github.com/gogo/protobuf/proto" - -import "io" + "github.com/gogo/protobuf/proto" + github_com_gogo_protobuf_proto "github.com/gogo/protobuf/proto" +) // Reference imports to suppress errors if they are not otherwise used. var _ = proto.Marshal var _ = fmt.Errorf var _ = math.Inf +var ( + // ErrInvalidLengthLog invalid proto + ErrInvalidLengthLog = fmt.Errorf("proto: negative length found during unmarshaling") + // ErrIntOverflowLog overflow + ErrIntOverflowLog = fmt.Errorf("proto: integer overflow") +) + +// Log define the proto Log type Log struct { - Time *uint32 `protobuf:"varint,1,req,name=Time" json:"Time,omitempty"` - Contents []*Log_Content `protobuf:"bytes,2,rep,name=Contents" json:"Contents,omitempty"` - XXX_unrecognized []byte `json:"-"` + Time *uint32 `protobuf:"varint,1,req,name=Time" json:"Time,omitempty"` + Contents []*LogContent `protobuf:"bytes,2,rep,name=Contents" json:"Contents,omitempty"` + XXXUnrecognized []byte `json:"-"` } -func (m *Log) Reset() { *m = Log{} } -func (m *Log) String() string { return proto.CompactTextString(m) } -func (*Log) ProtoMessage() {} +// Reset the Log +func (m *Log) Reset() { *m = Log{} } +// String return the Compact Log +func (m *Log) String() string { return proto.CompactTextString(m) } + +// ProtoMessage not implemented +func (*Log) ProtoMessage() {} + +// GetTime return the Log's Time func (m *Log) GetTime() uint32 { if m != nil && m.Time != nil { return *m.Time @@ -32,49 +45,65 @@ func (m *Log) GetTime() uint32 { return 0 } -func (m *Log) GetContents() []*Log_Content { +// GetContents return the Log's Contents +func (m *Log) GetContents() []*LogContent { if m != nil { return m.Contents } return nil } -type Log_Content struct { - Key *string `protobuf:"bytes,1,req,name=Key" json:"Key,omitempty"` - Value *string `protobuf:"bytes,2,req,name=Value" json:"Value,omitempty"` - XXX_unrecognized []byte `json:"-"` +// LogContent define the Log content struct +type LogContent struct { + Key *string `protobuf:"bytes,1,req,name=Key" json:"Key,omitempty"` + Value *string `protobuf:"bytes,2,req,name=Value" json:"Value,omitempty"` + XXXUnrecognized []byte `json:"-"` } -func (m *Log_Content) Reset() { *m = Log_Content{} } -func (m *Log_Content) String() string { return proto.CompactTextString(m) } -func (*Log_Content) ProtoMessage() {} +// Reset LogContent +func (m *LogContent) Reset() { *m = LogContent{} } -func (m *Log_Content) GetKey() string { +// String return the compact text +func (m *LogContent) String() string { return proto.CompactTextString(m) } + +// ProtoMessage not implemented +func (*LogContent) ProtoMessage() {} + +// GetKey return the Key +func (m *LogContent) GetKey() string { if m != nil && m.Key != nil { return *m.Key } return "" } -func (m *Log_Content) GetValue() string { +// GetValue return the Value +func (m *LogContent) GetValue() string { if m != nil && m.Value != nil { return *m.Value } return "" } +// LogGroup define the logs struct type LogGroup struct { - Logs []*Log `protobuf:"bytes,1,rep,name=Logs" json:"Logs,omitempty"` - Reserved *string `protobuf:"bytes,2,opt,name=Reserved" json:"Reserved,omitempty"` - Topic *string `protobuf:"bytes,3,opt,name=Topic" json:"Topic,omitempty"` - Source *string `protobuf:"bytes,4,opt,name=Source" json:"Source,omitempty"` - XXX_unrecognized []byte `json:"-"` + Logs []*Log `protobuf:"bytes,1,rep,name=Logs" json:"Logs,omitempty"` + Reserved *string `protobuf:"bytes,2,opt,name=Reserved" json:"Reserved,omitempty"` + Topic *string `protobuf:"bytes,3,opt,name=Topic" json:"Topic,omitempty"` + Source *string `protobuf:"bytes,4,opt,name=Source" json:"Source,omitempty"` + XXXUnrecognized []byte `json:"-"` } -func (m *LogGroup) Reset() { *m = LogGroup{} } -func (m *LogGroup) String() string { return proto.CompactTextString(m) } -func (*LogGroup) ProtoMessage() {} +// Reset LogGroup +func (m *LogGroup) Reset() { *m = LogGroup{} } +// String return the compact text +func (m *LogGroup) String() string { return proto.CompactTextString(m) } + +// ProtoMessage not implemented +func (*LogGroup) ProtoMessage() {} + +// GetLogs return the loggroup logs func (m *LogGroup) GetLogs() []*Log { if m != nil { return m.Logs @@ -82,6 +111,7 @@ func (m *LogGroup) GetLogs() []*Log { return nil } +// GetReserved return Reserved func (m *LogGroup) GetReserved() string { if m != nil && m.Reserved != nil { return *m.Reserved @@ -89,6 +119,7 @@ func (m *LogGroup) GetReserved() string { return "" } +// GetTopic return Topic func (m *LogGroup) GetTopic() string { if m != nil && m.Topic != nil { return *m.Topic @@ -96,6 +127,7 @@ func (m *LogGroup) GetTopic() string { return "" } +// GetSource return Source func (m *LogGroup) GetSource() string { if m != nil && m.Source != nil { return *m.Source @@ -103,15 +135,22 @@ func (m *LogGroup) GetSource() string { return "" } +// LogGroupList define the LogGroups type LogGroupList struct { - LogGroups []*LogGroup `protobuf:"bytes,1,rep,name=logGroups" json:"logGroups,omitempty"` - XXX_unrecognized []byte `json:"-"` + LogGroups []*LogGroup `protobuf:"bytes,1,rep,name=logGroups" json:"logGroups,omitempty"` + XXXUnrecognized []byte `json:"-"` } -func (m *LogGroupList) Reset() { *m = LogGroupList{} } -func (m *LogGroupList) String() string { return proto.CompactTextString(m) } -func (*LogGroupList) ProtoMessage() {} +// Reset LogGroupList +func (m *LogGroupList) Reset() { *m = LogGroupList{} } +// String return compact text +func (m *LogGroupList) String() string { return proto.CompactTextString(m) } + +// ProtoMessage not implemented +func (*LogGroupList) ProtoMessage() {} + +// GetLogGroups return the LogGroups func (m *LogGroupList) GetLogGroups() []*LogGroup { if m != nil { return m.LogGroups @@ -119,6 +158,7 @@ func (m *LogGroupList) GetLogGroups() []*LogGroup { return nil } +// Marshal the logs to byte slice func (m *Log) Marshal() (data []byte, err error) { size := m.Size() data = make([]byte, size) @@ -129,6 +169,7 @@ func (m *Log) Marshal() (data []byte, err error) { return data[:n], nil } +// MarshalTo data func (m *Log) MarshalTo(data []byte) (int, error) { var i int _ = i @@ -136,11 +177,10 @@ func (m *Log) MarshalTo(data []byte) (int, error) { _ = l if m.Time == nil { return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Time") - } else { - data[i] = 0x8 - i++ - i = encodeVarintLog(data, i, uint64(*m.Time)) } + data[i] = 0x8 + i++ + i = encodeVarintLog(data, i, uint64(*m.Time)) if len(m.Contents) > 0 { for _, msg := range m.Contents { data[i] = 0x12 @@ -153,13 +193,14 @@ func (m *Log) MarshalTo(data []byte) (int, error) { i += n } } - if m.XXX_unrecognized != nil { - i += copy(data[i:], m.XXX_unrecognized) + if m.XXXUnrecognized != nil { + i += copy(data[i:], m.XXXUnrecognized) } return i, nil } -func (m *Log_Content) Marshal() (data []byte, err error) { +// Marshal LogContent +func (m *LogContent) Marshal() (data []byte, err error) { size := m.Size() data = make([]byte, size) n, err := m.MarshalTo(data) @@ -169,33 +210,34 @@ func (m *Log_Content) Marshal() (data []byte, err error) { return data[:n], nil } -func (m *Log_Content) MarshalTo(data []byte) (int, error) { +// MarshalTo logcontent to data +func (m *LogContent) MarshalTo(data []byte) (int, error) { var i int _ = i var l int _ = l if m.Key == nil { return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Key") - } else { - data[i] = 0xa - i++ - i = encodeVarintLog(data, i, uint64(len(*m.Key))) - i += copy(data[i:], *m.Key) } + data[i] = 0xa + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Key))) + i += copy(data[i:], *m.Key) + if m.Value == nil { return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Value") - } else { - data[i] = 0x12 - i++ - i = encodeVarintLog(data, i, uint64(len(*m.Value))) - i += copy(data[i:], *m.Value) } - if m.XXX_unrecognized != nil { - i += copy(data[i:], m.XXX_unrecognized) + data[i] = 0x12 + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Value))) + i += copy(data[i:], *m.Value) + if m.XXXUnrecognized != nil { + i += copy(data[i:], m.XXXUnrecognized) } return i, nil } +// Marshal LogGroup func (m *LogGroup) Marshal() (data []byte, err error) { size := m.Size() data = make([]byte, size) @@ -206,6 +248,7 @@ func (m *LogGroup) Marshal() (data []byte, err error) { return data[:n], nil } +// MarshalTo LogGroup to data func (m *LogGroup) MarshalTo(data []byte) (int, error) { var i int _ = i @@ -241,12 +284,13 @@ func (m *LogGroup) MarshalTo(data []byte) (int, error) { i = encodeVarintLog(data, i, uint64(len(*m.Source))) i += copy(data[i:], *m.Source) } - if m.XXX_unrecognized != nil { - i += copy(data[i:], m.XXX_unrecognized) + if m.XXXUnrecognized != nil { + i += copy(data[i:], m.XXXUnrecognized) } return i, nil } +// Marshal LogGroupList func (m *LogGroupList) Marshal() (data []byte, err error) { size := m.Size() data = make([]byte, size) @@ -257,6 +301,7 @@ func (m *LogGroupList) Marshal() (data []byte, err error) { return data[:n], nil } +// MarshalTo LogGroupList to data func (m *LogGroupList) MarshalTo(data []byte) (int, error) { var i int _ = i @@ -274,8 +319,8 @@ func (m *LogGroupList) MarshalTo(data []byte) (int, error) { i += n } } - if m.XXX_unrecognized != nil { - i += copy(data[i:], m.XXX_unrecognized) + if m.XXXUnrecognized != nil { + i += copy(data[i:], m.XXXUnrecognized) } return i, nil } @@ -307,6 +352,8 @@ func encodeVarintLog(data []byte, offset int, v uint64) int { data[offset] = uint8(v) return offset + 1 } + +// Size return the log's size func (m *Log) Size() (n int) { var l int _ = l @@ -319,13 +366,14 @@ func (m *Log) Size() (n int) { n += 1 + l + sovLog(uint64(l)) } } - if m.XXX_unrecognized != nil { - n += len(m.XXX_unrecognized) + if m.XXXUnrecognized != nil { + n += len(m.XXXUnrecognized) } return n } -func (m *Log_Content) Size() (n int) { +// Size return LogContent size based on Key and Value +func (m *LogContent) Size() (n int) { var l int _ = l if m.Key != nil { @@ -336,12 +384,13 @@ func (m *Log_Content) Size() (n int) { l = len(*m.Value) n += 1 + l + sovLog(uint64(l)) } - if m.XXX_unrecognized != nil { - n += len(m.XXX_unrecognized) + if m.XXXUnrecognized != nil { + n += len(m.XXXUnrecognized) } return n } +// Size return LogGroup size based on Logs func (m *LogGroup) Size() (n int) { var l int _ = l @@ -363,12 +412,13 @@ func (m *LogGroup) Size() (n int) { l = len(*m.Source) n += 1 + l + sovLog(uint64(l)) } - if m.XXX_unrecognized != nil { - n += len(m.XXX_unrecognized) + if m.XXXUnrecognized != nil { + n += len(m.XXXUnrecognized) } return n } +// Size return LogGroupList size func (m *LogGroupList) Size() (n int) { var l int _ = l @@ -378,8 +428,8 @@ func (m *LogGroupList) Size() (n int) { n += 1 + l + sovLog(uint64(l)) } } - if m.XXX_unrecognized != nil { - n += len(m.XXX_unrecognized) + if m.XXXUnrecognized != nil { + n += len(m.XXXUnrecognized) } return n } @@ -395,8 +445,10 @@ func sovLog(x uint64) (n int) { return n } func sozLog(x uint64) (n int) { - return sovLog(uint64((x << 1) ^ uint64((int64(x) >> 63)))) + return sovLog((x << 1) ^ (x >> 63)) } + +// Unmarshal data to log func (m *Log) Unmarshal(data []byte) error { var hasFields [1]uint64 l := len(data) @@ -474,7 +526,7 @@ func (m *Log) Unmarshal(data []byte) error { if postIndex > l { return io.ErrUnexpectedEOF } - m.Contents = append(m.Contents, &Log_Content{}) + m.Contents = append(m.Contents, &LogContent{}) if err := m.Contents[len(m.Contents)-1].Unmarshal(data[iNdEx:postIndex]); err != nil { return err } @@ -491,7 +543,7 @@ func (m *Log) Unmarshal(data []byte) error { if (iNdEx + skippy) > l { return io.ErrUnexpectedEOF } - m.XXX_unrecognized = append(m.XXX_unrecognized, data[iNdEx:iNdEx+skippy]...) + m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) iNdEx += skippy } } @@ -504,7 +556,9 @@ func (m *Log) Unmarshal(data []byte) error { } return nil } -func (m *Log_Content) Unmarshal(data []byte) error { + +// Unmarshal data to LogContent +func (m *LogContent) Unmarshal(data []byte) error { var hasFields [1]uint64 l := len(data) iNdEx := 0 @@ -608,7 +662,7 @@ func (m *Log_Content) Unmarshal(data []byte) error { if (iNdEx + skippy) > l { return io.ErrUnexpectedEOF } - m.XXX_unrecognized = append(m.XXX_unrecognized, data[iNdEx:iNdEx+skippy]...) + m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) iNdEx += skippy } } @@ -624,6 +678,8 @@ func (m *Log_Content) Unmarshal(data []byte) error { } return nil } + +// Unmarshal data to LogGroup func (m *LogGroup) Unmarshal(data []byte) error { l := len(data) iNdEx := 0 @@ -786,7 +842,7 @@ func (m *LogGroup) Unmarshal(data []byte) error { if (iNdEx + skippy) > l { return io.ErrUnexpectedEOF } - m.XXX_unrecognized = append(m.XXX_unrecognized, data[iNdEx:iNdEx+skippy]...) + m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) iNdEx += skippy } } @@ -796,6 +852,8 @@ func (m *LogGroup) Unmarshal(data []byte) error { } return nil } + +// Unmarshal data to LogGroupList func (m *LogGroupList) Unmarshal(data []byte) error { l := len(data) iNdEx := 0 @@ -868,7 +926,7 @@ func (m *LogGroupList) Unmarshal(data []byte) error { if (iNdEx + skippy) > l { return io.ErrUnexpectedEOF } - m.XXX_unrecognized = append(m.XXX_unrecognized, data[iNdEx:iNdEx+skippy]...) + m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) iNdEx += skippy } } @@ -878,6 +936,7 @@ func (m *LogGroupList) Unmarshal(data []byte) error { } return nil } + func skipLog(data []byte) (n int, err error) { l := len(data) iNdEx := 0 @@ -940,7 +999,7 @@ func skipLog(data []byte) (n int, err error) { case 3: for { var innerWire uint64 - var start int = iNdEx + var start = iNdEx for shift := uint(0); ; shift += 7 { if shift >= 64 { return 0, ErrIntOverflowLog @@ -977,8 +1036,3 @@ func skipLog(data []byte) (n int, err error) { } panic("unreachable") } - -var ( - ErrInvalidLengthLog = fmt.Errorf("proto: negative length found during unmarshaling") - ErrIntOverflowLog = fmt.Errorf("proto: integer overflow") -) diff --git a/logs/alils/log_config.go b/logs/alils/log_config.go index 41fa0959..e8564efb 100755 --- a/logs/alils/log_config.go +++ b/logs/alils/log_config.go @@ -1,5 +1,6 @@ package alils +// InputDetail define log detail type InputDetail struct { LogType string `json:"logType"` LogPath string `json:"logPath"` @@ -14,11 +15,13 @@ type InputDetail struct { TopicFormat string `json:"topicFormat"` } +// OutputDetail define the output detail type OutputDetail struct { Endpoint string `json:"endpoint"` LogStoreName string `json:"logstoreName"` } +// LogConfig define Log Config type LogConfig struct { Name string `json:"configName"` InputType string `json:"inputType"` diff --git a/logs/alils/log_project.go b/logs/alils/log_project.go index 63ab07f8..59db8cbf 100755 --- a/logs/alils/log_project.go +++ b/logs/alils/log_project.go @@ -1,5 +1,5 @@ /* -Package sls implements the SDK(v0.5.0) of Simple Log Service(abbr. SLS). +Package alils implements the SDK(v0.5.0) of Simple Log Service(abbr. SLS). For more description about SLS, please read this article: http://gitlab.alibaba-inc.com/sls/doc. @@ -20,19 +20,20 @@ type errorMessage struct { Message string `json:"errorMessage"` } +// LogProject Define the Ali Project detail type LogProject struct { Name string // Project name Endpoint string // IP or hostname of SLS endpoint - AccessKeyId string + AccessKeyID string AccessKeySecret string } // NewLogProject creates a new SLS project. -func NewLogProject(name, endpoint, accessKeyId, accessKeySecret string) (p *LogProject, err error) { +func NewLogProject(name, endpoint, AccessKeyID, accessKeySecret string) (p *LogProject, err error) { p = &LogProject{ Name: name, Endpoint: endpoint, - AccessKeyId: accessKeyId, + AccessKeyID: AccessKeyID, AccessKeySecret: accessKeySecret, } return p, nil diff --git a/logs/alils/log_store.go b/logs/alils/log_store.go index 009e39c4..fa502736 100755 --- a/logs/alils/log_store.go +++ b/logs/alils/log_store.go @@ -12,6 +12,7 @@ import ( "github.com/gogo/protobuf/proto" ) +// LogStore Store the logs type LogStore struct { Name string `json:"logstoreName"` TTL int @@ -23,6 +24,7 @@ type LogStore struct { project *LogProject } +// Shard define the Log Shard type Shard struct { ShardID int `json:"shardID"` } @@ -116,16 +118,16 @@ func (s *LogStore) PutLogs(lg *LogGroup) (err error) { return } -// GetCursor gets log cursor of one shard specified by shardId. +// GetCursor gets log cursor of one shard specified by shardID. // The from can be in three form: a) unix timestamp in seccond, b) "begin", c) "end". // For more detail please read: http://gitlab.alibaba-inc.com/sls/doc/blob/master/api/shard.md#logstore -func (s *LogStore) GetCursor(shardId int, from string) (cursor string, err error) { +func (s *LogStore) GetCursor(shardID int, from string) (cursor string, err error) { h := map[string]string{ "x-sls-bodyrawsize": "0", } uri := fmt.Sprintf("/logstores/%v/shards/%v?type=cursor&from=%v", - s.Name, shardId, from) + s.Name, shardID, from) r, err := request(s.project, "GET", uri, h, nil) if err != nil { @@ -163,10 +165,10 @@ func (s *LogStore) GetCursor(shardId int, from string) (cursor string, err error return } -// GetLogsBytes gets logs binary data from shard specified by shardId according cursor. +// GetLogsBytes gets logs binary data from shard specified by shardID according cursor. // The logGroupMaxCount is the max number of logGroup could be returned. // The nextCursor is the next curosr can be used to read logs at next time. -func (s *LogStore) GetLogsBytes(shardId int, cursor string, +func (s *LogStore) GetLogsBytes(shardID int, cursor string, logGroupMaxCount int) (out []byte, nextCursor string, err error) { h := map[string]string{ @@ -176,7 +178,7 @@ func (s *LogStore) GetLogsBytes(shardId int, cursor string, } uri := fmt.Sprintf("/logstores/%v/shards/%v?type=logs&cursor=%v&count=%v", - s.Name, shardId, cursor, logGroupMaxCount) + s.Name, shardID, cursor, logGroupMaxCount) r, err := request(s.project, "GET", uri, h, nil) if err != nil { @@ -249,13 +251,13 @@ func LogsBytesDecode(data []byte) (gl *LogGroupList, err error) { return } -// GetLogs gets logs from shard specified by shardId according cursor. +// GetLogs gets logs from shard specified by shardID according cursor. // The logGroupMaxCount is the max number of logGroup could be returned. // The nextCursor is the next curosr can be used to read logs at next time. -func (s *LogStore) GetLogs(shardId int, cursor string, +func (s *LogStore) GetLogs(shardID int, cursor string, logGroupMaxCount int) (gl *LogGroupList, nextCursor string, err error) { - out, nextCursor, err := s.GetLogsBytes(shardId, cursor, logGroupMaxCount) + out, nextCursor, err := s.GetLogsBytes(shardID, cursor, logGroupMaxCount) if err != nil { return } diff --git a/logs/alils/machine_group.go b/logs/alils/machine_group.go index 7a0aace1..b6c69a14 100755 --- a/logs/alils/machine_group.go +++ b/logs/alils/machine_group.go @@ -8,18 +8,20 @@ import ( "net/http/httputil" ) -type MachinGroupAttribute struct { +// MachineGroupAttribute define the Attribute +type MachineGroupAttribute struct { ExternalName string `json:"externalName"` TopicName string `json:"groupTopic"` } +// MachineGroup define the machine Group type MachineGroup struct { Name string `json:"groupName"` Type string `json:"groupType"` - MachineIdType string `json:"machineIdentifyType"` - MachineIdList []string `json:"machineList"` + MachineIDType string `json:"machineIdentifyType"` + MachineIDList []string `json:"machineList"` - Attribute MachinGroupAttribute `json:"groupAttribute"` + Attribute MachineGroupAttribute `json:"groupAttribute"` CreateTime uint32 LastModifyTime uint32 @@ -27,12 +29,14 @@ type MachineGroup struct { project *LogProject } +// Machine define the Machine type Machine struct { IP string - UniqueId string `json:"machine-uniqueid"` - UserdefinedId string `json:"userdefined-id"` + UniqueID string `json:"machine-uniqueid"` + UserdefinedID string `json:"userdefined-id"` } +// MachineList define the Machine List type MachineList struct { Total int Machines []*Machine diff --git a/logs/alils/request.go b/logs/alils/request.go index 20df45b4..50d9c43c 100755 --- a/logs/alils/request.go +++ b/logs/alils/request.go @@ -33,12 +33,12 @@ func request(project *LogProject, method, uri string, headers map[string]string, } // Calc Authorization - // Authorization = "SLS :" + // Authorization = "SLS :" digest, err := signature(project, method, uri, headers) if err != nil { return } - auth := fmt.Sprintf("SLS %v:%v", project.AccessKeyId, digest) + auth := fmt.Sprintf("SLS %v:%v", project.AccessKeyID, digest) headers["Authorization"] = auth // Initialize http request diff --git a/logs/alils/signature.go b/logs/alils/signature.go index e0e4b3f7..2d611307 100755 --- a/logs/alils/signature.go +++ b/logs/alils/signature.go @@ -76,7 +76,7 @@ func signature(project *LogProject, method, uri string, var keys sort.StringSlice vals := u.Query() - for k, _ := range vals { + for k := range vals { keys = append(keys, k) } @@ -109,4 +109,3 @@ func signature(project *LogProject, method, uri string, digest = base64.StdEncoding.EncodeToString(mac.Sum(nil)) return } - diff --git a/logs/color_windows.go b/logs/color_windows.go index deee4c87..4e28f188 100644 --- a/logs/color_windows.go +++ b/logs/color_windows.go @@ -361,7 +361,7 @@ func isParameterChar(b byte) bool { } func (cw *ansiColorWriter) Write(p []byte) (int, error) { - r, nw, first, last := 0, 0, 0, 0 + var r, nw, first, last int if cw.mode != DiscardNonColorEscSeq { cw.state = outsideCsiCode cw.resetBuffer() diff --git a/logs/console.go b/logs/console.go index e6bf6c29..e75f2a1b 100644 --- a/logs/console.go +++ b/logs/console.go @@ -41,7 +41,7 @@ var colors = []brush{ newBrush("1;33"), // Warning yellow newBrush("1;32"), // Notice green newBrush("1;34"), // Informational blue - newBrush("1;34"), // Debug blue + newBrush("1;44"), // Debug Background blue } // consoleWriter implements LoggerInterface and writes messages to terminal. diff --git a/logs/file.go b/logs/file.go index bd3c22a9..1c2db882 100644 --- a/logs/file.go +++ b/logs/file.go @@ -170,7 +170,7 @@ func (w *fileLogWriter) initFd() error { fd := w.fileWriter fInfo, err := fd.Stat() if err != nil { - return fmt.Errorf("get stat err: %s\n", err) + return fmt.Errorf("get stat err: %s", err) } w.maxSizeCurSize = int(fInfo.Size()) w.dailyOpenTime = time.Now() @@ -193,16 +193,14 @@ func (w *fileLogWriter) dailyRotate(openTime time.Time) { y, m, d := openTime.Add(24 * time.Hour).Date() nextDay := time.Date(y, m, d, 0, 0, 0, 0, openTime.Location()) tm := time.NewTimer(time.Duration(nextDay.UnixNano() - openTime.UnixNano() + 100)) - select { - case <-tm.C: - w.Lock() - if w.needRotate(0, time.Now().Day()) { - if err := w.doRotate(time.Now()); err != nil { - fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) - } + <-tm.C + w.Lock() + if w.needRotate(0, time.Now().Day()) { + if err := w.doRotate(time.Now()); err != nil { + fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) } - w.Unlock() } + w.Unlock() } func (w *fileLogWriter) lines() (int, error) { @@ -261,7 +259,7 @@ func (w *fileLogWriter) doRotate(logTime time.Time) error { } // return error if the last file checked still existed if err == nil { - return fmt.Errorf("Rotate: Cannot find free log number to rename %s\n", w.Filename) + return fmt.Errorf("Rotate: Cannot find free log number to rename %s", w.Filename) } // close fileWriter before rename @@ -270,7 +268,10 @@ func (w *fileLogWriter) doRotate(logTime time.Time) error { // Rename the file to its new found name // even if occurs error,we MUST guarantee to restart new logger err = os.Rename(w.Filename, fName) - err = os.Chmod(fName, os.FileMode(440)) + if err != nil { + goto RESTART_LOGGER + } + err = os.Chmod(fName, os.FileMode(0440)) // re-start logger RESTART_LOGGER: @@ -278,13 +279,12 @@ RESTART_LOGGER: go w.deleteOldLog() if startLoggerErr != nil { - return fmt.Errorf("Rotate StartLogger: %s\n", startLoggerErr) + return fmt.Errorf("Rotate StartLogger: %s", startLoggerErr) } if err != nil { - return fmt.Errorf("Rotate: %s\n", err) + return fmt.Errorf("Rotate: %s", err) } return nil - } func (w *fileLogWriter) deleteOldLog() { diff --git a/logs/file_test.go b/logs/file_test.go index 69a66d84..f345ff20 100644 --- a/logs/file_test.go +++ b/logs/file_test.go @@ -162,7 +162,27 @@ func TestFileRotate_05(t *testing.T) { testFileDailyRotate(t, fn1, fn2) os.Remove(fn) } - +func TestFileRotate_06(t *testing.T) { //test file mode + log := NewLogger(10000) + log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`) + log.Debug("debug") + log.Info("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) + ".log" + s, _ := os.Lstat(rotateName) + if s.Mode() != 0440 { + os.Remove(rotateName) + os.Remove("test3.log") + t.Fatal("rotate file mode error") + } + os.Remove(rotateName) + os.Remove("test3.log") +} func testFileRotate(t *testing.T, fn1, fn2 string) { fw := &fileLogWriter{ Daily: true, diff --git a/logs/jianliao.go b/logs/jianliao.go index 16773c93..88ba0f9a 100644 --- a/logs/jianliao.go +++ b/logs/jianliao.go @@ -25,11 +25,7 @@ func newJLWriter() Logger { // Init JLWriter with json config string func (s *JLWriter) Init(jsonconfig string) error { - err := json.Unmarshal([]byte(jsonconfig), s) - if err != nil { - return err - } - return nil + return json.Unmarshal([]byte(jsonconfig), s) } // WriteMsg write message in smtp writer. @@ -65,12 +61,10 @@ func (s *JLWriter) WriteMsg(when time.Time, msg string, level int) error { // Flush implementing method. empty. func (s *JLWriter) Flush() { - return } // Destroy implementing method. empty. func (s *JLWriter) Destroy() { - return } func init() { diff --git a/logs/log.go b/logs/log.go index c351c473..0e97a70e 100644 --- a/logs/log.go +++ b/logs/log.go @@ -275,7 +275,7 @@ func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error line = 0 } _, filename := path.Split(file) - msg = "[" + filename + ":" + strconv.FormatInt(int64(line), 10) + "] " + msg + msg = "[" + filename + ":" + strconv.Itoa(line) + "] " + msg } //set level info in front of filename info @@ -492,9 +492,9 @@ func (bl *BeeLogger) flush() { } // beeLogger references the used application logger. -var beeLogger *BeeLogger = NewLogger() +var beeLogger = NewLogger() -// GetLogger returns the default BeeLogger +// GetBeeLogger returns the default BeeLogger func GetBeeLogger() *BeeLogger { return beeLogger } @@ -534,6 +534,7 @@ func Reset() { beeLogger.Reset() } +// Async set the beelogger with Async mode and hold msglen messages func Async(msgLen ...int64) *BeeLogger { return beeLogger.Async(msgLen...) } @@ -561,11 +562,7 @@ func SetLogFuncCallDepth(d int) { // SetLogger sets a new logger. func SetLogger(adapter string, config ...string) error { - err := beeLogger.SetLogger(adapter, config...) - if err != nil { - return err - } - return nil + return beeLogger.SetLogger(adapter, config...) } // Emergency logs a message at emergency level. diff --git a/logs/logger.go b/logs/logger.go index e0abfdc4..b5d7255f 100644 --- a/logs/logger.go +++ b/logs/logger.go @@ -139,6 +139,11 @@ var ( reset = string([]byte{27, 91, 48, 109}) ) +// ColorByStatus return color by http code +// 2xx return Green +// 3xx return White +// 4xx return Yellow +// 5xx return Red func ColorByStatus(cond bool, code int) string { switch { case code >= 200 && code < 300: @@ -152,6 +157,14 @@ func ColorByStatus(cond bool, code int) string { } } +// ColorByMethod return color by http code +// GET return Blue +// POST return Cyan +// PUT return Yellow +// DELETE return Red +// PATCH return Green +// HEAD return Magenta +// OPTIONS return WHITE func ColorByMethod(cond bool, method string) string { switch method { case "GET": @@ -173,10 +186,10 @@ func ColorByMethod(cond bool, method string) string { } } -// Guard Mutex to guarantee atomicity of W32Debug(string) function +// Guard Mutex to guarantee atomic of W32Debug(string) function var mu sync.Mutex -// Helper method to output colored logs in Windows terminals +// W32Debug Helper method to output colored logs in Windows terminals func W32Debug(msg string) { mu.Lock() defer mu.Unlock() diff --git a/logs/slack.go b/logs/slack.go index 90f009cb..1cd2e5ae 100644 --- a/logs/slack.go +++ b/logs/slack.go @@ -21,11 +21,7 @@ func newSLACKWriter() Logger { // Init SLACKWriter with json config string func (s *SLACKWriter) Init(jsonconfig string) error { - err := json.Unmarshal([]byte(jsonconfig), s) - if err != nil { - return err - } - return nil + return json.Unmarshal([]byte(jsonconfig), s) } // WriteMsg write message in smtp writer. @@ -53,12 +49,10 @@ func (s *SLACKWriter) WriteMsg(when time.Time, msg string, level int) error { // Flush implementing method. empty. func (s *SLACKWriter) Flush() { - return } // Destroy implementing method. empty. func (s *SLACKWriter) Destroy() { - return } func init() { diff --git a/logs/smtp.go b/logs/smtp.go index 834130ef..6208d7b8 100644 --- a/logs/smtp.go +++ b/logs/smtp.go @@ -52,11 +52,7 @@ func newSMTPWriter() Logger { // "level":LevelError // } func (s *SMTPWriter) Init(jsonconfig string) error { - err := json.Unmarshal([]byte(jsonconfig), s) - if err != nil { - return err - } - return nil + return json.Unmarshal([]byte(jsonconfig), s) } func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth { @@ -106,7 +102,7 @@ func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAd if err != nil { return err } - _, err = w.Write([]byte(msgContent)) + _, err = w.Write(msgContent) if err != nil { return err } @@ -116,12 +112,7 @@ func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAd return err } - err = client.Quit() - if err != nil { - return err - } - - return nil + return client.Quit() } // WriteMsg write message in smtp writer. @@ -147,12 +138,10 @@ func (s *SMTPWriter) WriteMsg(when time.Time, msg string, level int) error { // Flush implementing method. empty. func (s *SMTPWriter) Flush() { - return } // Destroy implementing method. empty. func (s *SMTPWriter) Destroy() { - return } func init() { diff --git a/namespace.go b/namespace.go index cfde0111..72f22a72 100644 --- a/namespace.go +++ b/namespace.go @@ -267,13 +267,12 @@ func addPrefix(t *Tree, prefix string) { addPrefix(t.wildcard, prefix) } for _, l := range t.leaves { - if c, ok := l.runObject.(*controllerInfo); ok { + if c, ok := l.runObject.(*ControllerInfo); ok { if !strings.HasPrefix(c.pattern, prefix) { c.pattern = prefix + c.pattern } } } - } // NSCond is Namespace Condition @@ -284,16 +283,16 @@ func NSCond(cond namespaceCond) LinkNamespace { } // NSBefore Namespace BeforeRouter filter -func NSBefore(filiterList ...FilterFunc) LinkNamespace { +func NSBefore(filterList ...FilterFunc) LinkNamespace { return func(ns *Namespace) { - ns.Filter("before", filiterList...) + ns.Filter("before", filterList...) } } // NSAfter add Namespace FinishRouter filter -func NSAfter(filiterList ...FilterFunc) LinkNamespace { +func NSAfter(filterList ...FilterFunc) LinkNamespace { return func(ns *Namespace) { - ns.Filter("after", filiterList...) + ns.Filter("after", filterList...) } } diff --git a/namespace_test.go b/namespace_test.go index fc02b5fb..b3f20dff 100644 --- a/namespace_test.go +++ b/namespace_test.go @@ -139,10 +139,7 @@ func TestNamespaceCond(t *testing.T) { ns := NewNamespace("/v2") ns.Cond(func(ctx *context.Context) bool { - if ctx.Input.Domain() == "beego.me" { - return true - } - return false + return ctx.Input.Domain() == "beego.me" }). AutoRouter(&TestController{}) AddNamespace(ns) diff --git a/orm/cmd.go b/orm/cmd.go index 3638a75c..0ff4dc40 100644 --- a/orm/cmd.go +++ b/orm/cmd.go @@ -150,7 +150,7 @@ func (d *commandSyncDb) Run() error { } for _, fi := range mi.fields.fieldsDB { - if _, ok := columns[fi.column]; ok == false { + if _, ok := columns[fi.column]; !ok { fields = append(fields, fi) } } @@ -175,7 +175,7 @@ func (d *commandSyncDb) Run() error { } for _, idx := range indexes[mi.table] { - if d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) == false { + if !d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) { if !d.noInfo { fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table) } diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go index 8119b70b..de47cb02 100644 --- a/orm/cmd_utils.go +++ b/orm/cmd_utils.go @@ -89,7 +89,7 @@ checkColumn: col = T["float64"] case TypeDecimalField: s := T["float64-decimal"] - if strings.Index(s, "%d") == -1 { + if !strings.Contains(s, "%d") { col = s } else { col = fmt.Sprintf(s, fi.digits, fi.decimals) @@ -120,7 +120,7 @@ func getColumnAddQuery(al *alias, fi *fieldInfo) string { Q := al.DbBaser.TableQuote() typ := getColumnTyp(al, fi) - if fi.null == false { + if !fi.null { typ += " " + "NOT NULL" } @@ -172,7 +172,7 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex } else { column += col - if fi.null == false { + if !fi.null { column += " " + "NOT NULL" } @@ -192,7 +192,7 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex } } - if strings.Index(column, "%COL%") != -1 { + if strings.Contains(column, "%COL%") { column = strings.Replace(column, "%COL%", fi.column, -1) } diff --git a/orm/db.go b/orm/db.go index bca6071d..12f0f54d 100644 --- a/orm/db.go +++ b/orm/db.go @@ -48,7 +48,7 @@ var ( "lte": true, "eq": true, "nq": true, - "ne": true, + "ne": true, "startswith": true, "endswith": true, "istartswith": true, @@ -87,7 +87,7 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, } else { panic(fmt.Errorf("wrong db field/column name `%s` for model `%s`", column, mi.fullName)) } - if fi.dbcol == false || fi.auto && skipAuto { + if !fi.dbcol || fi.auto && skipAuto { continue } value, err := d.collectFieldValue(mi, fi, ind, insert, tz) @@ -224,7 +224,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val value = nil } } - if fi.null == false && value == nil { + if !fi.null && value == nil { return nil, fmt.Errorf("field `%s` cannot be NULL", fi.fullName) } } @@ -271,7 +271,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, dbcols := make([]string, 0, len(mi.fields.dbcols)) marks := make([]string, 0, len(mi.fields.dbcols)) for _, fi := range mi.fields.fieldsDB { - if fi.auto == false { + if !fi.auto { dbcols = append(dbcols, fi.column) marks = append(marks, "?") } @@ -326,7 +326,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo } else { // default use pk value as where condtion. pkColumn, pkValue, ok := getExistPk(mi, ind) - if ok == false { + if !ok { return ErrMissPK } whereCols = []string{pkColumn} @@ -507,10 +507,9 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a case DRPostgres: if len(args) == 0 { return 0, fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName) - } else { - args0 = strings.ToLower(args[0]) - iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0) } + args0 = strings.ToLower(args[0]) + iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0) default: return 0, fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName) } @@ -592,7 +591,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a row := q.QueryRow(query, values...) var id int64 err = row.Scan(&id) - if err.Error() == `pq: syntax error at or near "ON"` { + if err != nil && err.Error() == `pq: syntax error at or near "ON"` { err = fmt.Errorf("postgres version must 9.5 or higher") } return id, err @@ -601,7 +600,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a // execute update sql dbQuerier with given struct reflect.Value. func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { pkName, pkValue, ok := getExistPk(mi, ind) - if ok == false { + if !ok { return 0, ErrMissPK } @@ -654,7 +653,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. } else { // default use pk value as where condtion. pkColumn, pkValue, ok := getExistPk(mi, ind) - if ok == false { + if !ok { return 0, ErrMissPK } whereCols = []string{pkColumn} @@ -699,7 +698,7 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con columns := make([]string, 0, len(params)) values := make([]interface{}, 0, len(params)) for col, val := range params { - if fi, ok := mi.fields.GetByAny(col); ok == false || fi.dbcol == false { + if fi, ok := mi.fields.GetByAny(col); !ok || !fi.dbcol { panic(fmt.Errorf("wrong field/column name `%s`", col)) } else { columns = append(columns, fi.column) @@ -834,7 +833,11 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con if err := rs.Scan(&ref); err != nil { return 0, err } - args = append(args, reflect.ValueOf(ref).Interface()) + pkValue, err := d.convertValueFromDB(mi.fields.pk, reflect.ValueOf(ref).Interface(), tz) + if err != nil { + return 0, err + } + args = append(args, pkValue) cnt++ } @@ -929,7 +932,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi if hasRel { for _, fi := range mi.fields.fieldsDB { if fi.fieldType&IsRelField > 0 { - if maps[fi.column] == false { + if !maps[fi.column] { tCols = append(tCols, fi.column) } } @@ -987,7 +990,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi var cnt int64 for rs.Next() { - if one && cnt == 0 || one == false { + if one && cnt == 0 || !one { if err := rs.Scan(refs...); err != nil { return 0, err } @@ -1067,7 +1070,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi cnt++ } - if one == false { + if !one { if cnt > 0 { ind.Set(slice) } else { @@ -1110,7 +1113,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition // generate sql with replacing operator string placeholders and replaced values. func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { - sql := "" + var sql string params := getFlatParams(fi, args, tz) if len(params) == 0 { @@ -1357,7 +1360,7 @@ end: func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) { fieldType := fi.fieldType - isNative := fi.isFielder == false + isNative := !fi.isFielder setValue: switch { @@ -1533,7 +1536,7 @@ setValue: } } - if isNative == false { + if !isNative { fd := field.Addr().Interface().(Fielder) err := fd.SetRaw(value) if err != nil { @@ -1594,7 +1597,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond infos = make([]*fieldInfo, 0, len(exprs)) for _, ex := range exprs { index, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep)) - if suc == false { + if !suc { panic(fmt.Errorf("unknown field/column name `%s`", ex)) } cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, fi.column, Q, Q, name, Q)) @@ -1733,7 +1736,7 @@ func (d *dbBase) TableQuote() string { return "`" } -// replace value placeholer in parametered sql string. +// replace value placeholder in parametered sql string. func (d *dbBase) ReplaceMarks(query *string) { // default use `?` as mark, do nothing } diff --git a/orm/db_alias.go b/orm/db_alias.go index c95d49c9..c7089239 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -60,6 +60,8 @@ var ( "sqlite3": DRSqlite, "tidb": DRTiDB, "oracle": DROracle, + "oci8": DROracle, // github.com/mattn/go-oci8 + "ora": DROracle, //https://github.com/rana/ora } dbBasers = map[DriverType]dbBaser{ DRMySQL: newdbBaseMysql(), @@ -186,7 +188,7 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error()) } - if dataBaseCache.add(aliasName, al) == false { + if !dataBaseCache.add(aliasName, al) { return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) } @@ -244,11 +246,11 @@ end: // RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. func RegisterDriver(driverName string, typ DriverType) error { - if t, ok := drivers[driverName]; ok == false { + if t, ok := drivers[driverName]; !ok { drivers[driverName] = typ } else { if t != typ { - return fmt.Errorf("driverName `%s` db driver already registered and is other type\n", driverName) + return fmt.Errorf("driverName `%s` db driver already registered and is other type", driverName) } } return nil @@ -259,7 +261,7 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error { if al, ok := dataBaseCache.get(aliasName); ok { al.TZ = tz } else { - return fmt.Errorf("DataBase alias name `%s` not registered\n", aliasName) + return fmt.Errorf("DataBase alias name `%s` not registered", aliasName) } return nil } @@ -294,5 +296,5 @@ func GetDB(aliasNames ...string) (*sql.DB, error) { if ok { return al.DB, nil } - return nil, fmt.Errorf("DataBase of alias name `%s` not found\n", name) + return nil, fmt.Errorf("DataBase of alias name `%s` not found", name) } diff --git a/orm/db_mysql.go b/orm/db_mysql.go index 1016de2b..51185563 100644 --- a/orm/db_mysql.go +++ b/orm/db_mysql.go @@ -103,8 +103,7 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool // If no will insert // Add "`" for mysql sql building func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { - - iouStr := "" + var iouStr string argsMap := map[string]string{} iouStr = "ON DUPLICATE KEY UPDATE" diff --git a/orm/db_sqlite.go b/orm/db_sqlite.go index a3cb69a7..a43a5594 100644 --- a/orm/db_sqlite.go +++ b/orm/db_sqlite.go @@ -134,7 +134,7 @@ func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool defer rows.Close() for rows.Next() { var tmp, index sql.NullString - rows.Scan(&tmp, &index, &tmp) + rows.Scan(&tmp, &index, &tmp, &tmp, &tmp) if name == index.String { return true } diff --git a/orm/db_tables.go b/orm/db_tables.go index e4c74ace..42be5550 100644 --- a/orm/db_tables.go +++ b/orm/db_tables.go @@ -63,7 +63,7 @@ func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) // add table info to collection. func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) { name := strings.Join(names, ExprSep) - if _, ok := t.tablesM[name]; ok == false { + if _, ok := t.tablesM[name]; !ok { i := len(t.tables) + 1 jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} t.tablesM[name] = jt @@ -261,7 +261,7 @@ loopFor: fiN, okN = mmi.fields.GetByAny(exprs[i+1]) } - if isRel && (fi.mi.isThrough == false || num != i) { + if isRel && (!fi.mi.isThrough || num != i) { if fi.null || t.skipEnd { inner = false } @@ -364,7 +364,7 @@ func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (whe } index, _, fi, suc := t.parseExprs(mi, exprs) - if suc == false { + if !suc { panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep))) } @@ -383,7 +383,7 @@ func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (whe } } - if sub == false && where != "" { + if !sub && where != "" { where = "WHERE " + where } @@ -403,7 +403,7 @@ func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) { exprs := strings.Split(group, ExprSep) index, _, fi, suc := t.parseExprs(t.mi, exprs) - if suc == false { + if !suc { panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) } @@ -432,7 +432,7 @@ func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) { exprs := strings.Split(order, ExprSep) index, _, fi, suc := t.parseExprs(t.mi, exprs) - if suc == false { + if !suc { panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) } diff --git a/orm/models_boot.go b/orm/models_boot.go index 4ba5affd..5327f754 100644 --- a/orm/models_boot.go +++ b/orm/models_boot.go @@ -75,7 +75,7 @@ func registerModel(PrefixOrSuffix string, model interface{}, isPrefix bool) { } if mi.fields.pk == nil { - fmt.Printf(" `%s` need a primary key field, default use 'id' if not set\n", name) + fmt.Printf(" `%s` needs a primary key field, default is to use 'id' if not set\n", name) os.Exit(2) } @@ -128,7 +128,7 @@ func bootStrap() { if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) { pn := fi.relThrough[:i] rmi, ok := modelCache.getByFullName(fi.relThrough) - if ok == false || pn != rmi.pkg { + if !ok || pn != rmi.pkg { err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough) goto end } @@ -171,7 +171,7 @@ func bootStrap() { break } } - if inModel == false { + if !inModel { rmi := fi.relModelInfo ffi := new(fieldInfo) ffi.name = mi.name @@ -185,7 +185,7 @@ func bootStrap() { } else { ffi.fieldType = RelReverseMany } - if rmi.fields.Add(ffi) == false { + if !rmi.fields.Add(ffi) { added := false for cnt := 0; cnt < 5; cnt++ { ffi.name = fmt.Sprintf("%s%d", mi.name, cnt) @@ -195,7 +195,7 @@ func bootStrap() { break } } - if added == false { + if !added { panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName)) } } @@ -248,7 +248,7 @@ func bootStrap() { break mForA } } - if found == false { + if !found { err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) goto end } @@ -267,7 +267,7 @@ func bootStrap() { break mForB } } - if found == false { + if !found { mForC: for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] { conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough || @@ -287,7 +287,7 @@ func bootStrap() { } } } - if found == false { + if !found { err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) goto end } diff --git a/orm/models_info_f.go b/orm/models_info_f.go index 4b3d3e27..bbb7d71f 100644 --- a/orm/models_info_f.go +++ b/orm/models_info_f.go @@ -47,7 +47,7 @@ func (f *fields) Add(fi *fieldInfo) (added bool) { } else { return } - if _, ok := f.fieldsByType[fi.fieldType]; ok == false { + if _, ok := f.fieldsByType[fi.fieldType]; !ok { f.fieldsByType[fi.fieldType] = make([]*fieldInfo, 0) } f.fieldsByType[fi.fieldType] = append(f.fieldsByType[fi.fieldType], fi) @@ -334,12 +334,12 @@ checkType: switch onDelete { case odCascade, odDoNothing: case odSetDefault: - if initial.Exist() == false { + if !initial.Exist() { err = errors.New("on_delete: set_default need set field a default value") goto end } case odSetNULL: - if fi.null == false { + if !fi.null { err = errors.New("on_delete: set_null need set field null") goto end } diff --git a/orm/models_info_m.go b/orm/models_info_m.go index d6ba1dca..4a3a37f9 100644 --- a/orm/models_info_m.go +++ b/orm/models_info_m.go @@ -78,7 +78,7 @@ func addModelFields(mi *modelInfo, ind reflect.Value, mName string, index []int) fi.fieldIndex = append(index, i) fi.mi = mi fi.inModel = true - if mi.fields.Add(fi) == false { + if !mi.fields.Add(fi) { err = fmt.Errorf("duplicate column name: %s", fi.column) break } diff --git a/orm/orm.go b/orm/orm.go index d9d1cd77..fcf82590 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -107,7 +107,7 @@ func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect if mi, ok := modelCache.getByFullName(name); ok { return mi, ind } - panic(fmt.Errorf(" table: `%s` not found, maybe not RegisterModel", name)) + panic(fmt.Errorf(" table: `%s` not found, make sure it was registered with `RegisterModel()`", name)) } // get field info from model info by given field name @@ -122,21 +122,13 @@ func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { // read data to model func (o *orm) Read(md interface{}, cols ...string) error { mi, ind := o.getMiInd(md, true) - err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) - if err != nil { - return err - } - return nil + return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) } // read data to model, like Read(), but use "SELECT FOR UPDATE" form func (o *orm) ReadForUpdate(md interface{}, cols ...string) error { mi, ind := o.getMiInd(md, true) - err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true) - if err != nil { - return err - } - return nil + return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true) } // Try to read a row from the database, or insert one if it doesn't exist @@ -238,15 +230,11 @@ func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64 // cols set the columns those want to update. func (o *orm) Update(md interface{}, cols ...string) (int64, error) { mi, ind := o.getMiInd(md, true) - num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) - if err != nil { - return num, err - } - return num, nil + return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) } // delete model in database -// cols shows the delete conditions values read from. deafult is pk +// cols shows the delete conditions values read from. default is pk func (o *orm) Delete(md interface{}, cols ...string) (int64, error) { mi, ind := o.getMiInd(md, true) num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols) @@ -361,7 +349,7 @@ func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, fi := o.getFieldInfo(mi, name) _, _, exist := getExistPk(mi, ind) - if exist == false { + if !exist { panic(ErrMissPK) } @@ -432,7 +420,7 @@ func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { // table name can be string or struct. // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { - name := "" + var name string if table, ok := ptrStructOrTableName.(string); ok { name = snakeString(table) if mi, ok := modelCache.get(name); ok { @@ -489,7 +477,7 @@ func (o *orm) Begin() error { // commit transaction func (o *orm) Commit() error { - if o.isTx == false { + if !o.isTx { return ErrTxDone } err := o.db.(txEnder).Commit() @@ -504,7 +492,7 @@ func (o *orm) Commit() error { // rollback transaction func (o *orm) Rollback() error { - if o.isTx == false { + if !o.isTx { return ErrTxDone } err := o.db.(txEnder).Rollback() diff --git a/orm/orm_querym2m.go b/orm/orm_querym2m.go index b220bda6..6a270a0d 100644 --- a/orm/orm_querym2m.go +++ b/orm/orm_querym2m.go @@ -72,7 +72,7 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { } _, v1, exist := getExistPk(o.mi, o.ind) - if exist == false { + if !exist { panic(ErrMissPK) } @@ -87,7 +87,7 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { v2 = ind.Interface() } else { _, v2, exist = getExistPk(fi.relModelInfo, ind) - if exist == false { + if !exist { panic(ErrMissPK) } } @@ -104,11 +104,7 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { fi := o.fi qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md) - nums, err := qs.Filter(fi.reverseFieldInfoTwo.name+ExprSep+"in", mds).Delete() - if err != nil { - return nums, err - } - return nums, nil + return qs.Filter(fi.reverseFieldInfoTwo.name+ExprSep+"in", mds).Delete() } // check model is existed in relationship of origin model diff --git a/orm/orm_raw.go b/orm/orm_raw.go index a968b1a1..c8e741ea 100644 --- a/orm/orm_raw.go +++ b/orm/orm_raw.go @@ -493,19 +493,33 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { } } } else { - for i := 0; i < ind.NumField(); i++ { - f := ind.Field(i) - fe := ind.Type().Field(i) - _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) - var col string - if col = tags["column"]; col == "" { - col = snakeString(fe.Name) - } - if v, ok := columnsMp[col]; ok { - value := reflect.ValueOf(v).Elem().Interface() - o.setFieldValue(f, value) + // define recursive function + var recursiveSetField func(rv reflect.Value) + recursiveSetField = func(rv reflect.Value) { + for i := 0; i < rv.NumField(); i++ { + f := rv.Field(i) + fe := rv.Type().Field(i) + + // check if the field is a Struct + // recursive the Struct type + if fe.Type.Kind() == reflect.Struct { + recursiveSetField(f) + } + + _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) + var col string + if col = tags["column"]; col == "" { + col = snakeString(fe.Name) + } + if v, ok := columnsMp[col]; ok { + value := reflect.ValueOf(v).Elem().Interface() + o.setFieldValue(f, value) + } } } + + // init call the recursive function + recursiveSetField(ind) } if eTyps[0].Kind() == reflect.Ptr { @@ -671,7 +685,7 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in ind *reflect.Value ) - typ := 0 + var typ int switch container.(type) { case *Params: typ = 1 diff --git a/orm/orm_test.go b/orm/orm_test.go index 8738952b..f1f2d85e 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -93,14 +93,14 @@ wrongArg: } func AssertIs(a interface{}, args ...interface{}) error { - if ok, err := ValuesCompare(true, a, args...); ok == false { + if ok, err := ValuesCompare(true, a, args...); !ok { return err } return nil } func AssertNot(a interface{}, args ...interface{}) error { - if ok, err := ValuesCompare(false, a, args...); ok == false { + if ok, err := ValuesCompare(false, a, args...); !ok { return err } return nil @@ -135,7 +135,7 @@ func getCaller(skip int) string { if i := strings.LastIndex(funName, "."); i > -1 { funName = funName[i+1:] } - return fmt.Sprintf("%s:%d: \n%s", fn, line, strings.Join(codes, "\n")) + return fmt.Sprintf("%s:%s:%d: \n%s", fn, funName, line, strings.Join(codes, "\n")) } func throwFail(t *testing.T, err error, args ...interface{}) { @@ -1014,6 +1014,8 @@ func TestAll(t *testing.T) { var users3 []*User qs = dORM.QueryTable("user") num, err = qs.Filter("user_name", "nothing").All(&users3) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 0)) throwFailNow(t, AssertIs(users3 == nil, false)) } @@ -1138,6 +1140,7 @@ func TestRelatedSel(t *testing.T) { } err = qs.Filter("user_name", "nobody").RelatedSel("profile").One(&user) + throwFail(t, err) throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(user.Profile, nil)) @@ -1246,20 +1249,24 @@ func TestLoadRelated(t *testing.T) { num, err = dORM.LoadRelated(&user, "Posts", true) throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) throwFailNow(t, AssertIs(len(user.Posts), 2)) throwFailNow(t, AssertIs(user.Posts[0].User.UserName, "astaxie")) num, err = dORM.LoadRelated(&user, "Posts", true, 1) throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(len(user.Posts), 1)) num, err = dORM.LoadRelated(&user, "Posts", true, 0, 0, "-Id") throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) throwFailNow(t, AssertIs(len(user.Posts), 2)) throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) num, err = dORM.LoadRelated(&user, "Posts", true, 1, 1, "Id") throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(len(user.Posts), 1)) throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) @@ -1654,6 +1661,13 @@ func TestRawQueryRow(t *testing.T) { throwFail(t, AssertIs(pid, nil)) } +// user_profile table +type userProfile struct { + User + Age int + Money float64 +} + func TestQueryRows(t *testing.T) { Q := dDbBaser.TableQuote() @@ -1724,6 +1738,19 @@ func TestQueryRows(t *testing.T) { throwFailNow(t, AssertIs(usernames[1], "astaxie")) throwFailNow(t, AssertIs(ids[2], 4)) throwFailNow(t, AssertIs(usernames[2], "nobody")) + + //test query rows by nested struct + var l []userProfile + query = fmt.Sprintf("SELECT * FROM %suser_profile%s LEFT JOIN %suser%s ON %suser_profile%s.%sid%s = %suser%s.%sid%s", Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q) + num, err = dORM.Raw(query).QueryRows(&l) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(l), 2)) + throwFailNow(t, AssertIs(l[0].UserName, "slene")) + throwFailNow(t, AssertIs(l[0].Age, 28)) + throwFailNow(t, AssertIs(l[1].UserName, "astaxie")) + throwFailNow(t, AssertIs(l[1].Age, 30)) + } func TestRawValues(t *testing.T) { @@ -1976,6 +2003,7 @@ func TestReadOrCreate(t *testing.T) { created, pk, err := dORM.ReadOrCreate(u, "UserName") throwFail(t, err) throwFail(t, AssertIs(created, true)) + throwFail(t, AssertIs(u.ID, pk)) throwFail(t, AssertIs(u.UserName, "Kyle")) throwFail(t, AssertIs(u.Email, "kylemcc@gmail.com")) throwFail(t, AssertIs(u.Password, "other_pass")) @@ -2130,13 +2158,13 @@ func TestUintPk(t *testing.T) { Name: name, } - created, pk, err := dORM.ReadOrCreate(u, "ID") + created, _, err := dORM.ReadOrCreate(u, "ID") throwFail(t, err) throwFail(t, AssertIs(created, true)) throwFail(t, AssertIs(u.Name, name)) nu := &UintPk{ID: 8} - created, pk, err = dORM.ReadOrCreate(nu, "ID") + created, pk, err := dORM.ReadOrCreate(nu, "ID") throwFail(t, err) throwFail(t, AssertIs(created, false)) throwFail(t, AssertIs(nu.ID, u.ID)) diff --git a/orm/utils.go b/orm/utils.go index 6aac8e5d..669d4734 100644 --- a/orm/utils.go +++ b/orm/utils.go @@ -92,11 +92,11 @@ func (f StrTo) Int64() (int64, error) { i := new(big.Int) ni, ok := i.SetString(f.String(), 10) // octal if !ok { - return int64(v), err + return v, err } return ni.Int64(), nil } - return int64(v), err + return v, err } // Uint string to uint @@ -130,11 +130,11 @@ func (f StrTo) Uint64() (uint64, error) { i := new(big.Int) ni, ok := i.SetString(f.String(), 10) if !ok { - return uint64(v), err + return v, err } return ni.Uint64(), nil } - return uint64(v), err + return v, err } // String string to string @@ -225,7 +225,7 @@ func camelString(s string) string { if d == '_' { flag = true continue - } else if flag == true { + } else if flag { if d >= 'a' && d <= 'z' { d = d - 32 } diff --git a/parser.go b/parser.go index d40ee3ce..a9cfd894 100644 --- a/parser.go +++ b/parser.go @@ -24,9 +24,13 @@ import ( "io/ioutil" "os" "path/filepath" + "regexp" "sort" + "strconv" "strings" + "unicode" + "github.com/astaxie/beego/context/param" "github.com/astaxie/beego/logs" "github.com/astaxie/beego/utils" ) @@ -35,6 +39,7 @@ var globalRouterTemplate = `package routers import ( "github.com/astaxie/beego" + "github.com/astaxie/beego/context/param" ) func init() { @@ -81,7 +86,7 @@ func parserPkg(pkgRealpath, pkgpath string) error { if specDecl.Recv != nil { exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser if ok { - parserComments(specDecl.Doc, specDecl.Name.String(), fmt.Sprint(exp.X), pkgpath) + parserComments(specDecl, fmt.Sprint(exp.X), pkgpath) } } } @@ -93,44 +98,169 @@ func parserPkg(pkgRealpath, pkgpath string) error { return nil } -func parserComments(comments *ast.CommentGroup, funcName, controllerName, pkgpath string) error { - if comments != nil && comments.List != nil { - for _, c := range comments.List { - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - if strings.HasPrefix(t, "@router") { - elements := strings.TrimLeft(t, "@router ") - e1 := strings.SplitN(elements, " ", 2) - if len(e1) < 1 { - return errors.New("you should has router information") - } - key := pkgpath + ":" + controllerName - cc := ControllerComments{} - cc.Method = funcName - cc.Router = e1[0] - if len(e1) == 2 && e1[1] != "" { - e1 = strings.SplitN(e1[1], " ", 2) - if len(e1) >= 1 { - cc.AllowHTTPMethods = strings.Split(strings.Trim(e1[0], "[]"), ",") - } else { - cc.AllowHTTPMethods = append(cc.AllowHTTPMethods, "get") - } - } else { - cc.AllowHTTPMethods = append(cc.AllowHTTPMethods, "get") - } - if len(e1) == 2 && e1[1] != "" { - keyval := strings.Split(strings.Trim(e1[1], "[]"), " ") - for _, kv := range keyval { - kk := strings.Split(kv, ":") - cc.Params = append(cc.Params, map[string]string{strings.Join(kk[:len(kk)-1], ":"): kk[len(kk)-1]}) - } - } - genInfoList[key] = append(genInfoList[key], cc) - } +type parsedComment struct { + routerPath string + methods []string + params map[string]parsedParam +} + +type parsedParam struct { + name string + datatype string + location string + defValue string + required bool +} + +func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error { + if f.Doc != nil { + parsedComment, err := parseComment(f.Doc.List) + if err != nil { + return err } + if parsedComment.routerPath != "" { + key := pkgpath + ":" + controllerName + cc := ControllerComments{} + cc.Method = f.Name.String() + cc.Router = parsedComment.routerPath + cc.AllowHTTPMethods = parsedComment.methods + cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment) + genInfoList[key] = append(genInfoList[key], cc) + } + } return nil } +func buildMethodParams(funcParams []*ast.Field, pc *parsedComment) []*param.MethodParam { + result := make([]*param.MethodParam, 0, len(funcParams)) + for _, fparam := range funcParams { + methodParam := buildMethodParam(fparam, pc) + result = append(result, methodParam) + } + return result +} + +func buildMethodParam(fparam *ast.Field, pc *parsedComment) *param.MethodParam { + options := []param.MethodParamOption{} + name := fparam.Names[0].Name + if cparam, ok := pc.params[name]; ok { + //Build param from comment info + name = cparam.name + if cparam.required { + options = append(options, param.IsRequired) + } + switch cparam.location { + case "body": + options = append(options, param.InBody) + case "header": + options = append(options, param.InHeader) + case "path": + options = append(options, param.InPath) + } + if cparam.defValue != "" { + options = append(options, param.Default(cparam.defValue)) + } + } else { + if paramInPath(name, pc.routerPath) { + options = append(options, param.InPath) + } + } + return param.New(name, options...) +} + +func paramInPath(name, route string) bool { + return strings.HasSuffix(route, ":"+name) || + strings.Contains(route, ":"+name+"/") +} + +var routeRegex = regexp.MustCompile(`@router\s+(\S+)(?:\s+\[(\S+)\])?`) + +func parseComment(lines []*ast.Comment) (pc *parsedComment, err error) { + pc = &parsedComment{} + for _, c := range lines { + t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) + if strings.HasPrefix(t, "@router") { + matches := routeRegex.FindStringSubmatch(t) + if len(matches) == 3 { + pc.routerPath = matches[1] + methods := matches[2] + if methods == "" { + pc.methods = []string{"get"} + //pc.hasGet = true + } else { + pc.methods = strings.Split(methods, ",") + //pc.hasGet = strings.Contains(methods, "get") + } + } else { + return nil, errors.New("Router information is missing") + } + } else if strings.HasPrefix(t, "@Param") { + pv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Param"))) + if len(pv) < 4 { + logs.Error("Invalid @Param format. Needs at least 4 parameters") + } + p := parsedParam{} + names := strings.SplitN(pv[0], "=>", 2) + p.name = names[0] + funcParamName := p.name + if len(names) > 1 { + funcParamName = names[1] + } + p.location = pv[1] + p.datatype = pv[2] + switch len(pv) { + case 5: + p.required, _ = strconv.ParseBool(pv[3]) + case 6: + p.defValue = pv[3] + p.required, _ = strconv.ParseBool(pv[4]) + } + if pc.params == nil { + pc.params = map[string]parsedParam{} + } + pc.params[funcParamName] = p + } + } + return +} + +// direct copy from bee\g_docs.go +// analisys params return []string +// @Param query form string true "The email for login" +// [query form string true "The email for login"] +func getparams(str string) []string { + var s []rune + var j int + var start bool + var r []string + var quoted int8 + for _, c := range str { + if unicode.IsSpace(c) && quoted == 0 { + if !start { + continue + } else { + start = false + j++ + r = append(r, string(s)) + s = make([]rune, 0) + continue + } + } + + start = true + if c == '"' { + quoted ^= 1 + continue + } + s = append(s, c) + } + if len(s) > 0 { + r = append(r, string(s)) + } + return r +} + func genRouterCode(pkgRealpath string) { os.Mkdir(getRouterDir(pkgRealpath), 0755) logs.Info("generate router from comments") @@ -163,12 +293,24 @@ func genRouterCode(pkgRealpath string) { } params = strings.TrimRight(params, ",") + "}" } + methodParams := "param.Make(" + if len(c.MethodParams) > 0 { + lines := make([]string, 0, len(c.MethodParams)) + for _, m := range c.MethodParams { + lines = append(lines, fmt.Sprint(m)) + } + methodParams += "\n " + + strings.Join(lines, ",\n ") + + ",\n " + } + methodParams += ")" globalinfo = globalinfo + ` beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"], beego.ControllerComments{ Method: "` + strings.TrimSpace(c.Method) + `", ` + "Router: `" + c.Router + "`" + `, AllowHTTPMethods: ` + allmethod + `, + MethodParams: ` + methodParams + `, Params: ` + params + `}) ` } diff --git a/plugins/authz/authz.go b/plugins/authz/authz.go new file mode 100644 index 00000000..709a613a --- /dev/null +++ b/plugins/authz/authz.go @@ -0,0 +1,86 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package authz provides handlers to enable ACL, RBAC, ABAC authorization support. +// Simple Usage: +// import( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/authz" +// "github.com/hsluoyz/casbin" +// ) +// +// func main(){ +// // mediate the access for every request +// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) +// beego.Run() +// } +// +// +// Advanced Usage: +// +// func main(){ +// e := casbin.NewEnforcer("authz_model.conf", "") +// e.AddRoleForUser("alice", "admin") +// e.AddPolicy(...) +// +// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(e)) +// beego.Run() +// } +package authz + +import ( + "github.com/astaxie/beego" + "github.com/astaxie/beego/context" + "github.com/hsluoyz/casbin" + "net/http" +) + +// NewAuthorizer returns the authorizer. +// Use a casbin enforcer as input +func NewAuthorizer(e *casbin.Enforcer) beego.FilterFunc { + return func(ctx *context.Context) { + a := &BasicAuthorizer{enforcer: e} + + if !a.CheckPermission(ctx.Request) { + a.RequirePermission(ctx.ResponseWriter) + } + } +} + +// BasicAuthorizer stores the casbin handler +type BasicAuthorizer struct { + enforcer *casbin.Enforcer +} + +// GetUserName gets the user name from the request. +// Currently, only HTTP basic authentication is supported +func (a *BasicAuthorizer) GetUserName(r *http.Request) string { + username, _, _ := r.BasicAuth() + return username +} + +// CheckPermission checks the user/method/path combination from the request. +// Returns true (permission granted) or false (permission forbidden) +func (a *BasicAuthorizer) CheckPermission(r *http.Request) bool { + user := a.GetUserName(r) + method := r.Method + path := r.URL.Path + return a.enforcer.Enforce(user, path, method) +} + +// RequirePermission returns the 403 Forbidden to the client +func (a *BasicAuthorizer) RequirePermission(w http.ResponseWriter) { + w.WriteHeader(403) + w.Write([]byte("403 Forbidden\n")) +} diff --git a/plugins/authz/authz_model.conf b/plugins/authz/authz_model.conf new file mode 100644 index 00000000..d1b3dbd7 --- /dev/null +++ b/plugins/authz/authz_model.conf @@ -0,0 +1,14 @@ +[request_definition] +r = sub, obj, act + +[policy_definition] +p = sub, obj, act + +[role_definition] +g = _, _ + +[policy_effect] +e = some(where (p.eft == allow)) + +[matchers] +m = g(r.sub, p.sub) && keyMatch(r.obj, p.obj) && (r.act == p.act || p.act == "*") \ No newline at end of file diff --git a/plugins/authz/authz_policy.csv b/plugins/authz/authz_policy.csv new file mode 100644 index 00000000..c062dd3e --- /dev/null +++ b/plugins/authz/authz_policy.csv @@ -0,0 +1,7 @@ +p, alice, /dataset1/*, GET +p, alice, /dataset1/resource1, POST +p, bob, /dataset2/resource1, * +p, bob, /dataset2/resource2, GET +p, bob, /dataset2/folder1/*, POST +p, dataset1_admin, /dataset1/*, * +g, cathy, dataset1_admin \ No newline at end of file diff --git a/plugins/authz/authz_test.go b/plugins/authz/authz_test.go new file mode 100644 index 00000000..4003582c --- /dev/null +++ b/plugins/authz/authz_test.go @@ -0,0 +1,107 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package authz + +import ( + "github.com/astaxie/beego" + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/plugins/auth" + "github.com/hsluoyz/casbin" + "net/http" + "net/http/httptest" + "testing" +) + +func testRequest(t *testing.T, handler *beego.ControllerRegister, user string, path string, method string, code int) { + r, _ := http.NewRequest(method, path, nil) + r.SetBasicAuth(user, "123") + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if w.Code != code { + t.Errorf("%s, %s, %s: %d, supposed to be %d", user, path, method, w.Code, code) + } +} + +func TestBasic(t *testing.T) { + handler := beego.NewControllerRegister() + + handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("alice", "123")) + handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "alice", "/dataset1/resource1", "GET", 200) + testRequest(t, handler, "alice", "/dataset1/resource1", "POST", 200) + testRequest(t, handler, "alice", "/dataset1/resource2", "GET", 200) + testRequest(t, handler, "alice", "/dataset1/resource2", "POST", 403) +} + +func TestPathWildcard(t *testing.T) { + handler := beego.NewControllerRegister() + + handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("bob", "123")) + handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "bob", "/dataset2/resource1", "GET", 200) + testRequest(t, handler, "bob", "/dataset2/resource1", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/resource1", "DELETE", 200) + testRequest(t, handler, "bob", "/dataset2/resource2", "GET", 200) + testRequest(t, handler, "bob", "/dataset2/resource2", "POST", 403) + testRequest(t, handler, "bob", "/dataset2/resource2", "DELETE", 403) + + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "GET", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "DELETE", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "GET", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "DELETE", 403) +} + +func TestRBAC(t *testing.T) { + handler := beego.NewControllerRegister() + + handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("cathy", "123")) + e := casbin.NewEnforcer("authz_model.conf", "authz_policy.csv") + handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(e)) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + // cathy can access all /dataset1/* resources via all methods because it has the dataset1_admin role. + testRequest(t, handler, "cathy", "/dataset1/item", "GET", 200) + testRequest(t, handler, "cathy", "/dataset1/item", "POST", 200) + testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 200) + testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403) + + // delete all roles on user cathy, so cathy cannot access any resources now. + e.DeleteRolesForUser("cathy") + + testRequest(t, handler, "cathy", "/dataset1/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset1/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403) +} diff --git a/policy.go b/policy.go index 2b91fdcc..ab23f927 100644 --- a/policy.go +++ b/policy.go @@ -23,7 +23,7 @@ import ( // PolicyFunc defines a policy function which is invoked before the controller handler is executed. type PolicyFunc func(*context.Context) -// FindRouter Find Router info for URL +// FindPolicy Find Router info for URL func (p *ControllerRegister) FindPolicy(cont *context.Context) []PolicyFunc { var urlPath = cont.Input.URL() if !BConfig.RouterCaseSensitive { @@ -71,7 +71,7 @@ func (p *ControllerRegister) addToPolicy(method, pattern string, r ...PolicyFunc } } -// Register new policy in beego +// Policy Register new policy in beego func Policy(pattern, method string, policy ...PolicyFunc) { BeeApp.Handlers.addToPolicy(method, pattern, policy...) } diff --git a/router.go b/router.go index 9f573f26..72476ae8 100644 --- a/router.go +++ b/router.go @@ -17,7 +17,6 @@ package beego import ( "fmt" "net/http" - "os" "path" "path/filepath" "reflect" @@ -28,6 +27,7 @@ import ( "time" beecontext "github.com/astaxie/beego/context" + "github.com/astaxie/beego/context/param" "github.com/astaxie/beego/logs" "github.com/astaxie/beego/toolbox" "github.com/astaxie/beego/utils" @@ -109,13 +109,15 @@ func ExceptMethodAppend(action string) { exceptMethod = append(exceptMethod, action) } -type controllerInfo struct { +// ControllerInfo holds information about the controller. +type ControllerInfo struct { pattern string controllerType reflect.Type methods map[string]string handler http.Handler runFunction FilterFunc routerType int + methodParams []*param.MethodParam } // ControllerRegister containers registered router rules, controller handlers and filters. @@ -151,6 +153,10 @@ func NewControllerRegister() *ControllerRegister { // Add("/api",&RestController{},"get,post:ApiFunc" // Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { + p.addWithMethodParams(pattern, c, nil, mappingMethods...) +} + +func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, mappingMethods ...string) { reflectVal := reflect.ValueOf(c) t := reflect.Indirect(reflectVal).Type() methods := make(map[string]string) @@ -176,11 +182,12 @@ func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingM } } - route := &controllerInfo{} + route := &ControllerInfo{} route.pattern = pattern route.methods = methods route.routerType = routerTypeBeego route.controllerType = t + route.methodParams = methodParams if len(methods) == 0 { for _, m := range HTTPMETHOD { p.addToRouter(m, pattern, route) @@ -198,7 +205,7 @@ func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingM } } -func (p *ControllerRegister) addToRouter(method, pattern string, r *controllerInfo) { +func (p *ControllerRegister) addToRouter(method, pattern string, r *ControllerInfo) { if !BConfig.RouterCaseSensitive { pattern = strings.ToLower(pattern) } @@ -219,13 +226,11 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) { for _, c := range cList { reflectVal := reflect.ValueOf(c) t := reflect.Indirect(reflectVal).Type() - gopath := os.Getenv("GOPATH") - if gopath == "" { + wgopath := utils.GetGOPATHs() + if len(wgopath) == 0 { panic("you are in dev mode. So please set gopath") } pkgpath := "" - - wgopath := filepath.SplitList(gopath) for _, wg := range wgopath { wg, _ = filepath.EvalSymlinks(filepath.Join(wg, "src", t.PkgPath())) if utils.FileExists(wg) { @@ -247,7 +252,7 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) { key := t.PkgPath() + ":" + t.Name() if comm, ok := GlobalControllerRouter[key]; ok { for _, a := range comm { - p.Add(a.Router, c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method) + p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method) } } } @@ -335,7 +340,7 @@ func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) { if _, ok := HTTPMETHOD[method]; method != "*" && !ok { panic("not support http method: " + method) } - route := &controllerInfo{} + route := &ControllerInfo{} route.pattern = pattern route.routerType = routerTypeRESTFul route.runFunction = f @@ -361,7 +366,7 @@ func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) { // Handler add user defined Handler func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) { - route := &controllerInfo{} + route := &ControllerInfo{} route.pattern = pattern route.routerType = routerTypeHandler route.handler = h @@ -396,7 +401,7 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) controllerName := strings.TrimSuffix(ct.Name(), "Controller") for i := 0; i < rt.NumMethod(); i++ { if !utils.InSlice(rt.Method(i).Name, exceptMethod) { - route := &controllerInfo{} + route := &ControllerInfo{} route.routerType = routerTypeBeego route.methods = map[string]string{"*": rt.Method(i).Name} route.controllerType = ct @@ -502,7 +507,7 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin } } for _, l := range t.leaves { - if c, ok := l.runObject.(*controllerInfo); ok { + if c, ok := l.runObject.(*ControllerInfo); ok { if c.routerType == routerTypeBeego && strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllName) { find := false @@ -626,11 +631,12 @@ func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath str func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) { startTime := time.Now() var ( - runRouter reflect.Type - findRouter bool - runMethod string - routerInfo *controllerInfo - isRunnable bool + runRouter reflect.Type + findRouter bool + runMethod string + methodParams []*param.MethodParam + routerInfo *ControllerInfo + isRunnable bool ) context := p.pool.Get().(*beecontext.Context) context.Reset(rw, r) @@ -670,7 +676,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) goto Admin } - if r.Method != "GET" && r.Method != "HEAD" { + if r.Method != http.MethodGet && r.Method != http.MethodHead { if BConfig.CopyRequestBody && !context.Input.IsUpload() { context.Input.CopyBody(BConfig.MaxMemory) } @@ -742,12 +748,13 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) routerInfo.handler.ServeHTTP(rw, r) } else { runRouter = routerInfo.controllerType + methodParams = routerInfo.methodParams method := r.Method - if r.Method == "POST" && context.Input.Query("_method") == "PUT" { - method = "PUT" + if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodPost { + method = http.MethodPut } - if r.Method == "POST" && context.Input.Query("_method") == "DELETE" { - method = "DELETE" + if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodDelete { + method = http.MethodDelete } if m, ok := routerInfo.methods[method]; ok { runMethod = m @@ -777,8 +784,8 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) //if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf if BConfig.WebConfig.EnableXSRF { execController.XSRFToken() - if r.Method == "POST" || r.Method == "DELETE" || r.Method == "PUT" || - (r.Method == "POST" && (context.Input.Query("_method") == "DELETE" || context.Input.Query("_method") == "PUT")) { + if r.Method == http.MethodPost || r.Method == http.MethodDelete || r.Method == http.MethodPut || + (r.Method == http.MethodPost && (context.Input.Query("_method") == http.MethodDelete || context.Input.Query("_method") == http.MethodPut)) { execController.CheckXSRFCookie() } } @@ -788,25 +795,30 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) if !context.ResponseWriter.Started { //exec main logic switch runMethod { - case "GET": + case http.MethodGet: execController.Get() - case "POST": + case http.MethodPost: execController.Post() - case "DELETE": + case http.MethodDelete: execController.Delete() - case "PUT": + case http.MethodPut: execController.Put() - case "HEAD": + case http.MethodHead: execController.Head() - case "PATCH": + case http.MethodPatch: execController.Patch() - case "OPTIONS": + case http.MethodOptions: execController.Options() default: if !execController.HandlerFunc(runMethod) { - var in []reflect.Value method := vc.MethodByName(runMethod) - method.Call(in) + in := param.ConvertParams(methodParams, method.Type(), context) + out := method.Call(in) + + //For backward compatibility we only handle response if we had incoming methodParams + if methodParams != nil { + p.handleParamResponse(context, execController, out) + } } } @@ -886,8 +898,22 @@ Admin: } } +func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, execController ControllerInterface, results []reflect.Value) { + //looping in reverse order for the case when both error and value are returned and error sets the response status code + for i := len(results) - 1; i >= 0; i-- { + result := results[i] + if result.Kind() != reflect.Interface || !result.IsNil() { + resultValue := result.Interface() + context.RenderMethodResult(resultValue) + } + } + if !context.ResponseWriter.Started && context.Output.Status == 0 { + context.Output.SetStatus(200) + } +} + // FindRouter Find Router info for URL -func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo *controllerInfo, isFind bool) { +func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo *ControllerInfo, isFind bool) { var urlPath = context.Input.URL() if !BConfig.RouterCaseSensitive { urlPath = strings.ToLower(urlPath) @@ -895,7 +921,7 @@ func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo httpMethod := context.Input.Method() if t, ok := p.routers[httpMethod]; ok { runObject := t.Match(urlPath, context) - if r, ok := runObject.(*controllerInfo); ok { + if r, ok := runObject.(*ControllerInfo); ok { return r, true } } diff --git a/router_test.go b/router_test.go index 936fd5e8..720b4ca8 100644 --- a/router_test.go +++ b/router_test.go @@ -502,10 +502,10 @@ func TestFilterBeforeRouter(t *testing.T) { rw, r := testRequest("GET", url) mux.ServeHTTP(rw, r) - if strings.Contains(rw.Body.String(), "BeforeRouter1") == false { + if !strings.Contains(rw.Body.String(), "BeforeRouter1") { t.Errorf(testName + " BeforeRouter did not run") } - if strings.Contains(rw.Body.String(), "hello") == true { + if strings.Contains(rw.Body.String(), "hello") { t.Errorf(testName + " BeforeRouter did not return properly") } } @@ -525,13 +525,13 @@ func TestFilterBeforeExec(t *testing.T) { rw, r := testRequest("GET", url) mux.ServeHTTP(rw, r) - if strings.Contains(rw.Body.String(), "BeforeExec1") == false { + if !strings.Contains(rw.Body.String(), "BeforeExec1") { t.Errorf(testName + " BeforeExec did not run") } - if strings.Contains(rw.Body.String(), "hello") == true { + if strings.Contains(rw.Body.String(), "hello") { t.Errorf(testName + " BeforeExec did not return properly") } - if strings.Contains(rw.Body.String(), "BeforeRouter") == true { + if strings.Contains(rw.Body.String(), "BeforeRouter") { t.Errorf(testName + " BeforeRouter ran in error") } } @@ -552,16 +552,16 @@ func TestFilterAfterExec(t *testing.T) { rw, r := testRequest("GET", url) mux.ServeHTTP(rw, r) - if strings.Contains(rw.Body.String(), "AfterExec1") == false { + if !strings.Contains(rw.Body.String(), "AfterExec1") { t.Errorf(testName + " AfterExec did not run") } - if strings.Contains(rw.Body.String(), "hello") == false { + if !strings.Contains(rw.Body.String(), "hello") { t.Errorf(testName + " handler did not run properly") } - if strings.Contains(rw.Body.String(), "BeforeRouter") == true { + if strings.Contains(rw.Body.String(), "BeforeRouter") { t.Errorf(testName + " BeforeRouter ran in error") } - if strings.Contains(rw.Body.String(), "BeforeExec") == true { + if strings.Contains(rw.Body.String(), "BeforeExec") { t.Errorf(testName + " BeforeExec ran in error") } } @@ -583,19 +583,19 @@ func TestFilterFinishRouter(t *testing.T) { rw, r := testRequest("GET", url) mux.ServeHTTP(rw, r) - if strings.Contains(rw.Body.String(), "FinishRouter1") == true { + if strings.Contains(rw.Body.String(), "FinishRouter1") { t.Errorf(testName + " FinishRouter did not run") } - if strings.Contains(rw.Body.String(), "hello") == false { + if !strings.Contains(rw.Body.String(), "hello") { t.Errorf(testName + " handler did not run properly") } - if strings.Contains(rw.Body.String(), "AfterExec1") == true { + if strings.Contains(rw.Body.String(), "AfterExec1") { t.Errorf(testName + " AfterExec ran in error") } - if strings.Contains(rw.Body.String(), "BeforeRouter") == true { + if strings.Contains(rw.Body.String(), "BeforeRouter") { t.Errorf(testName + " BeforeRouter ran in error") } - if strings.Contains(rw.Body.String(), "BeforeExec") == true { + if strings.Contains(rw.Body.String(), "BeforeExec") { t.Errorf(testName + " BeforeExec ran in error") } } @@ -615,14 +615,14 @@ func TestFilterFinishRouterMultiFirstOnly(t *testing.T) { rw, r := testRequest("GET", url) mux.ServeHTTP(rw, r) - if strings.Contains(rw.Body.String(), "FinishRouter1") == false { + if !strings.Contains(rw.Body.String(), "FinishRouter1") { t.Errorf(testName + " FinishRouter1 did not run") } - if strings.Contains(rw.Body.String(), "hello") == false { + if !strings.Contains(rw.Body.String(), "hello") { t.Errorf(testName + " handler did not run properly") } // not expected in body - if strings.Contains(rw.Body.String(), "FinishRouter2") == true { + if strings.Contains(rw.Body.String(), "FinishRouter2") { t.Errorf(testName + " FinishRouter2 did run") } } @@ -642,44 +642,52 @@ func TestFilterFinishRouterMulti(t *testing.T) { rw, r := testRequest("GET", url) mux.ServeHTTP(rw, r) - if strings.Contains(rw.Body.String(), "FinishRouter1") == false { + if !strings.Contains(rw.Body.String(), "FinishRouter1") { t.Errorf(testName + " FinishRouter1 did not run") } - if strings.Contains(rw.Body.String(), "hello") == false { + if !strings.Contains(rw.Body.String(), "hello") { t.Errorf(testName + " handler did not run properly") } - if strings.Contains(rw.Body.String(), "FinishRouter2") == false { + if !strings.Contains(rw.Body.String(), "FinishRouter2") { t.Errorf(testName + " FinishRouter2 did not run properly") } } func beegoFilterNoOutput(ctx *context.Context) { - return } + func beegoBeforeRouter1(ctx *context.Context) { ctx.WriteString("|BeforeRouter1") } + func beegoBeforeRouter2(ctx *context.Context) { ctx.WriteString("|BeforeRouter2") } + func beegoBeforeExec1(ctx *context.Context) { ctx.WriteString("|BeforeExec1") } + func beegoBeforeExec2(ctx *context.Context) { ctx.WriteString("|BeforeExec2") } + func beegoAfterExec1(ctx *context.Context) { ctx.WriteString("|AfterExec1") } + func beegoAfterExec2(ctx *context.Context) { ctx.WriteString("|AfterExec2") } + func beegoFinishRouter1(ctx *context.Context) { ctx.WriteString("|FinishRouter1") } + func beegoFinishRouter2(ctx *context.Context) { ctx.WriteString("|FinishRouter2") } + func beegoResetParams(ctx *context.Context) { ctx.ResponseWriter.Header().Set("splat", ctx.Input.Param(":splat")) } diff --git a/session/couchbase/sess_couchbase.go b/session/couchbase/sess_couchbase.go index d5be11d0..707d042c 100644 --- a/session/couchbase/sess_couchbase.go +++ b/session/couchbase/sess_couchbase.go @@ -155,11 +155,16 @@ func (cp *Provider) SessionInit(maxlifetime int64, savePath string) error { func (cp *Provider) SessionRead(sid string) (session.Store, error) { cp.b = cp.getBucket() - var doc []byte + var ( + kv map[interface{}]interface{} + err error + doc []byte + ) - err := cp.b.Get(sid, &doc) - var kv map[interface{}]interface{} - if doc == nil { + err = cp.b.Get(sid, &doc) + if err != nil { + return nil, err + } else if doc == nil { kv = make(map[interface{}]interface{}) } else { kv, err = session.DecodeGob(doc) @@ -230,7 +235,6 @@ func (cp *Provider) SessionDestroy(sid string) error { // SessionGC Recycle func (cp *Provider) SessionGC() { - return } // SessionAll return all active session diff --git a/session/ledis/ledis_session.go b/session/ledis/ledis_session.go index 68f37b08..77685d1e 100644 --- a/session/ledis/ledis_session.go +++ b/session/ledis/ledis_session.go @@ -12,8 +12,10 @@ import ( "github.com/siddontang/ledisdb/ledis" ) -var ledispder = &Provider{} -var c *ledis.DB +var ( + ledispder = &Provider{} + c *ledis.DB +) // SessionStore ledis session store type SessionStore struct { @@ -97,27 +99,33 @@ func (lp *Provider) SessionInit(maxlifetime int64, savePath string) error { } cfg := new(config.Config) cfg.DataDir = lp.savePath - nowLedis, err := ledis.Open(cfg) - c, err = nowLedis.Select(lp.db) + + var ledisInstance *ledis.Ledis + ledisInstance, err = ledis.Open(cfg) if err != nil { - println(err) - return nil + return err } - return nil + c, err = ledisInstance.Select(lp.db) + return err } // SessionRead read ledis session by sid func (lp *Provider) SessionRead(sid string) (session.Store, error) { - kvs, err := c.Get([]byte(sid)) - var kv map[interface{}]interface{} + var ( + kv map[interface{}]interface{} + err error + ) + + kvs, _ := c.Get([]byte(sid)) + if len(kvs) == 0 { kv = make(map[interface{}]interface{}) } else { - kv, err = session.DecodeGob(kvs) - if err != nil { + if kv, err = session.DecodeGob(kvs); err != nil { return nil, err } } + ls := &SessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime} return ls, nil } @@ -125,10 +133,7 @@ func (lp *Provider) SessionRead(sid string) (session.Store, error) { // SessionExist check ledis session exist by sid func (lp *Provider) SessionExist(sid string) bool { count, _ := c.Exists([]byte(sid)) - if count == 0 { - return false - } - return true + return !(count == 0) } // SessionRegenerate generate new sid for ledis session @@ -145,18 +150,7 @@ func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) c.Set([]byte(sid), data) c.Expire([]byte(sid), lp.maxlifetime) } - kvs, err := c.Get([]byte(sid)) - var kv map[interface{}]interface{} - if len(kvs) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob([]byte(kvs)) - if err != nil { - return nil, err - } - } - ls := &SessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime} - return ls, nil + return lp.SessionRead(sid) } // SessionDestroy delete ledis session by id @@ -167,7 +161,6 @@ func (lp *Provider) SessionDestroy(sid string) error { // SessionGC Impelment method, no used. func (lp *Provider) SessionGC() { - return } // SessionAll return all active session diff --git a/session/memcache/sess_memcache.go b/session/memcache/sess_memcache.go index f1069bc9..755979c4 100644 --- a/session/memcache/sess_memcache.go +++ b/session/memcache/sess_memcache.go @@ -205,11 +205,7 @@ func (rp *MemProvider) SessionDestroy(sid string) error { } } - err := client.Delete(sid) - if err != nil { - return err - } - return nil + return client.Delete(sid) } func (rp *MemProvider) connectInit() error { @@ -219,7 +215,6 @@ func (rp *MemProvider) connectInit() error { // SessionGC Impelment method, no used. func (rp *MemProvider) SessionGC() { - return } // SessionAll return all activeSession diff --git a/session/mysql/sess_mysql.go b/session/mysql/sess_mysql.go index 7683ee1f..4c9251e7 100644 --- a/session/mysql/sess_mysql.go +++ b/session/mysql/sess_mysql.go @@ -143,7 +143,6 @@ func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { // SessionRead get mysql session by sid func (mp *Provider) SessionRead(sid string) (session.Store, error) { c := mp.connectInit() - defer c.Close() row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) var sessiondata []byte err := row.Scan(&sessiondata) @@ -171,16 +170,12 @@ func (mp *Provider) SessionExist(sid string) bool { row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) var sessiondata []byte err := row.Scan(&sessiondata) - if err == sql.ErrNoRows { - return false - } - return true + return !(err == sql.ErrNoRows) } // SessionRegenerate generate new sid for mysql session func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { c := mp.connectInit() - defer c.Close() row := c.QueryRow("select session_data from "+TableName+" where session_key=?", oldsid) var sessiondata []byte err := row.Scan(&sessiondata) @@ -214,7 +209,6 @@ func (mp *Provider) SessionGC() { c := mp.connectInit() c.Exec("DELETE from "+TableName+" where session_expiry < ?", time.Now().Unix()-mp.maxlifetime) c.Close() - return } // SessionAll count values in mysql session diff --git a/session/postgres/sess_postgresql.go b/session/postgres/sess_postgresql.go index 73f9c13a..ffc27def 100644 --- a/session/postgres/sess_postgresql.go +++ b/session/postgres/sess_postgresql.go @@ -184,11 +184,7 @@ func (mp *Provider) SessionExist(sid string) bool { row := c.QueryRow("select session_data from session where session_key=$1", sid) var sessiondata []byte err := row.Scan(&sessiondata) - - if err == sql.ErrNoRows { - return false - } - return true + return !(err == sql.ErrNoRows) } // SessionRegenerate generate new sid for postgresql session @@ -228,7 +224,6 @@ func (mp *Provider) SessionGC() { c := mp.connectInit() c.Exec("DELETE from session where EXTRACT(EPOCH FROM (current_timestamp - session_expiry)) > $1", mp.maxlifetime) c.Close() - return } // SessionAll count values in postgresql session diff --git a/session/redis/sess_redis.go b/session/redis/sess_redis.go index c46fa7cd..d0424515 100644 --- a/session/redis/sess_redis.go +++ b/session/redis/sess_redis.go @@ -128,7 +128,7 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { } if len(configs) > 1 { poolsize, err := strconv.Atoi(configs[1]) - if err != nil || poolsize <= 0 { + if err != nil || poolsize < 0 { rp.poolsize = MaxPoolSize } else { rp.poolsize = poolsize @@ -155,7 +155,7 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { return nil, err } if rp.password != "" { - if _, err := c.Do("AUTH", rp.password); err != nil { + if _, err = c.Do("AUTH", rp.password); err != nil { c.Close() return nil, err } @@ -176,13 +176,16 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) { c := rp.poollist.Get() defer c.Close() - kvs, err := redis.String(c.Do("GET", sid)) var kv map[interface{}]interface{} + + kvs, err := redis.String(c.Do("GET", sid)) + if err != nil && err != redis.ErrNil { + return nil, err + } if len(kvs) == 0 { kv = make(map[interface{}]interface{}) } else { - kv, err = session.DecodeGob([]byte(kvs)) - if err != nil { + if kv, err = session.DecodeGob([]byte(kvs)); err != nil { return nil, err } } @@ -216,20 +219,7 @@ func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) c.Do("RENAME", oldsid, sid) c.Do("EXPIRE", sid, rp.maxlifetime) } - - kvs, err := redis.String(c.Do("GET", sid)) - var kv map[interface{}]interface{} - if len(kvs) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob([]byte(kvs)) - if err != nil { - return nil, err - } - } - - rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} - return rs, nil + return rp.SessionRead(sid) } // SessionDestroy delete redis session by id @@ -243,7 +233,6 @@ func (rp *Provider) SessionDestroy(sid string) error { // SessionGC Impelment method, no used. func (rp *Provider) SessionGC() { - return } // SessionAll return all activeSession diff --git a/session/sess_cookie.go b/session/sess_cookie.go index 3fefa360..145e53c9 100644 --- a/session/sess_cookie.go +++ b/session/sess_cookie.go @@ -74,21 +74,16 @@ func (st *CookieSessionStore) SessionID() string { // SessionRelease Write cookie session to http response cookie func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) { - str, err := encodeCookie(cookiepder.block, - cookiepder.config.SecurityKey, - cookiepder.config.SecurityName, - st.values) - if err != nil { - return + encodedCookie, err := encodeCookie(cookiepder.block, cookiepder.config.SecurityKey, cookiepder.config.SecurityName, st.values) + if err == nil { + cookie := &http.Cookie{Name: cookiepder.config.CookieName, + Value: url.QueryEscape(encodedCookie), + Path: "/", + HttpOnly: true, + Secure: cookiepder.config.Secure, + MaxAge: cookiepder.config.Maxage} + http.SetCookie(w, cookie) } - cookie := &http.Cookie{Name: cookiepder.config.CookieName, - Value: url.QueryEscape(str), - Path: "/", - HttpOnly: true, - Secure: cookiepder.config.Secure, - MaxAge: cookiepder.config.Maxage} - http.SetCookie(w, cookie) - return } type cookieConfig struct { @@ -166,7 +161,6 @@ func (pder *CookieProvider) SessionDestroy(sid string) error { // SessionGC Implement method, no used. func (pder *CookieProvider) SessionGC() { - return } // SessionAll Implement method, return 0. diff --git a/session/sess_file.go b/session/sess_file.go index 50687c9e..3ca93d55 100644 --- a/session/sess_file.go +++ b/session/sess_file.go @@ -87,9 +87,16 @@ func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) { var f *os.File if err == nil { f, err = os.OpenFile(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid), os.O_RDWR, 0777) + if err != nil { + SLogger.Println(err) + return + } } else if os.IsNotExist(err) { f, err = os.Create(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid)) - + if err != nil { + SLogger.Println(err) + return + } } else { return } @@ -163,10 +170,7 @@ func (fp *FileProvider) SessionExist(sid string) bool { defer filepder.lock.Unlock() _, err := os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) - if err == nil { - return true - } - return false + return err == nil } // SessionDestroy Remove all files in this save path diff --git a/session/sess_test.go b/session/sess_test.go index b40865f3..906abec2 100644 --- a/session/sess_test.go +++ b/session/sess_test.go @@ -74,8 +74,7 @@ func TestCookieEncodeDecode(t *testing.T) { if err != nil { t.Fatal("encodeCookie:", err) } - dst := make(map[interface{}]interface{}) - dst, err = decodeCookie(block, hashKey, securityName, str, 3600) + dst, err := decodeCookie(block, hashKey, securityName, str, 3600) if err != nil { t.Fatal("decodeCookie", err) } @@ -115,7 +114,7 @@ func TestParseConfig(t *testing.T) { if cf2.Gclifetime != 3600 { t.Fatal("parseconfig get gclifetime error") } - if cf2.EnableSetCookie != false { + if cf2.EnableSetCookie { t.Fatal("parseconfig get enableSetCookie error") } cconfig := new(cookieConfig) diff --git a/session/session.go b/session/session.go index fb4b2821..cf647521 100644 --- a/session/session.go +++ b/session/session.go @@ -81,6 +81,7 @@ func Register(name string, provide Provider) { provides[name] = provide } +// ManagerConfig define the session config type ManagerConfig struct { CookieName string `json:"cookieName"` EnableSetCookie bool `json:"enableSetCookie,omitempty"` @@ -92,9 +93,9 @@ type ManagerConfig struct { ProviderConfig string `json:"providerConfig"` Domain string `json:"domain"` SessionIDLength int64 `json:"sessionIDLength"` - EnableSidInHttpHeader bool `json:"enableSidInHttpHeader"` - SessionNameInHttpHeader string `json:"sessionNameInHttpHeader"` - EnableSidInUrlQuery bool `json:"enableSidInUrlQuery"` + EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"` + SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"` + EnableSidInURLQuery bool `json:"EnableSidInURLQuery"` } // Manager contains Provider and its configuration. @@ -125,14 +126,14 @@ func NewManager(provideName string, cf *ManagerConfig) (*Manager, error) { cf.Maxlifetime = cf.Gclifetime } - if cf.EnableSidInHttpHeader { - if cf.SessionNameInHttpHeader == "" { - panic(errors.New("SessionNameInHttpHeader is empty")) + if cf.EnableSidInHTTPHeader { + if cf.SessionNameInHTTPHeader == "" { + panic(errors.New("SessionNameInHTTPHeader is empty")) } - strMimeHeader := textproto.CanonicalMIMEHeaderKey(cf.SessionNameInHttpHeader) - if cf.SessionNameInHttpHeader != strMimeHeader { - strErrMsg := "SessionNameInHttpHeader (" + cf.SessionNameInHttpHeader + ") has the wrong format, it should be like this : " + strMimeHeader + strMimeHeader := textproto.CanonicalMIMEHeaderKey(cf.SessionNameInHTTPHeader) + if cf.SessionNameInHTTPHeader != strMimeHeader { + strErrMsg := "SessionNameInHTTPHeader (" + cf.SessionNameInHTTPHeader + ") has the wrong format, it should be like this : " + strMimeHeader panic(errors.New(strErrMsg)) } } @@ -163,7 +164,7 @@ func (manager *Manager) getSid(r *http.Request) (string, error) { cookie, errs := r.Cookie(manager.config.CookieName) if errs != nil || cookie.Value == "" { var sid string - if manager.config.EnableSidInUrlQuery { + if manager.config.EnableSidInURLQuery { errs := r.ParseForm() if errs != nil { return "", errs @@ -173,8 +174,8 @@ func (manager *Manager) getSid(r *http.Request) (string, error) { } // if not found in Cookie / param, then read it from request headers - if manager.config.EnableSidInHttpHeader && sid == "" { - sids, isFound := r.Header[manager.config.SessionNameInHttpHeader] + if manager.config.EnableSidInHTTPHeader && sid == "" { + sids, isFound := r.Header[manager.config.SessionNameInHTTPHeader] if isFound && len(sids) != 0 { return sids[0], nil } @@ -226,9 +227,9 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se } r.AddCookie(cookie) - if manager.config.EnableSidInHttpHeader { - r.Header.Set(manager.config.SessionNameInHttpHeader, sid) - w.Header().Set(manager.config.SessionNameInHttpHeader, sid) + if manager.config.EnableSidInHTTPHeader { + r.Header.Set(manager.config.SessionNameInHTTPHeader, sid) + w.Header().Set(manager.config.SessionNameInHTTPHeader, sid) } return @@ -236,9 +237,9 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se // SessionDestroy Destroy session by its id in http request cookie. func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { - if manager.config.EnableSidInHttpHeader { - r.Header.Del(manager.config.SessionNameInHttpHeader) - w.Header().Del(manager.config.SessionNameInHttpHeader) + if manager.config.EnableSidInHTTPHeader { + r.Header.Del(manager.config.SessionNameInHTTPHeader) + w.Header().Del(manager.config.SessionNameInHTTPHeader) } cookie, err := r.Cookie(manager.config.CookieName) @@ -306,9 +307,9 @@ func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Reque } r.AddCookie(cookie) - if manager.config.EnableSidInHttpHeader { - r.Header.Set(manager.config.SessionNameInHttpHeader, sid) - w.Header().Set(manager.config.SessionNameInHttpHeader, sid) + if manager.config.EnableSidInHTTPHeader { + r.Header.Set(manager.config.SessionNameInHTTPHeader, sid) + w.Header().Set(manager.config.SessionNameInHTTPHeader, sid) } return @@ -328,7 +329,7 @@ func (manager *Manager) sessionID() (string, error) { b := make([]byte, manager.config.SessionIDLength) n, err := rand.Read(b) if n != len(b) || err != nil { - return "", fmt.Errorf("Could not successfully read from the system CSPRNG.") + return "", fmt.Errorf("Could not successfully read from the system CSPRNG") } return hex.EncodeToString(b), nil } diff --git a/session/ssdb/sess_ssdb.go b/session/ssdb/sess_ssdb.go index 4dcf160a..de0c6360 100644 --- a/session/ssdb/sess_ssdb.go +++ b/session/ssdb/sess_ssdb.go @@ -11,44 +11,40 @@ import ( "github.com/ssdb/gossdb/ssdb" ) -var ssdbProvider = &SsdbProvider{} +var ssdbProvider = &Provider{} -type SsdbProvider struct { +// Provider holds ssdb client and configs +type Provider struct { client *ssdb.Client host string port int maxLifetime int64 } -func (p *SsdbProvider) connectInit() error { +func (p *Provider) connectInit() error { var err error if p.host == "" || p.port == 0 { return errors.New("SessionInit First") } p.client, err = ssdb.Connect(p.host, p.port) - if err != nil { - return err - } - return nil + return err } -func (p *SsdbProvider) SessionInit(maxLifetime int64, savePath string) error { - var e error = nil +// SessionInit init the ssdb with the config +func (p *Provider) SessionInit(maxLifetime int64, savePath string) error { p.maxLifetime = maxLifetime address := strings.Split(savePath, ":") p.host = address[0] - p.port, e = strconv.Atoi(address[1]) - if e != nil { - return e - } - err := p.connectInit() - if err != nil { + + var err error + if p.port, err = strconv.Atoi(address[1]); err != nil { return err } - return nil + return p.connectInit() } -func (p *SsdbProvider) SessionRead(sid string) (session.Store, error) { +// SessionRead return a ssdb client session Store +func (p *Provider) SessionRead(sid string) (session.Store, error) { if p.client == nil { if err := p.connectInit(); err != nil { return nil, err @@ -71,7 +67,8 @@ func (p *SsdbProvider) SessionRead(sid string) (session.Store, error) { return rs, nil } -func (p *SsdbProvider) SessionExist(sid string) bool { +// SessionExist judged whether sid is exist in session +func (p *Provider) SessionExist(sid string) bool { if p.client == nil { if err := p.connectInit(); err != nil { panic(err) @@ -85,9 +82,10 @@ func (p *SsdbProvider) SessionExist(sid string) bool { return false } return true - } -func (p *SsdbProvider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + +// SessionRegenerate regenerate session with new sid and delete oldsid +func (p *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { //conn.Do("setx", key, v, ttl) if p.client == nil { if err := p.connectInit(); err != nil { @@ -119,27 +117,27 @@ func (p *SsdbProvider) SessionRegenerate(oldsid, sid string) (session.Store, err return rs, nil } -func (p *SsdbProvider) SessionDestroy(sid string) error { +// SessionDestroy destroy the sid +func (p *Provider) SessionDestroy(sid string) error { if p.client == nil { if err := p.connectInit(); err != nil { return err } } _, err := p.client.Del(sid) - if err != nil { - return err - } - return nil + return err } -func (p *SsdbProvider) SessionGC() { - return +// SessionGC not implemented +func (p *Provider) SessionGC() { } -func (p *SsdbProvider) SessionAll() int { +// SessionAll not implemented +func (p *Provider) SessionAll() int { return 0 } +// SessionStore holds the session information which stored in ssdb type SessionStore struct { sid string lock sync.RWMutex @@ -148,12 +146,15 @@ type SessionStore struct { client *ssdb.Client } +// Set the key and value func (s *SessionStore) Set(key, value interface{}) error { s.lock.Lock() defer s.lock.Unlock() s.values[key] = value return nil } + +// Get return the value by the key func (s *SessionStore) Get(key interface{}) interface{} { s.lock.Lock() defer s.lock.Unlock() @@ -163,30 +164,36 @@ func (s *SessionStore) Get(key interface{}) interface{} { return nil } +// Delete the key in session store func (s *SessionStore) Delete(key interface{}) error { s.lock.Lock() defer s.lock.Unlock() delete(s.values, key) return nil } + +// Flush delete all keys and values func (s *SessionStore) Flush() error { s.lock.Lock() defer s.lock.Unlock() s.values = make(map[interface{}]interface{}) return nil } + +// SessionID return the sessionID func (s *SessionStore) SessionID() string { return s.sid } +// SessionRelease Store the keyvalues into ssdb func (s *SessionStore) SessionRelease(w http.ResponseWriter) { b, err := session.EncodeGob(s.values) if err != nil { return } s.client.Do("setx", s.sid, string(b), s.maxLifetime) - } + func init() { session.Register("ssdb", ssdbProvider) } diff --git a/staticfile.go b/staticfile.go index b7be24f3..bbb2a1fb 100644 --- a/staticfile.go +++ b/staticfile.go @@ -90,8 +90,6 @@ func serverStaticRouter(ctx *context.Context) { } http.ServeContent(ctx.ResponseWriter, ctx.Request, filePath, sch.modTime, sch) - return - } type serveContentHolder struct { @@ -109,14 +107,14 @@ var ( func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, string, *serveContentHolder, error) { mapKey := acceptEncoding + ":" + filePath mapLock.RLock() - mapFile, _ := staticFileMap[mapKey] + mapFile := staticFileMap[mapKey] mapLock.RUnlock() if isOk(mapFile, fi) { return mapFile.encoding != "", mapFile.encoding, mapFile, nil } mapLock.Lock() defer mapLock.Unlock() - if mapFile, _ = staticFileMap[mapKey]; !isOk(mapFile, fi) { + if mapFile = staticFileMap[mapKey]; !isOk(mapFile, fi) { file, err := os.Open(filePath) if err != nil { return false, "", nil, err diff --git a/swagger/swagger.go b/swagger/swagger.go index e0ac5cf5..035d5a49 100644 --- a/swagger/swagger.go +++ b/swagger/swagger.go @@ -22,19 +22,19 @@ package swagger // Swagger list the resource type Swagger struct { - SwaggerVersion string `json:"swagger,omitempty" yaml:"swagger,omitempty"` - Infos Information `json:"info" yaml:"info"` - Host string `json:"host,omitempty" yaml:"host,omitempty"` - BasePath string `json:"basePath,omitempty" yaml:"basePath,omitempty"` - Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"` - Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"` - Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"` - Paths map[string]*Item `json:"paths" yaml:"paths"` - Definitions map[string]Schema `json:"definitions,omitempty" yaml:"definitions,omitempty"` - SecurityDefinitions map[string]Security `json:"securityDefinitions,omitempty" yaml:"securityDefinitions,omitempty"` - Security map[string][]string `json:"security,omitempty" yaml:"security,omitempty"` - Tags []Tag `json:"tags,omitempty" yaml:"tags,omitempty"` - ExternalDocs *ExternalDocs `json:"externalDocs,omitempty" yaml:"externalDocs,omitempty"` + SwaggerVersion string `json:"swagger,omitempty" yaml:"swagger,omitempty"` + Infos Information `json:"info" yaml:"info"` + Host string `json:"host,omitempty" yaml:"host,omitempty"` + BasePath string `json:"basePath,omitempty" yaml:"basePath,omitempty"` + Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"` + Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"` + Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"` + Paths map[string]*Item `json:"paths" yaml:"paths"` + Definitions map[string]Schema `json:"definitions,omitempty" yaml:"definitions,omitempty"` + SecurityDefinitions map[string]Security `json:"securityDefinitions,omitempty" yaml:"securityDefinitions,omitempty"` + Security []map[string][]string `json:"security,omitempty" yaml:"security,omitempty"` + Tags []Tag `json:"tags,omitempty" yaml:"tags,omitempty"` + ExternalDocs *ExternalDocs `json:"externalDocs,omitempty" yaml:"externalDocs,omitempty"` } // Information Provides metadata about the API. The metadata can be used by the clients if needed. @@ -75,16 +75,17 @@ type Item struct { // Operation Describes a single API operation on a path. type Operation struct { - Tags []string `json:"tags,omitempty" yaml:"tags,omitempty"` - Summary string `json:"summary,omitempty" yaml:"summary,omitempty"` - Description string `json:"description,omitempty" yaml:"description,omitempty"` - OperationID string `json:"operationId,omitempty" yaml:"operationId,omitempty"` - Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"` - Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"` - Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"` - Parameters []Parameter `json:"parameters,omitempty" yaml:"parameters,omitempty"` - Responses map[string]Response `json:"responses,omitempty" yaml:"responses,omitempty"` - Deprecated bool `json:"deprecated,omitempty" yaml:"deprecated,omitempty"` + Tags []string `json:"tags,omitempty" yaml:"tags,omitempty"` + Summary string `json:"summary,omitempty" yaml:"summary,omitempty"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + OperationID string `json:"operationId,omitempty" yaml:"operationId,omitempty"` + Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"` + Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"` + Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"` + Parameters []Parameter `json:"parameters,omitempty" yaml:"parameters,omitempty"` + Responses map[string]Response `json:"responses,omitempty" yaml:"responses,omitempty"` + Security []map[string][]string `json:"security,omitempty" yaml:"security,omitempty"` + Deprecated bool `json:"deprecated,omitempty" yaml:"deprecated,omitempty"` } // Parameter Describes a single operation parameter. @@ -100,7 +101,7 @@ type Parameter struct { Default interface{} `json:"default,omitempty" yaml:"default,omitempty"` } -// A limited subset of JSON-Schema's items object. It is used by parameter definitions that are not located in "body". +// ParameterItems A limited subset of JSON-Schema's items object. It is used by parameter definitions that are not located in "body". // http://swagger.io/specification/#itemsObject type ParameterItems struct { Type string `json:"type,omitempty" yaml:"type,omitempty"` diff --git a/template.go b/template.go index 17c18591..d4859cd7 100644 --- a/template.go +++ b/template.go @@ -31,11 +31,11 @@ import ( ) var ( - beegoTplFuncMap = make(template.FuncMap) + beegoTplFuncMap = make(template.FuncMap) beeViewPathTemplateLocked = false // beeViewPathTemplates caching map and supported template file extensions per view - beeViewPathTemplates = make(map[string]map[string]*template.Template) - templatesLock sync.RWMutex + beeViewPathTemplates = make(map[string]map[string]*template.Template) + templatesLock sync.RWMutex // beeTemplateExt stores the template extension which will build beeTemplateExt = []string{"tpl", "html"} // beeTemplatePreprocessors stores associations of extension -> preprocessor handler @@ -46,7 +46,7 @@ var ( // writing the output to wr. // A template will be executed safely in parallel. func ExecuteTemplate(wr io.Writer, name string, data interface{}) error { - return ExecuteViewPathTemplate(wr,name, BConfig.WebConfig.ViewsPath, data) + return ExecuteViewPathTemplate(wr, name, BConfig.WebConfig.ViewsPath, data) } // ExecuteViewPathTemplate applies the template with name and from specific viewPath to the specified data object, @@ -57,7 +57,7 @@ func ExecuteViewPathTemplate(wr io.Writer, name string, viewPath string, data in templatesLock.RLock() defer templatesLock.RUnlock() } - if beeTemplates,ok := beeViewPathTemplates[viewPath]; ok { + if beeTemplates, ok := beeViewPathTemplates[viewPath]; ok { if t, ok := beeTemplates[name]; ok { var err error if t.Lookup(name) != nil { @@ -72,7 +72,7 @@ func ExecuteViewPathTemplate(wr io.Writer, name string, viewPath string, data in } panic("can't find templatefile in the path:" + viewPath + "/" + name) } - panic("Uknown view path:" + viewPath) + panic("Unknown view path:" + viewPath) } func init() { @@ -160,11 +160,14 @@ func AddTemplateExt(ext string) { beeTemplateExt = append(beeTemplateExt, ext) } -// AddViewPath adds a new path to the supported view paths. +// AddViewPath adds a new path to the supported view paths. //Can later be used by setting a controller ViewPath to this folder -//will panic if called after beego.Run() +//will panic if called after beego.Run() func AddViewPath(viewPath string) error { if beeViewPathTemplateLocked { + if _, exist := beeViewPathTemplates[viewPath]; exist { + return nil //Ignore if viewpath already exists + } panic("Can not add new view paths after beego.Run()") } beeViewPathTemplates[viewPath] = make(map[string]*template.Template) @@ -184,7 +187,7 @@ func BuildTemplate(dir string, files ...string) error { } return errors.New("dir open err") } - beeTemplates,ok := beeViewPathTemplates[dir]; + beeTemplates, ok := beeViewPathTemplates[dir] if !ok { panic("Unknown view path: " + dir) } @@ -214,7 +217,7 @@ func BuildTemplate(dir string, files ...string) error { t, err = getTemplate(self.root, file, v...) } if err != nil { - logs.Trace("parse template err:", file, err) + logs.Error("parse template err:", file, err) } else { beeTemplates[file] = t } @@ -227,9 +230,12 @@ func BuildTemplate(dir string, files ...string) error { func getTplDeep(root, file, parent string, t *template.Template) (*template.Template, [][]string, error) { var fileAbsPath string + var rParent string if filepath.HasPrefix(file, "../") { + rParent = filepath.Join(filepath.Dir(parent), file) fileAbsPath = filepath.Join(root, filepath.Dir(parent), file) } else { + rParent = file fileAbsPath = filepath.Join(root, file) } if e := utils.FileExists(fileAbsPath); !e { @@ -254,7 +260,7 @@ func getTplDeep(root, file, parent string, t *template.Template) (*template.Temp if !HasTemplateExt(m[1]) { continue } - _, _, err = getTplDeep(root, m[1], file, t) + _, _, err = getTplDeep(root, m[1], rParent, t) if err != nil { return nil, [][]string{}, err } @@ -293,7 +299,7 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others t, subMods1, err = getTplDeep(root, otherFile, "", t) if err != nil { logs.Trace("template parse file err:", err) - } else if subMods1 != nil && len(subMods1) > 0 { + } else if len(subMods1) > 0 { t, err = _getTemplate(t, root, subMods1, others...) } break @@ -301,8 +307,9 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others } //second check define for _, otherFile := range others { + var data []byte fileAbsPath := filepath.Join(root, otherFile) - data, err := ioutil.ReadFile(fileAbsPath) + data, err = ioutil.ReadFile(fileAbsPath) if err != nil { continue } @@ -314,7 +321,7 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others t, subMods1, err = getTplDeep(root, otherFile, "", t) if err != nil { logs.Trace("template parse file err:", err) - } else if subMods1 != nil && len(subMods1) > 0 { + } else if len(subMods1) > 0 { t, err = _getTemplate(t, root, subMods1, others...) } break @@ -358,6 +365,7 @@ func DelStaticPath(url string) *App { return BeeApp } +// AddTemplateEngine add a new templatePreProcessor which support extension func AddTemplateEngine(extension string, fn templatePreProcessor) *App { AddTemplateExt(extension) beeTemplateEngines[extension] = fn diff --git a/template_test.go b/template_test.go index 17690965..2153ef72 100644 --- a/template_test.go +++ b/template_test.go @@ -15,6 +15,7 @@ package beego import ( + "bytes" "os" "path/filepath" "testing" @@ -142,3 +143,116 @@ func TestRelativeTemplate(t *testing.T) { } os.RemoveAll(dir) } + +var add = `{{ template "layout_blog.tpl" . }} +{{ define "css" }} + +{{ end}} + + +{{ define "content" }} +

{{ .Title }}

+

This is SomeVar: {{ .SomeVar }}

+{{ end }} + +{{ define "js" }} + +{{ end}}` + +var layoutBlog = ` + + + Lin Li + + + + + {{ block "css" . }}{{ end }} + + + +
+ {{ block "content" . }}{{ end }} +
+ + + {{ block "js" . }}{{ end }} + +` + +var output = ` + + + Lin Li + + + + + + + + + + +
+ +

Hello

+

This is SomeVar: val

+ +
+ + + + + + + + + + + + +` + +func TestTemplateLayout(t *testing.T) { + dir := "_beeTmp" + files := []string{ + "add.tpl", + "layout_blog.tpl", + } + if err := os.MkdirAll(dir, 0777); err != nil { + t.Fatal(err) + } + for k, name := range files { + os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + if f, err := os.Create(filepath.Join(dir, name)); err != nil { + t.Fatal(err) + } else { + if k == 0 { + f.WriteString(add) + } else if k == 1 { + f.WriteString(layoutBlog) + } + f.Close() + } + } + if err := AddViewPath(dir); err != nil { + t.Fatal(err) + } + beeTemplates := beeViewPathTemplates[dir] + if len(beeTemplates) != 2 { + t.Fatalf("should be 2 but got %v", len(beeTemplates)) + } + out := bytes.NewBufferString("") + if err := beeTemplates["add.tpl"].ExecuteTemplate(out, "add.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil { + t.Fatal(err) + } + if out.String() != output { + t.Log(out.String()) + t.Fatal("Compare failed") + } + for _, name := range files { + os.RemoveAll(filepath.Join(dir, name)) + } + os.RemoveAll(dir) +} diff --git a/templatefunc.go b/templatefunc.go index 01751717..a104fd24 100644 --- a/templatefunc.go +++ b/templatefunc.go @@ -26,6 +26,13 @@ import ( "time" ) +const ( + formatTime = "15:04:05" + formatDate = "2006-01-02" + formatDateTime = "2006-01-02 15:04:05" + formatDateTimeT = "2006-01-02T15:04:05" +) + // Substr returns the substr from start to length. func Substr(s string, start, length int) string { bt := []rune(s) @@ -46,26 +53,25 @@ func Substr(s string, start, length int) string { // HTML2str returns escaping text convert from html. func HTML2str(html string) string { - src := string(html) - re, _ := regexp.Compile("\\<[\\S\\s]+?\\>") - src = re.ReplaceAllStringFunc(src, strings.ToLower) + re, _ := regexp.Compile(`\<[\S\s]+?\>`) + html = re.ReplaceAllStringFunc(html, strings.ToLower) //remove STYLE - re, _ = regexp.Compile("\\") - src = re.ReplaceAllString(src, "") + re, _ = regexp.Compile(`\`) + html = re.ReplaceAllString(html, "") //remove SCRIPT - re, _ = regexp.Compile("\\") - src = re.ReplaceAllString(src, "") + re, _ = regexp.Compile(`\`) + html = re.ReplaceAllString(html, "") - re, _ = regexp.Compile("\\<[\\S\\s]+?\\>") - src = re.ReplaceAllString(src, "\n") + re, _ = regexp.Compile(`\<[\S\s]+?\>`) + html = re.ReplaceAllString(html, "\n") - re, _ = regexp.Compile("\\s{2,}") - src = re.ReplaceAllString(src, "\n") + re, _ = regexp.Compile(`\s{2,}`) + html = re.ReplaceAllString(html, "\n") - return strings.TrimSpace(src) + return strings.TrimSpace(html) } // DateFormat takes a time and a layout string and returns a string with the formatted date. Used by the template parser as "dateformat" @@ -193,7 +199,7 @@ func Str2html(raw string) template.HTML { } // Htmlquote returns quoted html string. -func Htmlquote(src string) string { +func Htmlquote(text string) string { //HTML编码为实体符号 /* Encodes `text` for raw use in HTML. @@ -201,8 +207,6 @@ func Htmlquote(src string) string { '<'&">' */ - text := string(src) - text = strings.Replace(text, "&", "&", -1) // Must be done first! text = strings.Replace(text, "<", "<", -1) text = strings.Replace(text, ">", ">", -1) @@ -216,7 +220,7 @@ func Htmlquote(src string) string { } // Htmlunquote returns unquoted html string. -func Htmlunquote(src string) string { +func Htmlunquote(text string) string { //实体符号解释为HTML /* Decodes `text` that's HTML quoted. @@ -227,7 +231,6 @@ func Htmlunquote(src string) string { // strings.Replace(s, old, new, n) // 在s字符串中,把old字符串替换为new字符串,n表示替换的次数,小于0表示全部替换 - text := string(src) text = strings.Replace(text, " ", " ", -1) text = strings.Replace(text, "”", "”", -1) text = strings.Replace(text, "“", "“", -1) @@ -262,19 +265,17 @@ func URLFor(endpoint string, values ...interface{}) string { } // AssetsJs returns script tag with src string. -func AssetsJs(src string) template.HTML { - text := string(src) +func AssetsJs(text string) template.HTML { - text = "" + text = "" return template.HTML(text) } // AssetsCSS returns stylesheet link tag with src string. -func AssetsCSS(src string) template.HTML { - text := string(src) +func AssetsCSS(text string) template.HTML { - text = "" + text = "" return template.HTML(text) } @@ -352,11 +353,32 @@ func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) e case reflect.Struct: switch fieldT.Type.String() { case "time.Time": - format := time.RFC3339 - if len(tags) > 1 { - format = tags[1] + var ( + t time.Time + err error + ) + if len(value) >= 25 { + value = value[:25] + t, err = time.ParseInLocation(time.RFC3339, value, time.Local) + } else if len(value) >= 19 { + if strings.Contains(value, "T") { + value = value[:19] + t, err = time.ParseInLocation(formatDateTimeT, value, time.Local) + } else { + value = value[:19] + t, err = time.ParseInLocation(formatDateTime, value, time.Local) + } + } else if len(value) >= 10 { + if len(value) > 10 { + value = value[:10] + } + t, err = time.ParseInLocation(formatDate, value, time.Local) + } else if len(value) >= 8 { + if len(value) > 8 { + value = value[:8] + } + t, err = time.ParseInLocation(formatTime, value, time.Local) } - t, err := time.ParseInLocation(format, value, time.Local) if err != nil { return err } @@ -490,9 +512,9 @@ func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id str class = fieldT.Tag.Get("class") required = false - required_field := fieldT.Tag.Get("required") - if required_field != "-" && required_field != "" { - required, _ = strconv.ParseBool(required_field) + requiredField := fieldT.Tag.Get("required") + if requiredField != "-" && requiredField != "" { + required, _ = strconv.ParseBool(requiredField) } switch len(tags) { diff --git a/templatefunc_test.go b/templatefunc_test.go index a1ec1136..9df61125 100644 --- a/templatefunc_test.go +++ b/templatefunc_test.go @@ -173,7 +173,7 @@ func TestParseForm(t *testing.T) { if u.Intro != "I am an engineer!" { t.Errorf("Intro should equal `I am an engineer!` but got `%v`", u.Intro) } - if u.StrBool != true { + if !u.StrBool { t.Errorf("strboll should equal `true`, but got `%v`", u.StrBool) } y, m, d := u.Date.Date() @@ -254,44 +254,44 @@ func TestParseFormTag(t *testing.T) { objT := reflect.TypeOf(&user{}).Elem() - label, name, fType, id, class, ignored, required := parseFormTag(objT.Field(0)) - if !(name == "name" && label == "年龄:" && fType == "text" && ignored == false) { + label, name, fType, _, _, ignored, _ := parseFormTag(objT.Field(0)) + if !(name == "name" && label == "年龄:" && fType == "text" && !ignored) { t.Errorf("Form Tag with name, label and type was not correctly parsed.") } - label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(1)) - if !(name == "NoName" && label == "年龄:" && fType == "hidden" && ignored == false) { + label, name, fType, _, _, ignored, _ = parseFormTag(objT.Field(1)) + if !(name == "NoName" && label == "年龄:" && fType == "hidden" && !ignored) { t.Errorf("Form Tag with label and type but without name was not correctly parsed.") } - label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(2)) - if !(name == "OnlyLabel" && label == "年龄:" && fType == "text" && ignored == false) { + label, name, fType, _, _, ignored, _ = parseFormTag(objT.Field(2)) + if !(name == "OnlyLabel" && label == "年龄:" && fType == "text" && !ignored) { t.Errorf("Form Tag containing only label was not correctly parsed.") } - label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(3)) - if !(name == "name" && label == "OnlyName: " && fType == "text" && ignored == false && + label, name, fType, id, class, ignored, _ := parseFormTag(objT.Field(3)) + if !(name == "name" && label == "OnlyName: " && fType == "text" && !ignored && id == "name" && class == "form-name") { t.Errorf("Form Tag containing only name was not correctly parsed.") } - label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(4)) - if ignored == false { + _, _, _, _, _, ignored, _ = parseFormTag(objT.Field(4)) + if !ignored { t.Errorf("Form Tag that should be ignored was not correctly parsed.") } - label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(5)) - if !(name == "name" && required == true) { + _, name, _, _, _, _, required := parseFormTag(objT.Field(5)) + if !(name == "name" && required) { t.Errorf("Form Tag containing only name and required was not correctly parsed.") } - label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(6)) - if !(name == "name" && required == false) { + _, name, _, _, _, _, required = parseFormTag(objT.Field(6)) + if !(name == "name" && !required) { t.Errorf("Form Tag containing only name and ignore required was not correctly parsed.") } - label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(7)) - if !(name == "name" && required == false) { + _, name, _, _, _, _, required = parseFormTag(objT.Field(7)) + if !(name == "name" && !required) { t.Errorf("Form Tag containing only name and not required was not correctly parsed.") } diff --git a/toolbox/statistics.go b/toolbox/statistics.go index c6a9489f..d014544c 100644 --- a/toolbox/statistics.go +++ b/toolbox/statistics.go @@ -119,7 +119,7 @@ func (m *URLMap) GetMap() map[string]interface{} { func (m *URLMap) GetMapData() []map[string]interface{} { m.lock.Lock() defer m.lock.Unlock() - + var resultLists []map[string]interface{} for k, v := range m.urlmap { diff --git a/toolbox/task.go b/toolbox/task.go index abd411c8..672717cd 100644 --- a/toolbox/task.go +++ b/toolbox/task.go @@ -427,6 +427,7 @@ func run() { } continue case <-changed: + now = time.Now().Local() continue case <-stop: return diff --git a/tree.go b/tree.go index 25b78e50..2d6c3fc3 100644 --- a/tree.go +++ b/tree.go @@ -288,10 +288,10 @@ func (t *Tree) Match(pattern string, ctx *context.Context) (runObject interface{ return nil } w := make([]string, 0, 20) - return t.match(pattern, w, ctx) + return t.match(pattern[1:], pattern, w, ctx) } -func (t *Tree) match(pattern string, wildcardValues []string, ctx *context.Context) (runObject interface{}) { +func (t *Tree) match(treePattern string, pattern string, wildcardValues []string, ctx *context.Context) (runObject interface{}) { if len(pattern) > 0 { i := 0 for ; i < len(pattern) && pattern[i] == '/'; i++ { @@ -301,13 +301,13 @@ func (t *Tree) match(pattern string, wildcardValues []string, ctx *context.Conte // Handle leaf nodes: if len(pattern) == 0 { for _, l := range t.leaves { - if ok := l.match(wildcardValues, ctx); ok { + if ok := l.match(treePattern, wildcardValues, ctx); ok { return l.runObject } } if t.wildcard != nil { for _, l := range t.wildcard.leaves { - if ok := l.match(wildcardValues, ctx); ok { + if ok := l.match(treePattern, wildcardValues, ctx); ok { return l.runObject } } @@ -327,7 +327,12 @@ func (t *Tree) match(pattern string, wildcardValues []string, ctx *context.Conte } for _, subTree := range t.fixrouters { if subTree.prefix == seg { - runObject = subTree.match(pattern, wildcardValues, ctx) + if len(pattern) != 0 && pattern[0] == '/' { + treePattern = pattern[1:] + } else { + treePattern = pattern + } + runObject = subTree.match(treePattern, pattern, wildcardValues, ctx) if runObject != nil { break } @@ -339,7 +344,7 @@ func (t *Tree) match(pattern string, wildcardValues []string, ctx *context.Conte if strings.HasSuffix(seg, str) { for _, subTree := range t.fixrouters { if subTree.prefix == seg[:len(seg)-len(str)] { - runObject = subTree.match(pattern, wildcardValues, ctx) + runObject = subTree.match(treePattern, pattern, wildcardValues, ctx) if runObject != nil { ctx.Input.SetParam(":ext", str[1:]) } @@ -349,7 +354,7 @@ func (t *Tree) match(pattern string, wildcardValues []string, ctx *context.Conte } } if runObject == nil && t.wildcard != nil { - runObject = t.wildcard.match(pattern, append(wildcardValues, seg), ctx) + runObject = t.wildcard.match(treePattern, pattern, append(wildcardValues, seg), ctx) } if runObject == nil && len(t.leaves) > 0 { @@ -368,7 +373,7 @@ func (t *Tree) match(pattern string, wildcardValues []string, ctx *context.Conte wildcardValues = append(wildcardValues, pattern[start:i]) } for _, l := range t.leaves { - if ok := l.match(wildcardValues, ctx); ok { + if ok := l.match(treePattern, wildcardValues, ctx); ok { return l.runObject } } @@ -386,7 +391,7 @@ type leafInfo struct { runObject interface{} } -func (leaf *leafInfo) match(wildcardValues []string, ctx *context.Context) (ok bool) { +func (leaf *leafInfo) match(treePattern string, wildcardValues []string, ctx *context.Context) (ok bool) { //fmt.Println("Leaf:", wildcardValues, leaf.wildcards, leaf.regexps) if leaf.regexps == nil { if len(wildcardValues) == 0 && len(leaf.wildcards) == 0 { // static path @@ -394,7 +399,7 @@ func (leaf *leafInfo) match(wildcardValues []string, ctx *context.Context) (ok b } // match * if len(leaf.wildcards) == 1 && leaf.wildcards[0] == ":splat" { - ctx.Input.SetParam(":splat", path.Join(wildcardValues...)) + ctx.Input.SetParam(":splat", treePattern) return true } // match *.* or :id diff --git a/tree_test.go b/tree_test.go index 81ff7edd..d412a348 100644 --- a/tree_test.go +++ b/tree_test.go @@ -42,7 +42,7 @@ func init() { routers = append(routers, testinfo{"/", "/", nil}) routers = append(routers, testinfo{"/customer/login", "/customer/login", nil}) routers = append(routers, testinfo{"/customer/login", "/customer/login.json", map[string]string{":ext": "json"}}) - routers = append(routers, testinfo{"/*", "/customer/123", map[string]string{":splat": "customer/123"}}) + routers = append(routers, testinfo{"/*", "/http://customer/123/", map[string]string{":splat": "http://customer/123/"}}) routers = append(routers, testinfo{"/*", "/customer/2009/12/11", map[string]string{":splat": "customer/2009/12/11"}}) routers = append(routers, testinfo{"/aa/*/bb", "/aa/2009/bb", map[string]string{":splat": "2009"}}) routers = append(routers, testinfo{"/cc/*/dd", "/cc/2009/11/dd", map[string]string{":splat": "2009/11"}}) diff --git a/utils/captcha/image.go b/utils/captcha/image.go index 0ceb8e42..c3c9a83a 100644 --- a/utils/captcha/image.go +++ b/utils/captcha/image.go @@ -474,7 +474,7 @@ func randomBrightness(c color.RGBA, max uint8) color.RGBA { uint8(int(c.R) + n), uint8(int(c.G) + n), uint8(int(c.B) + n), - uint8(c.A), + c.A, } } diff --git a/utils/file.go b/utils/file.go index db197882..6090eb17 100644 --- a/utils/file.go +++ b/utils/file.go @@ -72,7 +72,7 @@ func GrepFile(patten string, filename string) (lines []string, err error) { lines = make([]string, 0) reader := bufio.NewReader(fd) prefix := "" - isLongLine := false + var isLongLine bool for { byteLine, isPrefix, er := reader.ReadLine() if er != nil && er != io.EOF { diff --git a/utils/file_test.go b/utils/file_test.go index 86d1a700..b2644157 100644 --- a/utils/file_test.go +++ b/utils/file_test.go @@ -54,7 +54,7 @@ func TestSearchFile(t *testing.T) { _, err = SearchFile(noExistedFile, ".") if err == nil { - t.Errorf("err shouldnot be nil, got path: %s", SelfDir()) + t.Errorf("err shouldnt be nil, got path: %s", SelfDir()) } } diff --git a/utils/utils.go b/utils/utils.go new file mode 100644 index 00000000..ed885787 --- /dev/null +++ b/utils/utils.go @@ -0,0 +1,30 @@ +package utils + +import ( + "os" + "path/filepath" + "runtime" + "strings" +) + +// GetGOPATHs returns all paths in GOPATH variable. +func GetGOPATHs() []string { + gopath := os.Getenv("GOPATH") + if gopath == "" && strings.Compare(runtime.Version(), "go1.8") >= 0 { + gopath = defaultGOPATH() + } + return filepath.SplitList(gopath) +} + +func defaultGOPATH() string { + env := "HOME" + if runtime.GOOS == "windows" { + env = "USERPROFILE" + } else if runtime.GOOS == "plan9" { + env = "home" + } + if home := os.Getenv(env); home != "" { + return filepath.Join(home, "go") + } + return "" +} diff --git a/validation/util_test.go b/validation/util_test.go index d7e10506..e74d50ed 100644 --- a/validation/util_test.go +++ b/validation/util_test.go @@ -42,7 +42,7 @@ func TestGetValidFuncs(t *testing.T) { } f, _ = tf.FieldByName("Tag") - if vfs, err = getValidFuncs(f); err.Error() != "doesn't exsits Maxx valid function" { + if _, err = getValidFuncs(f); err.Error() != "doesn't exsits Maxx valid function" { t.Fatal(err) } diff --git a/validation/validation.go b/validation/validation.go index 489dfa5e..9dc51106 100644 --- a/validation/validation.go +++ b/validation/validation.go @@ -349,7 +349,7 @@ func (v *Validation) RecursiveValid(objc interface{}) (bool, error) { //Step 1: validate obj itself firstly // fails if objc is not struct pass, err := v.Valid(objc) - if err != nil || false == pass { + if err != nil || !pass { return pass, err // Stop recursive validation } // Step 2: Validate struct's struct fields diff --git a/validation/validation_test.go b/validation/validation_test.go index ec65b6d0..71a8bd17 100644 --- a/validation/validation_test.go +++ b/validation/validation_test.go @@ -35,6 +35,12 @@ func TestRequired(t *testing.T) { if valid.Required("", "string").Ok { t.Error("\"'\" string should be false") } + if valid.Required(" ", "string").Ok { + t.Error("\" \" string should be false") // For #2361 + } + if valid.Required("\n", "string").Ok { + t.Error("new line string should be false") // For #2361 + } if !valid.Required("astaxie", "string").Ok { t.Error("string should be true") } @@ -175,10 +181,10 @@ func TestAlphaNumeric(t *testing.T) { func TestMatch(t *testing.T) { valid := Validation{} - if valid.Match("suchuangji@gmail", regexp.MustCompile("^\\w+@\\w+\\.\\w+$"), "match").Ok { + if valid.Match("suchuangji@gmail", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be false") } - if !valid.Match("suchuangji@gmail.com", regexp.MustCompile("^\\w+@\\w+\\.\\w+$"), "match").Ok { + if !valid.Match("suchuangji@gmail.com", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be true") } } @@ -186,10 +192,10 @@ func TestMatch(t *testing.T) { func TestNoMatch(t *testing.T) { valid := Validation{} - if valid.NoMatch("123@gmail", regexp.MustCompile("[^\\w\\d]"), "nomatch").Ok { + if valid.NoMatch("123@gmail", regexp.MustCompile(`[^\w\d]`), "nomatch").Ok { t.Error("\"123@gmail\" not match \"[^\\w\\d]\" should be false") } - if !valid.NoMatch("123gmail", regexp.MustCompile("[^\\w\\d]"), "match").Ok { + if !valid.NoMatch("123gmail", regexp.MustCompile(`[^\w\d]`), "match").Ok { t.Error("\"123@gmail\" not match \"[^\\w\\d@]\" should be true") } } @@ -214,6 +220,12 @@ func TestEmail(t *testing.T) { if !valid.Email("suchuangji@gmail.com", "email").Ok { t.Error("\"suchuangji@gmail.com\" is a valid email address should be true") } + if valid.Email("@suchuangji@gmail.com", "email").Ok { + t.Error("\"@suchuangji@gmail.com\" is a valid email address should be false") + } + if valid.Email("suchuangji@gmail.com ok", "email").Ok { + t.Error("\"suchuangji@gmail.com ok\" is a valid email address should be false") + } } func TestIP(t *testing.T) { diff --git a/validation/validators.go b/validation/validators.go index 9b04c5ce..5d489a55 100644 --- a/validation/validators.go +++ b/validation/validators.go @@ -18,6 +18,7 @@ import ( "fmt" "reflect" "regexp" + "strings" "time" "unicode/utf8" ) @@ -98,7 +99,7 @@ func (r Required) IsSatisfied(obj interface{}) bool { } if str, ok := obj.(string); ok { - return len(str) > 0 + return len(strings.TrimSpace(str)) > 0 } if _, ok := obj.(bool); ok { return true @@ -145,7 +146,7 @@ func (r Required) IsSatisfied(obj interface{}) bool { // DefaultMessage return the default error message func (r Required) DefaultMessage() string { - return fmt.Sprint(MessageTmpls["Required"]) + return MessageTmpls["Required"] } // GetKey return the r.Key @@ -364,7 +365,7 @@ func (a Alpha) IsSatisfied(obj interface{}) bool { // DefaultMessage return the default Length error message func (a Alpha) DefaultMessage() string { - return fmt.Sprint(MessageTmpls["Alpha"]) + return MessageTmpls["Alpha"] } // GetKey return the m.Key @@ -397,7 +398,7 @@ func (n Numeric) IsSatisfied(obj interface{}) bool { // DefaultMessage return the default Length error message func (n Numeric) DefaultMessage() string { - return fmt.Sprint(MessageTmpls["Numeric"]) + return MessageTmpls["Numeric"] } // GetKey return the n.Key @@ -430,7 +431,7 @@ func (a AlphaNumeric) IsSatisfied(obj interface{}) bool { // DefaultMessage return the default Length error message func (a AlphaNumeric) DefaultMessage() string { - return fmt.Sprint(MessageTmpls["AlphaNumeric"]) + return MessageTmpls["AlphaNumeric"] } // GetKey return the a.Key @@ -495,7 +496,7 @@ func (n NoMatch) GetLimitValue() interface{} { return n.Regexp.String() } -var alphaDashPattern = regexp.MustCompile("[^\\d\\w-_]") +var alphaDashPattern = regexp.MustCompile(`[^\d\w-_]`) // AlphaDash check not Alpha type AlphaDash struct { @@ -505,7 +506,7 @@ type AlphaDash struct { // DefaultMessage return the default AlphaDash error message func (a AlphaDash) DefaultMessage() string { - return fmt.Sprint(MessageTmpls["AlphaDash"]) + return MessageTmpls["AlphaDash"] } // GetKey return the n.Key @@ -518,7 +519,7 @@ func (a AlphaDash) GetLimitValue() interface{} { return nil } -var emailPattern = regexp.MustCompile("[\\w!#$%&'*+/=?^_`{|}~-]+(?:\\.[\\w!#$%&'*+/=?^_`{|}~-]+)*@(?:[\\w](?:[\\w-]*[\\w])?\\.)+[a-zA-Z0-9](?:[\\w-]*[\\w])?") +var emailPattern = regexp.MustCompile(`^[\w!#$%&'*+/=?^_` + "`" + `{|}~-]+(?:\.[\w!#$%&'*+/=?^_` + "`" + `{|}~-]+)*@(?:[\w](?:[\w-]*[\w])?\.)+[a-zA-Z0-9](?:[\w-]*[\w])?$`) // Email check struct type Email struct { @@ -528,7 +529,7 @@ type Email struct { // DefaultMessage return the default Email error message func (e Email) DefaultMessage() string { - return fmt.Sprint(MessageTmpls["Email"]) + return MessageTmpls["Email"] } // GetKey return the n.Key @@ -541,7 +542,7 @@ func (e Email) GetLimitValue() interface{} { return nil } -var ipPattern = regexp.MustCompile("^((2[0-4]\\d|25[0-5]|[01]?\\d\\d?)\\.){3}(2[0-4]\\d|25[0-5]|[01]?\\d\\d?)$") +var ipPattern = regexp.MustCompile(`^((2[0-4]\d|25[0-5]|[01]?\d\d?)\.){3}(2[0-4]\d|25[0-5]|[01]?\d\d?)$`) // IP check struct type IP struct { @@ -551,7 +552,7 @@ type IP struct { // DefaultMessage return the default IP error message func (i IP) DefaultMessage() string { - return fmt.Sprint(MessageTmpls["IP"]) + return MessageTmpls["IP"] } // GetKey return the i.Key @@ -564,7 +565,7 @@ func (i IP) GetLimitValue() interface{} { return nil } -var base64Pattern = regexp.MustCompile("^(?:[A-Za-z0-99+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?$") +var base64Pattern = regexp.MustCompile(`^(?:[A-Za-z0-99+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?$`) // Base64 check struct type Base64 struct { @@ -574,7 +575,7 @@ type Base64 struct { // DefaultMessage return the default Base64 error message func (b Base64) DefaultMessage() string { - return fmt.Sprint(MessageTmpls["Base64"]) + return MessageTmpls["Base64"] } // GetKey return the b.Key @@ -588,7 +589,7 @@ func (b Base64) GetLimitValue() interface{} { } // just for chinese mobile phone number -var mobilePattern = regexp.MustCompile("^((\\+86)|(86))?(1(([35][0-9])|[8][0-9]|[7][06789]|[4][579]))\\d{8}$") +var mobilePattern = regexp.MustCompile(`^((\+86)|(86))?(1(([35][0-9])|[8][0-9]|[7][06789]|[4][579]))\d{8}$`) // Mobile check struct type Mobile struct { @@ -598,7 +599,7 @@ type Mobile struct { // DefaultMessage return the default Mobile error message func (m Mobile) DefaultMessage() string { - return fmt.Sprint(MessageTmpls["Mobile"]) + return MessageTmpls["Mobile"] } // GetKey return the m.Key @@ -612,7 +613,7 @@ func (m Mobile) GetLimitValue() interface{} { } // just for chinese telephone number -var telPattern = regexp.MustCompile("^(0\\d{2,3}(\\-)?)?\\d{7,8}$") +var telPattern = regexp.MustCompile(`^(0\d{2,3}(\-)?)?\d{7,8}$`) // Tel check telephone struct type Tel struct { @@ -622,7 +623,7 @@ type Tel struct { // DefaultMessage return the default Tel error message func (t Tel) DefaultMessage() string { - return fmt.Sprint(MessageTmpls["Tel"]) + return MessageTmpls["Tel"] } // GetKey return the t.Key @@ -649,7 +650,7 @@ func (p Phone) IsSatisfied(obj interface{}) bool { // DefaultMessage return the default Phone error message func (p Phone) DefaultMessage() string { - return fmt.Sprint(MessageTmpls["Phone"]) + return MessageTmpls["Phone"] } // GetKey return the p.Key @@ -663,7 +664,7 @@ func (p Phone) GetLimitValue() interface{} { } // just for chinese zipcode -var zipCodePattern = regexp.MustCompile("^[1-9]\\d{5}$") +var zipCodePattern = regexp.MustCompile(`^[1-9]\d{5}$`) // ZipCode check the zip struct type ZipCode struct { @@ -673,7 +674,7 @@ type ZipCode struct { // DefaultMessage return the default Zip error message func (z ZipCode) DefaultMessage() string { - return fmt.Sprint(MessageTmpls["ZipCode"]) + return MessageTmpls["ZipCode"] } // GetKey return the z.Key