diff --git a/.travis.yml b/.travis.yml index 479d70ca..aa88a44d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -33,6 +33,8 @@ install: - go get github.com/ssdb/gossdb/ssdb - go get github.com/cloudflare/golz4 - go get github.com/gogo/protobuf/proto + - go get github.com/Knetic/govaluate + - go get github.com/hsluoyz/casbin - go get -u honnef.co/go/tools/cmd/gosimple - go get -u github.com/mdempsky/unconvert - go get -u github.com/gordonklaus/ineffassign diff --git a/beego.go b/beego.go index 7a8db390..22079a20 100644 --- a/beego.go +++ b/beego.go @@ -23,7 +23,7 @@ import ( const ( // VERSION represent beego web framework version. - VERSION = "1.8.2" + VERSION = "1.8.3" // DEV is for develop DEV = "dev" diff --git a/cache/conv.go b/cache/conv.go index dbdff1c7..87800586 100644 --- a/cache/conv.go +++ b/cache/conv.go @@ -28,7 +28,7 @@ func GetString(v interface{}) string { return string(result) default: if v != nil { - return fmt.Sprintf("%v", result) + return fmt.Sprint(result) } } return "" diff --git a/context/context.go b/context/context.go index 03286097..8b32062c 100644 --- a/context/context.go +++ b/context/context.go @@ -171,6 +171,22 @@ func (ctx *Context) CheckXSRFCookie() bool { return true } +// RenderMethodResult renders the return value of a controller method to the output +func (ctx *Context) RenderMethodResult(result interface{}) { + if result != nil { + renderer, ok := result.(Renderer) + if !ok { + err, ok := result.(error) + if ok { + renderer = errorRenderer(err) + } else { + renderer = jsonRenderer(result) + } + } + renderer.Render(ctx) + } +} + //Response is a wrapper for the http.ResponseWriter //started set to true if response was written to then don't execute other handler type Response struct { diff --git a/context/output.go b/context/output.go index 564ef96d..cf9e7a7e 100644 --- a/context/output.go +++ b/context/output.go @@ -168,6 +168,19 @@ func sanitizeValue(v string) string { return cookieValueSanitizer.Replace(v) } +func jsonRenderer(value interface{}) Renderer { + return rendererFunc(func(ctx *Context) { + ctx.Output.JSON(value, false, false) + }) +} + +func errorRenderer(err error) Renderer { + return rendererFunc(func(ctx *Context) { + ctx.Output.SetStatus(500) + ctx.WriteString(err.Error()) + }) +} + // JSON writes json to response body. // if coding is true, it converts utf-8 to \u0000 type. func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, coding bool) error { @@ -330,9 +343,8 @@ func (output *BeegoOutput) IsServerError() bool { } func stringsToJSON(str string) string { - rs := []rune(str) var jsons bytes.Buffer - for _, r := range rs { + for _, r := range str { rint := int(r) if rint < 128 { jsons.WriteRune(r) diff --git a/context/param/conv.go b/context/param/conv.go new file mode 100644 index 00000000..c200e008 --- /dev/null +++ b/context/param/conv.go @@ -0,0 +1,78 @@ +package param + +import ( + "fmt" + "reflect" + + beecontext "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" +) + +// ConvertParams converts http method params to values that will be passed to the method controller as arguments +func ConvertParams(methodParams []*MethodParam, methodType reflect.Type, ctx *beecontext.Context) (result []reflect.Value) { + result = make([]reflect.Value, 0, len(methodParams)) + for i := 0; i < len(methodParams); i++ { + reflectValue := convertParam(methodParams[i], methodType.In(i), ctx) + result = append(result, reflectValue) + } + return +} + +func convertParam(param *MethodParam, paramType reflect.Type, ctx *beecontext.Context) (result reflect.Value) { + paramValue := getParamValue(param, ctx) + if paramValue == "" { + if param.required { + ctx.Abort(400, fmt.Sprintf("Missing parameter %s", param.name)) + } else { + paramValue = param.defaultValue + } + } + + reflectValue, err := parseValue(param, paramValue, paramType) + if err != nil { + logs.Debug(fmt.Sprintf("Error converting param %s to type %s. Value: %v, Error: %s", param.name, paramType, paramValue, err)) + ctx.Abort(400, fmt.Sprintf("Invalid parameter %s. Can not convert %v to type %s", param.name, paramValue, paramType)) + } + + return reflectValue +} + +func getParamValue(param *MethodParam, ctx *beecontext.Context) string { + switch param.in { + case body: + return string(ctx.Input.RequestBody) + case header: + return ctx.Input.Header(param.name) + case path: + return ctx.Input.Query(":" + param.name) + default: + return ctx.Input.Query(param.name) + } +} + +func parseValue(param *MethodParam, paramValue string, paramType reflect.Type) (result reflect.Value, err error) { + if paramValue == "" { + return reflect.Zero(paramType), nil + } + parser := getParser(param, paramType) + value, err := parser.parse(paramValue, paramType) + if err != nil { + return result, err + } + + return safeConvert(reflect.ValueOf(value), paramType) +} + +func safeConvert(value reflect.Value, t reflect.Type) (result reflect.Value, err error) { + defer func() { + if r := recover(); r != nil { + var ok bool + err, ok = r.(error) + if !ok { + err = fmt.Errorf("%v", r) + } + } + }() + result = value.Convert(t) + return +} diff --git a/context/param/methodparams.go b/context/param/methodparams.go new file mode 100644 index 00000000..cd6708a2 --- /dev/null +++ b/context/param/methodparams.go @@ -0,0 +1,69 @@ +package param + +import ( + "fmt" + "strings" +) + +//MethodParam keeps param information to be auto passed to controller methods +type MethodParam struct { + name string + in paramType + required bool + defaultValue string +} + +type paramType byte + +const ( + param paramType = iota + path + body + header +) + +//New creates a new MethodParam with name and specific options +func New(name string, opts ...MethodParamOption) *MethodParam { + return newParam(name, nil, opts) +} + +func newParam(name string, parser paramParser, opts []MethodParamOption) (param *MethodParam) { + param = &MethodParam{name: name} + for _, option := range opts { + option(param) + } + return +} + +//Make creates an array of MethodParmas or an empty array +func Make(list ...*MethodParam) []*MethodParam { + if len(list) > 0 { + return list + } + return nil +} + +func (mp *MethodParam) String() string { + options := []string{} + result := "param.New(\"" + mp.name + "\"" + if mp.required { + options = append(options, "param.IsRequired") + } + switch mp.in { + case path: + options = append(options, "param.InPath") + case body: + options = append(options, "param.InBody") + case header: + options = append(options, "param.InHeader") + } + if mp.defaultValue != "" { + options = append(options, fmt.Sprintf(`param.Default("%s")`, mp.defaultValue)) + } + if len(options) > 0 { + result += ", " + } + result += strings.Join(options, ", ") + result += ")" + return result +} diff --git a/context/param/options.go b/context/param/options.go new file mode 100644 index 00000000..58bdc3d0 --- /dev/null +++ b/context/param/options.go @@ -0,0 +1,37 @@ +package param + +import ( + "fmt" +) + +// MethodParamOption defines a func which apply options on a MethodParam +type MethodParamOption func(*MethodParam) + +// IsRequired indicates that this param is required and can not be ommited from the http request +var IsRequired MethodParamOption = func(p *MethodParam) { + p.required = true +} + +// InHeader indicates that this param is passed via an http header +var InHeader MethodParamOption = func(p *MethodParam) { + p.in = header +} + +// InPath indicates that this param is part of the URL path +var InPath MethodParamOption = func(p *MethodParam) { + p.in = path +} + +// InBody indicates that this param is passed as an http request body +var InBody MethodParamOption = func(p *MethodParam) { + p.in = body +} + +// Default provides a default value for the http param +func Default(defaultValue interface{}) MethodParamOption { + return func(p *MethodParam) { + if defaultValue != nil { + p.defaultValue = fmt.Sprint(defaultValue) + } + } +} diff --git a/context/param/parsers.go b/context/param/parsers.go new file mode 100644 index 00000000..421aecf0 --- /dev/null +++ b/context/param/parsers.go @@ -0,0 +1,149 @@ +package param + +import ( + "encoding/json" + "reflect" + "strconv" + "strings" + "time" +) + +type paramParser interface { + parse(value string, toType reflect.Type) (interface{}, error) +} + +func getParser(param *MethodParam, t reflect.Type) paramParser { + switch t.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return intParser{} + case reflect.Slice: + if t.Elem().Kind() == reflect.Uint8 { //treat []byte as string + return stringParser{} + } + if param.in == body { + return jsonParser{} + } + elemParser := getParser(param, t.Elem()) + if elemParser == (jsonParser{}) { + return elemParser + } + return sliceParser(elemParser) + case reflect.Bool: + return boolParser{} + case reflect.String: + return stringParser{} + case reflect.Float32, reflect.Float64: + return floatParser{} + case reflect.Ptr: + elemParser := getParser(param, t.Elem()) + if elemParser == (jsonParser{}) { + return elemParser + } + return ptrParser(elemParser) + default: + if t.PkgPath() == "time" && t.Name() == "Time" { + return timeParser{} + } + return jsonParser{} + } +} + +type parserFunc func(value string, toType reflect.Type) (interface{}, error) + +func (f parserFunc) parse(value string, toType reflect.Type) (interface{}, error) { + return f(value, toType) +} + +type boolParser struct { +} + +func (p boolParser) parse(value string, toType reflect.Type) (interface{}, error) { + return strconv.ParseBool(value) +} + +type stringParser struct { +} + +func (p stringParser) parse(value string, toType reflect.Type) (interface{}, error) { + return value, nil +} + +type intParser struct { +} + +func (p intParser) parse(value string, toType reflect.Type) (interface{}, error) { + return strconv.Atoi(value) +} + +type floatParser struct { +} + +func (p floatParser) parse(value string, toType reflect.Type) (interface{}, error) { + if toType.Kind() == reflect.Float32 { + res, err := strconv.ParseFloat(value, 32) + if err != nil { + return nil, err + } + return float32(res), nil + } + return strconv.ParseFloat(value, 64) +} + +type timeParser struct { +} + +func (p timeParser) parse(value string, toType reflect.Type) (result interface{}, err error) { + result, err = time.Parse(time.RFC3339, value) + if err != nil { + result, err = time.Parse("2006-01-02", value) + } + return +} + +type jsonParser struct { +} + +func (p jsonParser) parse(value string, toType reflect.Type) (interface{}, error) { + pResult := reflect.New(toType) + v := pResult.Interface() + err := json.Unmarshal([]byte(value), v) + if err != nil { + return nil, err + } + return pResult.Elem().Interface(), nil +} + +func sliceParser(elemParser paramParser) paramParser { + return parserFunc(func(value string, toType reflect.Type) (interface{}, error) { + values := strings.Split(value, ",") + result := reflect.MakeSlice(toType, 0, len(values)) + elemType := toType.Elem() + for _, v := range values { + parsedValue, err := elemParser.parse(v, elemType) + if err != nil { + return nil, err + } + result = reflect.Append(result, reflect.ValueOf(parsedValue)) + } + return result.Interface(), nil + }) +} + +func ptrParser(elemParser paramParser) paramParser { + return parserFunc(func(value string, toType reflect.Type) (interface{}, error) { + parsedValue, err := elemParser.parse(value, toType.Elem()) + if err != nil { + return nil, err + } + newValPtr := reflect.New(toType.Elem()) + newVal := reflect.Indirect(newValPtr) + convertedVal, err := safeConvert(reflect.ValueOf(parsedValue), toType.Elem()) + if err != nil { + return nil, err + } + + newVal.Set(convertedVal) + return newValPtr.Interface(), nil + }) +} diff --git a/context/param/parsers_test.go b/context/param/parsers_test.go new file mode 100644 index 00000000..b946ba08 --- /dev/null +++ b/context/param/parsers_test.go @@ -0,0 +1,84 @@ +package param + +import "testing" +import "reflect" +import "time" + +type testDefinition struct { + strValue string + expectedValue interface{} + expectedParser paramParser +} + +func Test_Parsers(t *testing.T) { + + //ints + checkParser(testDefinition{"1", 1, intParser{}}, t) + checkParser(testDefinition{"-1", int64(-1), intParser{}}, t) + checkParser(testDefinition{"1", uint64(1), intParser{}}, t) + + //floats + checkParser(testDefinition{"1.0", float32(1.0), floatParser{}}, t) + checkParser(testDefinition{"-1.0", float64(-1.0), floatParser{}}, t) + + //strings + checkParser(testDefinition{"AB", "AB", stringParser{}}, t) + checkParser(testDefinition{"AB", []byte{65, 66}, stringParser{}}, t) + + //bools + checkParser(testDefinition{"true", true, boolParser{}}, t) + checkParser(testDefinition{"0", false, boolParser{}}, t) + + //timeParser + checkParser(testDefinition{"2017-05-30T13:54:53Z", time.Date(2017, 5, 30, 13, 54, 53, 0, time.UTC), timeParser{}}, t) + checkParser(testDefinition{"2017-05-30", time.Date(2017, 5, 30, 0, 0, 0, 0, time.UTC), timeParser{}}, t) + + //json + checkParser(testDefinition{`{"X": 5, "Y":"Z"}`, struct { + X int + Y string + }{5, "Z"}, jsonParser{}}, t) + + //slice in query is parsed as comma delimited + checkParser(testDefinition{`1,2`, []int{1, 2}, sliceParser(intParser{})}, t) + + //slice in body is parsed as json + checkParser(testDefinition{`["a","b"]`, []string{"a", "b"}, jsonParser{}}, t, MethodParam{in: body}) + + //pointers + var someInt = 1 + checkParser(testDefinition{`1`, &someInt, ptrParser(intParser{})}, t) + + var someStruct = struct{ X int }{5} + checkParser(testDefinition{`{"X": 5}`, &someStruct, jsonParser{}}, t) + +} + +func checkParser(def testDefinition, t *testing.T, methodParam ...MethodParam) { + toType := reflect.TypeOf(def.expectedValue) + var mp MethodParam + if len(methodParam) == 0 { + mp = MethodParam{} + } else { + mp = methodParam[0] + } + parser := getParser(&mp, toType) + + if reflect.TypeOf(parser) != reflect.TypeOf(def.expectedParser) { + t.Errorf("Invalid parser for value %v. Expected: %v, actual: %v", def.strValue, reflect.TypeOf(def.expectedParser).Name(), reflect.TypeOf(parser).Name()) + return + } + result, err := parser.parse(def.strValue, toType) + if err != nil { + t.Errorf("Parsing error for value %v. Expected result: %v, error: %v", def.strValue, def.expectedValue, err) + return + } + convResult, err := safeConvert(reflect.ValueOf(result), toType) + if err != nil { + t.Errorf("Convertion error for %v. from value: %v, toType: %v, error: %v", def.strValue, result, toType, err) + return + } + if !reflect.DeepEqual(convResult.Interface(), def.expectedValue) { + t.Errorf("Parsing error for value %v. Expected result: %v, actual: %v", def.strValue, def.expectedValue, result) + } +} diff --git a/context/renderer.go b/context/renderer.go new file mode 100644 index 00000000..36a7cb53 --- /dev/null +++ b/context/renderer.go @@ -0,0 +1,12 @@ +package context + +// Renderer defines an http response renderer +type Renderer interface { + Render(ctx *Context) +} + +type rendererFunc func(ctx *Context) + +func (f rendererFunc) Render(ctx *Context) { + f(ctx) +} diff --git a/context/response.go b/context/response.go new file mode 100644 index 00000000..9c3c715a --- /dev/null +++ b/context/response.go @@ -0,0 +1,27 @@ +package context + +import ( + "strconv" + + "net/http" +) + +const ( + //BadRequest indicates http error 400 + BadRequest StatusCode = http.StatusBadRequest + + //NotFound indicates http error 404 + NotFound StatusCode = http.StatusNotFound +) + +// StatusCode sets the http response status code +type StatusCode int + +func (s StatusCode) Error() string { + return strconv.Itoa(int(s)) +} + +// Render sets the http status code +func (s StatusCode) Render(ctx *Context) { + ctx.Output.SetStatus(int(s)) +} diff --git a/controller.go b/controller.go index c2a327b3..510e16b8 100644 --- a/controller.go +++ b/controller.go @@ -28,6 +28,7 @@ import ( "strings" "github.com/astaxie/beego/context" + "github.com/astaxie/beego/context/param" "github.com/astaxie/beego/session" ) @@ -51,6 +52,7 @@ type ControllerComments struct { Router string AllowHTTPMethods []string Params []map[string]string + MethodParams []*param.MethodParam } // Controller defines some basic http request handler operations, such as diff --git a/error.go b/error.go index ab626247..b913db39 100644 --- a/error.go +++ b/error.go @@ -252,6 +252,30 @@ func forbidden(rw http.ResponseWriter, r *http.Request) { ) } +// show 422 missing xsrf token +func missingxsrf(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 422, + "
The page you have requested is forbidden."+ + "
Perhaps you are here because:"+ + "

", + ) +} + +// show 417 invalid xsrf token +func invalidxsrf(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 417, + "
The page you have requested is forbidden."+ + "
Perhaps you are here because:"+ + "

", + ) +} + // show 404 not found error. func notFound(rw http.ResponseWriter, r *http.Request) { responseError(rw, r, diff --git a/hooks.go b/hooks.go index 0fddc82f..c5ec8e2d 100644 --- a/hooks.go +++ b/hooks.go @@ -32,6 +32,8 @@ func registerDefaultErrorHandler() error { "502": badGateway, "503": serviceUnavailable, "504": gatewayTimeout, + "417": invalidxsrf, + "422": missingxsrf, } for e, h := range m { if _, ok := ErrorMaps[e]; !ok { diff --git a/orm/db.go b/orm/db.go index 2a05797a..12f0f54d 100644 --- a/orm/db.go +++ b/orm/db.go @@ -833,7 +833,11 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con if err := rs.Scan(&ref); err != nil { return 0, err } - args = append(args, reflect.ValueOf(ref).Interface()) + pkValue, err := d.convertValueFromDB(mi.fields.pk, reflect.ValueOf(ref).Interface(), tz) + if err != nil { + return 0, err + } + args = append(args, pkValue) cnt++ } diff --git a/orm/models_boot.go b/orm/models_boot.go index 85d0917f..5327f754 100644 --- a/orm/models_boot.go +++ b/orm/models_boot.go @@ -75,7 +75,7 @@ func registerModel(PrefixOrSuffix string, model interface{}, isPrefix bool) { } if mi.fields.pk == nil { - fmt.Printf(" `%s` need a primary key field, default use 'id' if not set\n", name) + fmt.Printf(" `%s` needs a primary key field, default is to use 'id' if not set\n", name) os.Exit(2) } diff --git a/orm/orm.go b/orm/orm.go index 5db79386..fcf82590 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -107,7 +107,7 @@ func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect if mi, ok := modelCache.getByFullName(name); ok { return mi, ind } - panic(fmt.Errorf(" table: `%s` not found, maybe not RegisterModel", name)) + panic(fmt.Errorf(" table: `%s` not found, make sure it was registered with `RegisterModel()`", name)) } // get field info from model info by given field name diff --git a/orm/orm_raw.go b/orm/orm_raw.go index 1e86212a..c8e741ea 100644 --- a/orm/orm_raw.go +++ b/orm/orm_raw.go @@ -493,19 +493,33 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { } } } else { - for i := 0; i < ind.NumField(); i++ { - f := ind.Field(i) - fe := ind.Type().Field(i) - _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) - var col string - if col = tags["column"]; col == "" { - col = snakeString(fe.Name) - } - if v, ok := columnsMp[col]; ok { - value := reflect.ValueOf(v).Elem().Interface() - o.setFieldValue(f, value) + // define recursive function + var recursiveSetField func(rv reflect.Value) + recursiveSetField = func(rv reflect.Value) { + for i := 0; i < rv.NumField(); i++ { + f := rv.Field(i) + fe := rv.Type().Field(i) + + // check if the field is a Struct + // recursive the Struct type + if fe.Type.Kind() == reflect.Struct { + recursiveSetField(f) + } + + _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) + var col string + if col = tags["column"]; col == "" { + col = snakeString(fe.Name) + } + if v, ok := columnsMp[col]; ok { + value := reflect.ValueOf(v).Elem().Interface() + o.setFieldValue(f, value) + } } } + + // init call the recursive function + recursiveSetField(ind) } if eTyps[0].Kind() == reflect.Ptr { diff --git a/orm/orm_test.go b/orm/orm_test.go index c5bfa8b9..f1f2d85e 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -1661,6 +1661,13 @@ func TestRawQueryRow(t *testing.T) { throwFail(t, AssertIs(pid, nil)) } +// user_profile table +type userProfile struct { + User + Age int + Money float64 +} + func TestQueryRows(t *testing.T) { Q := dDbBaser.TableQuote() @@ -1731,6 +1738,19 @@ func TestQueryRows(t *testing.T) { throwFailNow(t, AssertIs(usernames[1], "astaxie")) throwFailNow(t, AssertIs(ids[2], 4)) throwFailNow(t, AssertIs(usernames[2], "nobody")) + + //test query rows by nested struct + var l []userProfile + query = fmt.Sprintf("SELECT * FROM %suser_profile%s LEFT JOIN %suser%s ON %suser_profile%s.%sid%s = %suser%s.%sid%s", Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q) + num, err = dORM.Raw(query).QueryRows(&l) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(l), 2)) + throwFailNow(t, AssertIs(l[0].UserName, "slene")) + throwFailNow(t, AssertIs(l[0].Age, 28)) + throwFailNow(t, AssertIs(l[1].UserName, "astaxie")) + throwFailNow(t, AssertIs(l[1].Age, 30)) + } func TestRawValues(t *testing.T) { diff --git a/parser.go b/parser.go index d40ee3ce..a9cfd894 100644 --- a/parser.go +++ b/parser.go @@ -24,9 +24,13 @@ import ( "io/ioutil" "os" "path/filepath" + "regexp" "sort" + "strconv" "strings" + "unicode" + "github.com/astaxie/beego/context/param" "github.com/astaxie/beego/logs" "github.com/astaxie/beego/utils" ) @@ -35,6 +39,7 @@ var globalRouterTemplate = `package routers import ( "github.com/astaxie/beego" + "github.com/astaxie/beego/context/param" ) func init() { @@ -81,7 +86,7 @@ func parserPkg(pkgRealpath, pkgpath string) error { if specDecl.Recv != nil { exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser if ok { - parserComments(specDecl.Doc, specDecl.Name.String(), fmt.Sprint(exp.X), pkgpath) + parserComments(specDecl, fmt.Sprint(exp.X), pkgpath) } } } @@ -93,44 +98,169 @@ func parserPkg(pkgRealpath, pkgpath string) error { return nil } -func parserComments(comments *ast.CommentGroup, funcName, controllerName, pkgpath string) error { - if comments != nil && comments.List != nil { - for _, c := range comments.List { - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - if strings.HasPrefix(t, "@router") { - elements := strings.TrimLeft(t, "@router ") - e1 := strings.SplitN(elements, " ", 2) - if len(e1) < 1 { - return errors.New("you should has router information") - } - key := pkgpath + ":" + controllerName - cc := ControllerComments{} - cc.Method = funcName - cc.Router = e1[0] - if len(e1) == 2 && e1[1] != "" { - e1 = strings.SplitN(e1[1], " ", 2) - if len(e1) >= 1 { - cc.AllowHTTPMethods = strings.Split(strings.Trim(e1[0], "[]"), ",") - } else { - cc.AllowHTTPMethods = append(cc.AllowHTTPMethods, "get") - } - } else { - cc.AllowHTTPMethods = append(cc.AllowHTTPMethods, "get") - } - if len(e1) == 2 && e1[1] != "" { - keyval := strings.Split(strings.Trim(e1[1], "[]"), " ") - for _, kv := range keyval { - kk := strings.Split(kv, ":") - cc.Params = append(cc.Params, map[string]string{strings.Join(kk[:len(kk)-1], ":"): kk[len(kk)-1]}) - } - } - genInfoList[key] = append(genInfoList[key], cc) - } +type parsedComment struct { + routerPath string + methods []string + params map[string]parsedParam +} + +type parsedParam struct { + name string + datatype string + location string + defValue string + required bool +} + +func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error { + if f.Doc != nil { + parsedComment, err := parseComment(f.Doc.List) + if err != nil { + return err } + if parsedComment.routerPath != "" { + key := pkgpath + ":" + controllerName + cc := ControllerComments{} + cc.Method = f.Name.String() + cc.Router = parsedComment.routerPath + cc.AllowHTTPMethods = parsedComment.methods + cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment) + genInfoList[key] = append(genInfoList[key], cc) + } + } return nil } +func buildMethodParams(funcParams []*ast.Field, pc *parsedComment) []*param.MethodParam { + result := make([]*param.MethodParam, 0, len(funcParams)) + for _, fparam := range funcParams { + methodParam := buildMethodParam(fparam, pc) + result = append(result, methodParam) + } + return result +} + +func buildMethodParam(fparam *ast.Field, pc *parsedComment) *param.MethodParam { + options := []param.MethodParamOption{} + name := fparam.Names[0].Name + if cparam, ok := pc.params[name]; ok { + //Build param from comment info + name = cparam.name + if cparam.required { + options = append(options, param.IsRequired) + } + switch cparam.location { + case "body": + options = append(options, param.InBody) + case "header": + options = append(options, param.InHeader) + case "path": + options = append(options, param.InPath) + } + if cparam.defValue != "" { + options = append(options, param.Default(cparam.defValue)) + } + } else { + if paramInPath(name, pc.routerPath) { + options = append(options, param.InPath) + } + } + return param.New(name, options...) +} + +func paramInPath(name, route string) bool { + return strings.HasSuffix(route, ":"+name) || + strings.Contains(route, ":"+name+"/") +} + +var routeRegex = regexp.MustCompile(`@router\s+(\S+)(?:\s+\[(\S+)\])?`) + +func parseComment(lines []*ast.Comment) (pc *parsedComment, err error) { + pc = &parsedComment{} + for _, c := range lines { + t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) + if strings.HasPrefix(t, "@router") { + matches := routeRegex.FindStringSubmatch(t) + if len(matches) == 3 { + pc.routerPath = matches[1] + methods := matches[2] + if methods == "" { + pc.methods = []string{"get"} + //pc.hasGet = true + } else { + pc.methods = strings.Split(methods, ",") + //pc.hasGet = strings.Contains(methods, "get") + } + } else { + return nil, errors.New("Router information is missing") + } + } else if strings.HasPrefix(t, "@Param") { + pv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Param"))) + if len(pv) < 4 { + logs.Error("Invalid @Param format. Needs at least 4 parameters") + } + p := parsedParam{} + names := strings.SplitN(pv[0], "=>", 2) + p.name = names[0] + funcParamName := p.name + if len(names) > 1 { + funcParamName = names[1] + } + p.location = pv[1] + p.datatype = pv[2] + switch len(pv) { + case 5: + p.required, _ = strconv.ParseBool(pv[3]) + case 6: + p.defValue = pv[3] + p.required, _ = strconv.ParseBool(pv[4]) + } + if pc.params == nil { + pc.params = map[string]parsedParam{} + } + pc.params[funcParamName] = p + } + } + return +} + +// direct copy from bee\g_docs.go +// analisys params return []string +// @Param query form string true "The email for login" +// [query form string true "The email for login"] +func getparams(str string) []string { + var s []rune + var j int + var start bool + var r []string + var quoted int8 + for _, c := range str { + if unicode.IsSpace(c) && quoted == 0 { + if !start { + continue + } else { + start = false + j++ + r = append(r, string(s)) + s = make([]rune, 0) + continue + } + } + + start = true + if c == '"' { + quoted ^= 1 + continue + } + s = append(s, c) + } + if len(s) > 0 { + r = append(r, string(s)) + } + return r +} + func genRouterCode(pkgRealpath string) { os.Mkdir(getRouterDir(pkgRealpath), 0755) logs.Info("generate router from comments") @@ -163,12 +293,24 @@ func genRouterCode(pkgRealpath string) { } params = strings.TrimRight(params, ",") + "}" } + methodParams := "param.Make(" + if len(c.MethodParams) > 0 { + lines := make([]string, 0, len(c.MethodParams)) + for _, m := range c.MethodParams { + lines = append(lines, fmt.Sprint(m)) + } + methodParams += "\n " + + strings.Join(lines, ",\n ") + + ",\n " + } + methodParams += ")" globalinfo = globalinfo + ` beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"], beego.ControllerComments{ Method: "` + strings.TrimSpace(c.Method) + `", ` + "Router: `" + c.Router + "`" + `, AllowHTTPMethods: ` + allmethod + `, + MethodParams: ` + methodParams + `, Params: ` + params + `}) ` } diff --git a/plugins/authz/authz.go b/plugins/authz/authz.go new file mode 100644 index 00000000..709a613a --- /dev/null +++ b/plugins/authz/authz.go @@ -0,0 +1,86 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package authz provides handlers to enable ACL, RBAC, ABAC authorization support. +// Simple Usage: +// import( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/authz" +// "github.com/hsluoyz/casbin" +// ) +// +// func main(){ +// // mediate the access for every request +// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) +// beego.Run() +// } +// +// +// Advanced Usage: +// +// func main(){ +// e := casbin.NewEnforcer("authz_model.conf", "") +// e.AddRoleForUser("alice", "admin") +// e.AddPolicy(...) +// +// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(e)) +// beego.Run() +// } +package authz + +import ( + "github.com/astaxie/beego" + "github.com/astaxie/beego/context" + "github.com/hsluoyz/casbin" + "net/http" +) + +// NewAuthorizer returns the authorizer. +// Use a casbin enforcer as input +func NewAuthorizer(e *casbin.Enforcer) beego.FilterFunc { + return func(ctx *context.Context) { + a := &BasicAuthorizer{enforcer: e} + + if !a.CheckPermission(ctx.Request) { + a.RequirePermission(ctx.ResponseWriter) + } + } +} + +// BasicAuthorizer stores the casbin handler +type BasicAuthorizer struct { + enforcer *casbin.Enforcer +} + +// GetUserName gets the user name from the request. +// Currently, only HTTP basic authentication is supported +func (a *BasicAuthorizer) GetUserName(r *http.Request) string { + username, _, _ := r.BasicAuth() + return username +} + +// CheckPermission checks the user/method/path combination from the request. +// Returns true (permission granted) or false (permission forbidden) +func (a *BasicAuthorizer) CheckPermission(r *http.Request) bool { + user := a.GetUserName(r) + method := r.Method + path := r.URL.Path + return a.enforcer.Enforce(user, path, method) +} + +// RequirePermission returns the 403 Forbidden to the client +func (a *BasicAuthorizer) RequirePermission(w http.ResponseWriter) { + w.WriteHeader(403) + w.Write([]byte("403 Forbidden\n")) +} diff --git a/plugins/authz/authz_model.conf b/plugins/authz/authz_model.conf new file mode 100644 index 00000000..d1b3dbd7 --- /dev/null +++ b/plugins/authz/authz_model.conf @@ -0,0 +1,14 @@ +[request_definition] +r = sub, obj, act + +[policy_definition] +p = sub, obj, act + +[role_definition] +g = _, _ + +[policy_effect] +e = some(where (p.eft == allow)) + +[matchers] +m = g(r.sub, p.sub) && keyMatch(r.obj, p.obj) && (r.act == p.act || p.act == "*") \ No newline at end of file diff --git a/plugins/authz/authz_policy.csv b/plugins/authz/authz_policy.csv new file mode 100644 index 00000000..c062dd3e --- /dev/null +++ b/plugins/authz/authz_policy.csv @@ -0,0 +1,7 @@ +p, alice, /dataset1/*, GET +p, alice, /dataset1/resource1, POST +p, bob, /dataset2/resource1, * +p, bob, /dataset2/resource2, GET +p, bob, /dataset2/folder1/*, POST +p, dataset1_admin, /dataset1/*, * +g, cathy, dataset1_admin \ No newline at end of file diff --git a/plugins/authz/authz_test.go b/plugins/authz/authz_test.go new file mode 100644 index 00000000..4003582c --- /dev/null +++ b/plugins/authz/authz_test.go @@ -0,0 +1,107 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package authz + +import ( + "github.com/astaxie/beego" + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/plugins/auth" + "github.com/hsluoyz/casbin" + "net/http" + "net/http/httptest" + "testing" +) + +func testRequest(t *testing.T, handler *beego.ControllerRegister, user string, path string, method string, code int) { + r, _ := http.NewRequest(method, path, nil) + r.SetBasicAuth(user, "123") + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if w.Code != code { + t.Errorf("%s, %s, %s: %d, supposed to be %d", user, path, method, w.Code, code) + } +} + +func TestBasic(t *testing.T) { + handler := beego.NewControllerRegister() + + handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("alice", "123")) + handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "alice", "/dataset1/resource1", "GET", 200) + testRequest(t, handler, "alice", "/dataset1/resource1", "POST", 200) + testRequest(t, handler, "alice", "/dataset1/resource2", "GET", 200) + testRequest(t, handler, "alice", "/dataset1/resource2", "POST", 403) +} + +func TestPathWildcard(t *testing.T) { + handler := beego.NewControllerRegister() + + handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("bob", "123")) + handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "bob", "/dataset2/resource1", "GET", 200) + testRequest(t, handler, "bob", "/dataset2/resource1", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/resource1", "DELETE", 200) + testRequest(t, handler, "bob", "/dataset2/resource2", "GET", 200) + testRequest(t, handler, "bob", "/dataset2/resource2", "POST", 403) + testRequest(t, handler, "bob", "/dataset2/resource2", "DELETE", 403) + + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "GET", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "DELETE", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "GET", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "DELETE", 403) +} + +func TestRBAC(t *testing.T) { + handler := beego.NewControllerRegister() + + handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("cathy", "123")) + e := casbin.NewEnforcer("authz_model.conf", "authz_policy.csv") + handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(e)) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + // cathy can access all /dataset1/* resources via all methods because it has the dataset1_admin role. + testRequest(t, handler, "cathy", "/dataset1/item", "GET", 200) + testRequest(t, handler, "cathy", "/dataset1/item", "POST", 200) + testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 200) + testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403) + + // delete all roles on user cathy, so cathy cannot access any resources now. + e.DeleteRolesForUser("cathy") + + testRequest(t, handler, "cathy", "/dataset1/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset1/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403) +} diff --git a/router.go b/router.go index 692e72b0..72476ae8 100644 --- a/router.go +++ b/router.go @@ -27,6 +27,7 @@ import ( "time" beecontext "github.com/astaxie/beego/context" + "github.com/astaxie/beego/context/param" "github.com/astaxie/beego/logs" "github.com/astaxie/beego/toolbox" "github.com/astaxie/beego/utils" @@ -116,6 +117,7 @@ type ControllerInfo struct { handler http.Handler runFunction FilterFunc routerType int + methodParams []*param.MethodParam } // ControllerRegister containers registered router rules, controller handlers and filters. @@ -151,6 +153,10 @@ func NewControllerRegister() *ControllerRegister { // Add("/api",&RestController{},"get,post:ApiFunc" // Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { + p.addWithMethodParams(pattern, c, nil, mappingMethods...) +} + +func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, mappingMethods ...string) { reflectVal := reflect.ValueOf(c) t := reflect.Indirect(reflectVal).Type() methods := make(map[string]string) @@ -181,6 +187,7 @@ func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingM route.methods = methods route.routerType = routerTypeBeego route.controllerType = t + route.methodParams = methodParams if len(methods) == 0 { for _, m := range HTTPMETHOD { p.addToRouter(m, pattern, route) @@ -245,7 +252,7 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) { key := t.PkgPath() + ":" + t.Name() if comm, ok := GlobalControllerRouter[key]; ok { for _, a := range comm { - p.Add(a.Router, c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method) + p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method) } } } @@ -624,11 +631,12 @@ func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath str func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) { startTime := time.Now() var ( - runRouter reflect.Type - findRouter bool - runMethod string - routerInfo *ControllerInfo - isRunnable bool + runRouter reflect.Type + findRouter bool + runMethod string + methodParams []*param.MethodParam + routerInfo *ControllerInfo + isRunnable bool ) context := p.pool.Get().(*beecontext.Context) context.Reset(rw, r) @@ -740,6 +748,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) routerInfo.handler.ServeHTTP(rw, r) } else { runRouter = routerInfo.controllerType + methodParams = routerInfo.methodParams method := r.Method if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodPost { method = http.MethodPut @@ -802,9 +811,14 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) execController.Options() default: if !execController.HandlerFunc(runMethod) { - var in []reflect.Value method := vc.MethodByName(runMethod) - method.Call(in) + in := param.ConvertParams(methodParams, method.Type(), context) + out := method.Call(in) + + //For backward compatibility we only handle response if we had incoming methodParams + if methodParams != nil { + p.handleParamResponse(context, execController, out) + } } } @@ -884,6 +898,20 @@ Admin: } } +func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, execController ControllerInterface, results []reflect.Value) { + //looping in reverse order for the case when both error and value are returned and error sets the response status code + for i := len(results) - 1; i >= 0; i-- { + result := results[i] + if result.Kind() != reflect.Interface || !result.IsNil() { + resultValue := result.Interface() + context.RenderMethodResult(resultValue) + } + } + if !context.ResponseWriter.Started && context.Output.Status == 0 { + context.Output.SetStatus(200) + } +} + // FindRouter Find Router info for URL func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo *ControllerInfo, isFind bool) { var urlPath = context.Input.URL() diff --git a/swagger/swagger.go b/swagger/swagger.go index c687fb8e..035d5a49 100644 --- a/swagger/swagger.go +++ b/swagger/swagger.go @@ -22,19 +22,19 @@ package swagger // Swagger list the resource type Swagger struct { - SwaggerVersion string `json:"swagger,omitempty" yaml:"swagger,omitempty"` - Infos Information `json:"info" yaml:"info"` - Host string `json:"host,omitempty" yaml:"host,omitempty"` - BasePath string `json:"basePath,omitempty" yaml:"basePath,omitempty"` - Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"` - Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"` - Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"` - Paths map[string]*Item `json:"paths" yaml:"paths"` - Definitions map[string]Schema `json:"definitions,omitempty" yaml:"definitions,omitempty"` - SecurityDefinitions map[string]Security `json:"securityDefinitions,omitempty" yaml:"securityDefinitions,omitempty"` - Security map[string][]string `json:"security,omitempty" yaml:"security,omitempty"` - Tags []Tag `json:"tags,omitempty" yaml:"tags,omitempty"` - ExternalDocs *ExternalDocs `json:"externalDocs,omitempty" yaml:"externalDocs,omitempty"` + SwaggerVersion string `json:"swagger,omitempty" yaml:"swagger,omitempty"` + Infos Information `json:"info" yaml:"info"` + Host string `json:"host,omitempty" yaml:"host,omitempty"` + BasePath string `json:"basePath,omitempty" yaml:"basePath,omitempty"` + Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"` + Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"` + Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"` + Paths map[string]*Item `json:"paths" yaml:"paths"` + Definitions map[string]Schema `json:"definitions,omitempty" yaml:"definitions,omitempty"` + SecurityDefinitions map[string]Security `json:"securityDefinitions,omitempty" yaml:"securityDefinitions,omitempty"` + Security []map[string][]string `json:"security,omitempty" yaml:"security,omitempty"` + Tags []Tag `json:"tags,omitempty" yaml:"tags,omitempty"` + ExternalDocs *ExternalDocs `json:"externalDocs,omitempty" yaml:"externalDocs,omitempty"` } // Information Provides metadata about the API. The metadata can be used by the clients if needed. @@ -75,16 +75,17 @@ type Item struct { // Operation Describes a single API operation on a path. type Operation struct { - Tags []string `json:"tags,omitempty" yaml:"tags,omitempty"` - Summary string `json:"summary,omitempty" yaml:"summary,omitempty"` - Description string `json:"description,omitempty" yaml:"description,omitempty"` - OperationID string `json:"operationId,omitempty" yaml:"operationId,omitempty"` - Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"` - Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"` - Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"` - Parameters []Parameter `json:"parameters,omitempty" yaml:"parameters,omitempty"` - Responses map[string]Response `json:"responses,omitempty" yaml:"responses,omitempty"` - Deprecated bool `json:"deprecated,omitempty" yaml:"deprecated,omitempty"` + Tags []string `json:"tags,omitempty" yaml:"tags,omitempty"` + Summary string `json:"summary,omitempty" yaml:"summary,omitempty"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + OperationID string `json:"operationId,omitempty" yaml:"operationId,omitempty"` + Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"` + Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"` + Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"` + Parameters []Parameter `json:"parameters,omitempty" yaml:"parameters,omitempty"` + Responses map[string]Response `json:"responses,omitempty" yaml:"responses,omitempty"` + Security []map[string][]string `json:"security,omitempty" yaml:"security,omitempty"` + Deprecated bool `json:"deprecated,omitempty" yaml:"deprecated,omitempty"` } // Parameter Describes a single operation parameter. diff --git a/templatefunc.go b/templatefunc.go index a6f9c961..a104fd24 100644 --- a/templatefunc.go +++ b/templatefunc.go @@ -27,9 +27,10 @@ import ( ) const ( - formatTime = "15:04:05" - formatDate = "2006-01-02" - formatDateTime = "2006-01-02 15:04:05" + formatTime = "15:04:05" + formatDate = "2006-01-02" + formatDateTime = "2006-01-02 15:04:05" + formatDateTimeT = "2006-01-02T15:04:05" ) // Substr returns the substr from start to length. @@ -53,21 +54,21 @@ func Substr(s string, start, length int) string { // HTML2str returns escaping text convert from html. func HTML2str(html string) string { - re, _ := regexp.Compile("\\<[\\S\\s]+?\\>") + re, _ := regexp.Compile(`\<[\S\s]+?\>`) html = re.ReplaceAllStringFunc(html, strings.ToLower) //remove STYLE - re, _ = regexp.Compile("\\") + re, _ = regexp.Compile(`\`) html = re.ReplaceAllString(html, "") //remove SCRIPT - re, _ = regexp.Compile("\\") + re, _ = regexp.Compile(`\`) html = re.ReplaceAllString(html, "") - re, _ = regexp.Compile("\\<[\\S\\s]+?\\>") + re, _ = regexp.Compile(`\<[\S\s]+?\>`) html = re.ReplaceAllString(html, "\n") - re, _ = regexp.Compile("\\s{2,}") + re, _ = regexp.Compile(`\s{2,}`) html = re.ReplaceAllString(html, "\n") return strings.TrimSpace(html) @@ -360,8 +361,13 @@ func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) e value = value[:25] t, err = time.ParseInLocation(time.RFC3339, value, time.Local) } else if len(value) >= 19 { - value = value[:19] - t, err = time.ParseInLocation(formatDateTime, value, time.Local) + if strings.Contains(value, "T") { + value = value[:19] + t, err = time.ParseInLocation(formatDateTimeT, value, time.Local) + } else { + value = value[:19] + t, err = time.ParseInLocation(formatDateTime, value, time.Local) + } } else if len(value) >= 10 { if len(value) > 10 { value = value[:10] @@ -373,7 +379,6 @@ func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) e } t, err = time.ParseInLocation(formatTime, value, time.Local) } - if err != nil { return err } diff --git a/validation/validation_test.go b/validation/validation_test.go index 83e881bf..71a8bd17 100644 --- a/validation/validation_test.go +++ b/validation/validation_test.go @@ -35,6 +35,12 @@ func TestRequired(t *testing.T) { if valid.Required("", "string").Ok { t.Error("\"'\" string should be false") } + if valid.Required(" ", "string").Ok { + t.Error("\" \" string should be false") // For #2361 + } + if valid.Required("\n", "string").Ok { + t.Error("new line string should be false") // For #2361 + } if !valid.Required("astaxie", "string").Ok { t.Error("string should be true") } @@ -175,10 +181,10 @@ func TestAlphaNumeric(t *testing.T) { func TestMatch(t *testing.T) { valid := Validation{} - if valid.Match("suchuangji@gmail", regexp.MustCompile("^\\w+@\\w+\\.\\w+$"), "match").Ok { + if valid.Match("suchuangji@gmail", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be false") } - if !valid.Match("suchuangji@gmail.com", regexp.MustCompile("^\\w+@\\w+\\.\\w+$"), "match").Ok { + if !valid.Match("suchuangji@gmail.com", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be true") } } @@ -186,10 +192,10 @@ func TestMatch(t *testing.T) { func TestNoMatch(t *testing.T) { valid := Validation{} - if valid.NoMatch("123@gmail", regexp.MustCompile("[^\\w\\d]"), "nomatch").Ok { + if valid.NoMatch("123@gmail", regexp.MustCompile(`[^\w\d]`), "nomatch").Ok { t.Error("\"123@gmail\" not match \"[^\\w\\d]\" should be false") } - if !valid.NoMatch("123gmail", regexp.MustCompile("[^\\w\\d]"), "match").Ok { + if !valid.NoMatch("123gmail", regexp.MustCompile(`[^\w\d]`), "match").Ok { t.Error("\"123@gmail\" not match \"[^\\w\\d@]\" should be true") } } diff --git a/validation/validators.go b/validation/validators.go index 01aed443..5d489a55 100644 --- a/validation/validators.go +++ b/validation/validators.go @@ -18,6 +18,7 @@ import ( "fmt" "reflect" "regexp" + "strings" "time" "unicode/utf8" ) @@ -98,7 +99,7 @@ func (r Required) IsSatisfied(obj interface{}) bool { } if str, ok := obj.(string); ok { - return len(str) > 0 + return len(strings.TrimSpace(str)) > 0 } if _, ok := obj.(bool); ok { return true @@ -145,7 +146,7 @@ func (r Required) IsSatisfied(obj interface{}) bool { // DefaultMessage return the default error message func (r Required) DefaultMessage() string { - return fmt.Sprint(MessageTmpls["Required"]) + return MessageTmpls["Required"] } // GetKey return the r.Key @@ -364,7 +365,7 @@ func (a Alpha) IsSatisfied(obj interface{}) bool { // DefaultMessage return the default Length error message func (a Alpha) DefaultMessage() string { - return fmt.Sprint(MessageTmpls["Alpha"]) + return MessageTmpls["Alpha"] } // GetKey return the m.Key @@ -397,7 +398,7 @@ func (n Numeric) IsSatisfied(obj interface{}) bool { // DefaultMessage return the default Length error message func (n Numeric) DefaultMessage() string { - return fmt.Sprint(MessageTmpls["Numeric"]) + return MessageTmpls["Numeric"] } // GetKey return the n.Key @@ -430,7 +431,7 @@ func (a AlphaNumeric) IsSatisfied(obj interface{}) bool { // DefaultMessage return the default Length error message func (a AlphaNumeric) DefaultMessage() string { - return fmt.Sprint(MessageTmpls["AlphaNumeric"]) + return MessageTmpls["AlphaNumeric"] } // GetKey return the a.Key @@ -495,7 +496,7 @@ func (n NoMatch) GetLimitValue() interface{} { return n.Regexp.String() } -var alphaDashPattern = regexp.MustCompile("[^\\d\\w-_]") +var alphaDashPattern = regexp.MustCompile(`[^\d\w-_]`) // AlphaDash check not Alpha type AlphaDash struct { @@ -505,7 +506,7 @@ type AlphaDash struct { // DefaultMessage return the default AlphaDash error message func (a AlphaDash) DefaultMessage() string { - return fmt.Sprint(MessageTmpls["AlphaDash"]) + return MessageTmpls["AlphaDash"] } // GetKey return the n.Key @@ -518,7 +519,7 @@ func (a AlphaDash) GetLimitValue() interface{} { return nil } -var emailPattern = regexp.MustCompile("^[\\w!#$%&'*+/=?^_`{|}~-]+(?:\\.[\\w!#$%&'*+/=?^_`{|}~-]+)*@(?:[\\w](?:[\\w-]*[\\w])?\\.)+[a-zA-Z0-9](?:[\\w-]*[\\w])?$") +var emailPattern = regexp.MustCompile(`^[\w!#$%&'*+/=?^_` + "`" + `{|}~-]+(?:\.[\w!#$%&'*+/=?^_` + "`" + `{|}~-]+)*@(?:[\w](?:[\w-]*[\w])?\.)+[a-zA-Z0-9](?:[\w-]*[\w])?$`) // Email check struct type Email struct { @@ -528,7 +529,7 @@ type Email struct { // DefaultMessage return the default Email error message func (e Email) DefaultMessage() string { - return fmt.Sprint(MessageTmpls["Email"]) + return MessageTmpls["Email"] } // GetKey return the n.Key @@ -541,7 +542,7 @@ func (e Email) GetLimitValue() interface{} { return nil } -var ipPattern = regexp.MustCompile("^((2[0-4]\\d|25[0-5]|[01]?\\d\\d?)\\.){3}(2[0-4]\\d|25[0-5]|[01]?\\d\\d?)$") +var ipPattern = regexp.MustCompile(`^((2[0-4]\d|25[0-5]|[01]?\d\d?)\.){3}(2[0-4]\d|25[0-5]|[01]?\d\d?)$`) // IP check struct type IP struct { @@ -551,7 +552,7 @@ type IP struct { // DefaultMessage return the default IP error message func (i IP) DefaultMessage() string { - return fmt.Sprint(MessageTmpls["IP"]) + return MessageTmpls["IP"] } // GetKey return the i.Key @@ -564,7 +565,7 @@ func (i IP) GetLimitValue() interface{} { return nil } -var base64Pattern = regexp.MustCompile("^(?:[A-Za-z0-99+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?$") +var base64Pattern = regexp.MustCompile(`^(?:[A-Za-z0-99+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?$`) // Base64 check struct type Base64 struct { @@ -574,7 +575,7 @@ type Base64 struct { // DefaultMessage return the default Base64 error message func (b Base64) DefaultMessage() string { - return fmt.Sprint(MessageTmpls["Base64"]) + return MessageTmpls["Base64"] } // GetKey return the b.Key @@ -588,7 +589,7 @@ func (b Base64) GetLimitValue() interface{} { } // just for chinese mobile phone number -var mobilePattern = regexp.MustCompile("^((\\+86)|(86))?(1(([35][0-9])|[8][0-9]|[7][06789]|[4][579]))\\d{8}$") +var mobilePattern = regexp.MustCompile(`^((\+86)|(86))?(1(([35][0-9])|[8][0-9]|[7][06789]|[4][579]))\d{8}$`) // Mobile check struct type Mobile struct { @@ -598,7 +599,7 @@ type Mobile struct { // DefaultMessage return the default Mobile error message func (m Mobile) DefaultMessage() string { - return fmt.Sprint(MessageTmpls["Mobile"]) + return MessageTmpls["Mobile"] } // GetKey return the m.Key @@ -612,7 +613,7 @@ func (m Mobile) GetLimitValue() interface{} { } // just for chinese telephone number -var telPattern = regexp.MustCompile("^(0\\d{2,3}(\\-)?)?\\d{7,8}$") +var telPattern = regexp.MustCompile(`^(0\d{2,3}(\-)?)?\d{7,8}$`) // Tel check telephone struct type Tel struct { @@ -622,7 +623,7 @@ type Tel struct { // DefaultMessage return the default Tel error message func (t Tel) DefaultMessage() string { - return fmt.Sprint(MessageTmpls["Tel"]) + return MessageTmpls["Tel"] } // GetKey return the t.Key @@ -649,7 +650,7 @@ func (p Phone) IsSatisfied(obj interface{}) bool { // DefaultMessage return the default Phone error message func (p Phone) DefaultMessage() string { - return fmt.Sprint(MessageTmpls["Phone"]) + return MessageTmpls["Phone"] } // GetKey return the p.Key @@ -663,7 +664,7 @@ func (p Phone) GetLimitValue() interface{} { } // just for chinese zipcode -var zipCodePattern = regexp.MustCompile("^[1-9]\\d{5}$") +var zipCodePattern = regexp.MustCompile(`^[1-9]\d{5}$`) // ZipCode check the zip struct type ZipCode struct { @@ -673,7 +674,7 @@ type ZipCode struct { // DefaultMessage return the default Zip error message func (z ZipCode) DefaultMessage() string { - return fmt.Sprint(MessageTmpls["ZipCode"]) + return MessageTmpls["ZipCode"] } // GetKey return the z.Key