diff --git a/.travis.yml b/.travis.yml
index 479d70ca..2937e6e8 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -33,6 +33,8 @@ install:
- go get github.com/ssdb/gossdb/ssdb
- go get github.com/cloudflare/golz4
- go get github.com/gogo/protobuf/proto
+ - go get github.com/Knetic/govaluate
+ - go get github.com/casbin/casbin
- go get -u honnef.co/go/tools/cmd/gosimple
- go get -u github.com/mdempsky/unconvert
- go get -u github.com/gordonklaus/ineffassign
diff --git a/beego.go b/beego.go
index 7a8db390..22079a20 100644
--- a/beego.go
+++ b/beego.go
@@ -23,7 +23,7 @@ import (
const (
// VERSION represent beego web framework version.
- VERSION = "1.8.2"
+ VERSION = "1.8.3"
// DEV is for develop
DEV = "dev"
diff --git a/cache/conv.go b/cache/conv.go
index dbdff1c7..87800586 100644
--- a/cache/conv.go
+++ b/cache/conv.go
@@ -28,7 +28,7 @@ func GetString(v interface{}) string {
return string(result)
default:
if v != nil {
- return fmt.Sprintf("%v", result)
+ return fmt.Sprint(result)
}
}
return ""
diff --git a/context/context.go b/context/context.go
index 03286097..8b32062c 100644
--- a/context/context.go
+++ b/context/context.go
@@ -171,6 +171,22 @@ func (ctx *Context) CheckXSRFCookie() bool {
return true
}
+// RenderMethodResult renders the return value of a controller method to the output
+func (ctx *Context) RenderMethodResult(result interface{}) {
+ if result != nil {
+ renderer, ok := result.(Renderer)
+ if !ok {
+ err, ok := result.(error)
+ if ok {
+ renderer = errorRenderer(err)
+ } else {
+ renderer = jsonRenderer(result)
+ }
+ }
+ renderer.Render(ctx)
+ }
+}
+
//Response is a wrapper for the http.ResponseWriter
//started set to true if response was written to then don't execute other handler
type Response struct {
diff --git a/context/output.go b/context/output.go
index 564ef96d..cf9e7a7e 100644
--- a/context/output.go
+++ b/context/output.go
@@ -168,6 +168,19 @@ func sanitizeValue(v string) string {
return cookieValueSanitizer.Replace(v)
}
+func jsonRenderer(value interface{}) Renderer {
+ return rendererFunc(func(ctx *Context) {
+ ctx.Output.JSON(value, false, false)
+ })
+}
+
+func errorRenderer(err error) Renderer {
+ return rendererFunc(func(ctx *Context) {
+ ctx.Output.SetStatus(500)
+ ctx.WriteString(err.Error())
+ })
+}
+
// JSON writes json to response body.
// if coding is true, it converts utf-8 to \u0000 type.
func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, coding bool) error {
@@ -330,9 +343,8 @@ func (output *BeegoOutput) IsServerError() bool {
}
func stringsToJSON(str string) string {
- rs := []rune(str)
var jsons bytes.Buffer
- for _, r := range rs {
+ for _, r := range str {
rint := int(r)
if rint < 128 {
jsons.WriteRune(r)
diff --git a/context/param/conv.go b/context/param/conv.go
new file mode 100644
index 00000000..c200e008
--- /dev/null
+++ b/context/param/conv.go
@@ -0,0 +1,78 @@
+package param
+
+import (
+ "fmt"
+ "reflect"
+
+ beecontext "github.com/astaxie/beego/context"
+ "github.com/astaxie/beego/logs"
+)
+
+// ConvertParams converts http method params to values that will be passed to the method controller as arguments
+func ConvertParams(methodParams []*MethodParam, methodType reflect.Type, ctx *beecontext.Context) (result []reflect.Value) {
+ result = make([]reflect.Value, 0, len(methodParams))
+ for i := 0; i < len(methodParams); i++ {
+ reflectValue := convertParam(methodParams[i], methodType.In(i), ctx)
+ result = append(result, reflectValue)
+ }
+ return
+}
+
+func convertParam(param *MethodParam, paramType reflect.Type, ctx *beecontext.Context) (result reflect.Value) {
+ paramValue := getParamValue(param, ctx)
+ if paramValue == "" {
+ if param.required {
+ ctx.Abort(400, fmt.Sprintf("Missing parameter %s", param.name))
+ } else {
+ paramValue = param.defaultValue
+ }
+ }
+
+ reflectValue, err := parseValue(param, paramValue, paramType)
+ if err != nil {
+ logs.Debug(fmt.Sprintf("Error converting param %s to type %s. Value: %v, Error: %s", param.name, paramType, paramValue, err))
+ ctx.Abort(400, fmt.Sprintf("Invalid parameter %s. Can not convert %v to type %s", param.name, paramValue, paramType))
+ }
+
+ return reflectValue
+}
+
+func getParamValue(param *MethodParam, ctx *beecontext.Context) string {
+ switch param.in {
+ case body:
+ return string(ctx.Input.RequestBody)
+ case header:
+ return ctx.Input.Header(param.name)
+ case path:
+ return ctx.Input.Query(":" + param.name)
+ default:
+ return ctx.Input.Query(param.name)
+ }
+}
+
+func parseValue(param *MethodParam, paramValue string, paramType reflect.Type) (result reflect.Value, err error) {
+ if paramValue == "" {
+ return reflect.Zero(paramType), nil
+ }
+ parser := getParser(param, paramType)
+ value, err := parser.parse(paramValue, paramType)
+ if err != nil {
+ return result, err
+ }
+
+ return safeConvert(reflect.ValueOf(value), paramType)
+}
+
+func safeConvert(value reflect.Value, t reflect.Type) (result reflect.Value, err error) {
+ defer func() {
+ if r := recover(); r != nil {
+ var ok bool
+ err, ok = r.(error)
+ if !ok {
+ err = fmt.Errorf("%v", r)
+ }
+ }
+ }()
+ result = value.Convert(t)
+ return
+}
diff --git a/context/param/methodparams.go b/context/param/methodparams.go
new file mode 100644
index 00000000..cd6708a2
--- /dev/null
+++ b/context/param/methodparams.go
@@ -0,0 +1,69 @@
+package param
+
+import (
+ "fmt"
+ "strings"
+)
+
+//MethodParam keeps param information to be auto passed to controller methods
+type MethodParam struct {
+ name string
+ in paramType
+ required bool
+ defaultValue string
+}
+
+type paramType byte
+
+const (
+ param paramType = iota
+ path
+ body
+ header
+)
+
+//New creates a new MethodParam with name and specific options
+func New(name string, opts ...MethodParamOption) *MethodParam {
+ return newParam(name, nil, opts)
+}
+
+func newParam(name string, parser paramParser, opts []MethodParamOption) (param *MethodParam) {
+ param = &MethodParam{name: name}
+ for _, option := range opts {
+ option(param)
+ }
+ return
+}
+
+//Make creates an array of MethodParmas or an empty array
+func Make(list ...*MethodParam) []*MethodParam {
+ if len(list) > 0 {
+ return list
+ }
+ return nil
+}
+
+func (mp *MethodParam) String() string {
+ options := []string{}
+ result := "param.New(\"" + mp.name + "\""
+ if mp.required {
+ options = append(options, "param.IsRequired")
+ }
+ switch mp.in {
+ case path:
+ options = append(options, "param.InPath")
+ case body:
+ options = append(options, "param.InBody")
+ case header:
+ options = append(options, "param.InHeader")
+ }
+ if mp.defaultValue != "" {
+ options = append(options, fmt.Sprintf(`param.Default("%s")`, mp.defaultValue))
+ }
+ if len(options) > 0 {
+ result += ", "
+ }
+ result += strings.Join(options, ", ")
+ result += ")"
+ return result
+}
diff --git a/context/param/options.go b/context/param/options.go
new file mode 100644
index 00000000..58bdc3d0
--- /dev/null
+++ b/context/param/options.go
@@ -0,0 +1,37 @@
+package param
+
+import (
+ "fmt"
+)
+
+// MethodParamOption defines a func which apply options on a MethodParam
+type MethodParamOption func(*MethodParam)
+
+// IsRequired indicates that this param is required and can not be ommited from the http request
+var IsRequired MethodParamOption = func(p *MethodParam) {
+ p.required = true
+}
+
+// InHeader indicates that this param is passed via an http header
+var InHeader MethodParamOption = func(p *MethodParam) {
+ p.in = header
+}
+
+// InPath indicates that this param is part of the URL path
+var InPath MethodParamOption = func(p *MethodParam) {
+ p.in = path
+}
+
+// InBody indicates that this param is passed as an http request body
+var InBody MethodParamOption = func(p *MethodParam) {
+ p.in = body
+}
+
+// Default provides a default value for the http param
+func Default(defaultValue interface{}) MethodParamOption {
+ return func(p *MethodParam) {
+ if defaultValue != nil {
+ p.defaultValue = fmt.Sprint(defaultValue)
+ }
+ }
+}
diff --git a/context/param/parsers.go b/context/param/parsers.go
new file mode 100644
index 00000000..421aecf0
--- /dev/null
+++ b/context/param/parsers.go
@@ -0,0 +1,149 @@
+package param
+
+import (
+ "encoding/json"
+ "reflect"
+ "strconv"
+ "strings"
+ "time"
+)
+
+type paramParser interface {
+ parse(value string, toType reflect.Type) (interface{}, error)
+}
+
+func getParser(param *MethodParam, t reflect.Type) paramParser {
+ switch t.Kind() {
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
+ reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ return intParser{}
+ case reflect.Slice:
+ if t.Elem().Kind() == reflect.Uint8 { //treat []byte as string
+ return stringParser{}
+ }
+ if param.in == body {
+ return jsonParser{}
+ }
+ elemParser := getParser(param, t.Elem())
+ if elemParser == (jsonParser{}) {
+ return elemParser
+ }
+ return sliceParser(elemParser)
+ case reflect.Bool:
+ return boolParser{}
+ case reflect.String:
+ return stringParser{}
+ case reflect.Float32, reflect.Float64:
+ return floatParser{}
+ case reflect.Ptr:
+ elemParser := getParser(param, t.Elem())
+ if elemParser == (jsonParser{}) {
+ return elemParser
+ }
+ return ptrParser(elemParser)
+ default:
+ if t.PkgPath() == "time" && t.Name() == "Time" {
+ return timeParser{}
+ }
+ return jsonParser{}
+ }
+}
+
+type parserFunc func(value string, toType reflect.Type) (interface{}, error)
+
+func (f parserFunc) parse(value string, toType reflect.Type) (interface{}, error) {
+ return f(value, toType)
+}
+
+type boolParser struct {
+}
+
+func (p boolParser) parse(value string, toType reflect.Type) (interface{}, error) {
+ return strconv.ParseBool(value)
+}
+
+type stringParser struct {
+}
+
+func (p stringParser) parse(value string, toType reflect.Type) (interface{}, error) {
+ return value, nil
+}
+
+type intParser struct {
+}
+
+func (p intParser) parse(value string, toType reflect.Type) (interface{}, error) {
+ return strconv.Atoi(value)
+}
+
+type floatParser struct {
+}
+
+func (p floatParser) parse(value string, toType reflect.Type) (interface{}, error) {
+ if toType.Kind() == reflect.Float32 {
+ res, err := strconv.ParseFloat(value, 32)
+ if err != nil {
+ return nil, err
+ }
+ return float32(res), nil
+ }
+ return strconv.ParseFloat(value, 64)
+}
+
+type timeParser struct {
+}
+
+func (p timeParser) parse(value string, toType reflect.Type) (result interface{}, err error) {
+ result, err = time.Parse(time.RFC3339, value)
+ if err != nil {
+ result, err = time.Parse("2006-01-02", value)
+ }
+ return
+}
+
+type jsonParser struct {
+}
+
+func (p jsonParser) parse(value string, toType reflect.Type) (interface{}, error) {
+ pResult := reflect.New(toType)
+ v := pResult.Interface()
+ err := json.Unmarshal([]byte(value), v)
+ if err != nil {
+ return nil, err
+ }
+ return pResult.Elem().Interface(), nil
+}
+
+func sliceParser(elemParser paramParser) paramParser {
+ return parserFunc(func(value string, toType reflect.Type) (interface{}, error) {
+ values := strings.Split(value, ",")
+ result := reflect.MakeSlice(toType, 0, len(values))
+ elemType := toType.Elem()
+ for _, v := range values {
+ parsedValue, err := elemParser.parse(v, elemType)
+ if err != nil {
+ return nil, err
+ }
+ result = reflect.Append(result, reflect.ValueOf(parsedValue))
+ }
+ return result.Interface(), nil
+ })
+}
+
+func ptrParser(elemParser paramParser) paramParser {
+ return parserFunc(func(value string, toType reflect.Type) (interface{}, error) {
+ parsedValue, err := elemParser.parse(value, toType.Elem())
+ if err != nil {
+ return nil, err
+ }
+ newValPtr := reflect.New(toType.Elem())
+ newVal := reflect.Indirect(newValPtr)
+ convertedVal, err := safeConvert(reflect.ValueOf(parsedValue), toType.Elem())
+ if err != nil {
+ return nil, err
+ }
+
+ newVal.Set(convertedVal)
+ return newValPtr.Interface(), nil
+ })
+}
diff --git a/context/param/parsers_test.go b/context/param/parsers_test.go
new file mode 100644
index 00000000..b946ba08
--- /dev/null
+++ b/context/param/parsers_test.go
@@ -0,0 +1,84 @@
+package param
+
+import "testing"
+import "reflect"
+import "time"
+
+type testDefinition struct {
+ strValue string
+ expectedValue interface{}
+ expectedParser paramParser
+}
+
+func Test_Parsers(t *testing.T) {
+
+ //ints
+ checkParser(testDefinition{"1", 1, intParser{}}, t)
+ checkParser(testDefinition{"-1", int64(-1), intParser{}}, t)
+ checkParser(testDefinition{"1", uint64(1), intParser{}}, t)
+
+ //floats
+ checkParser(testDefinition{"1.0", float32(1.0), floatParser{}}, t)
+ checkParser(testDefinition{"-1.0", float64(-1.0), floatParser{}}, t)
+
+ //strings
+ checkParser(testDefinition{"AB", "AB", stringParser{}}, t)
+ checkParser(testDefinition{"AB", []byte{65, 66}, stringParser{}}, t)
+
+ //bools
+ checkParser(testDefinition{"true", true, boolParser{}}, t)
+ checkParser(testDefinition{"0", false, boolParser{}}, t)
+
+ //timeParser
+ checkParser(testDefinition{"2017-05-30T13:54:53Z", time.Date(2017, 5, 30, 13, 54, 53, 0, time.UTC), timeParser{}}, t)
+ checkParser(testDefinition{"2017-05-30", time.Date(2017, 5, 30, 0, 0, 0, 0, time.UTC), timeParser{}}, t)
+
+ //json
+ checkParser(testDefinition{`{"X": 5, "Y":"Z"}`, struct {
+ X int
+ Y string
+ }{5, "Z"}, jsonParser{}}, t)
+
+ //slice in query is parsed as comma delimited
+ checkParser(testDefinition{`1,2`, []int{1, 2}, sliceParser(intParser{})}, t)
+
+ //slice in body is parsed as json
+ checkParser(testDefinition{`["a","b"]`, []string{"a", "b"}, jsonParser{}}, t, MethodParam{in: body})
+
+ //pointers
+ var someInt = 1
+ checkParser(testDefinition{`1`, &someInt, ptrParser(intParser{})}, t)
+
+ var someStruct = struct{ X int }{5}
+ checkParser(testDefinition{`{"X": 5}`, &someStruct, jsonParser{}}, t)
+
+}
+
+func checkParser(def testDefinition, t *testing.T, methodParam ...MethodParam) {
+ toType := reflect.TypeOf(def.expectedValue)
+ var mp MethodParam
+ if len(methodParam) == 0 {
+ mp = MethodParam{}
+ } else {
+ mp = methodParam[0]
+ }
+ parser := getParser(&mp, toType)
+
+ if reflect.TypeOf(parser) != reflect.TypeOf(def.expectedParser) {
+ t.Errorf("Invalid parser for value %v. Expected: %v, actual: %v", def.strValue, reflect.TypeOf(def.expectedParser).Name(), reflect.TypeOf(parser).Name())
+ return
+ }
+ result, err := parser.parse(def.strValue, toType)
+ if err != nil {
+ t.Errorf("Parsing error for value %v. Expected result: %v, error: %v", def.strValue, def.expectedValue, err)
+ return
+ }
+ convResult, err := safeConvert(reflect.ValueOf(result), toType)
+ if err != nil {
+ t.Errorf("Convertion error for %v. from value: %v, toType: %v, error: %v", def.strValue, result, toType, err)
+ return
+ }
+ if !reflect.DeepEqual(convResult.Interface(), def.expectedValue) {
+ t.Errorf("Parsing error for value %v. Expected result: %v, actual: %v", def.strValue, def.expectedValue, result)
+ }
+}
diff --git a/context/renderer.go b/context/renderer.go
new file mode 100644
index 00000000..36a7cb53
--- /dev/null
+++ b/context/renderer.go
@@ -0,0 +1,12 @@
+package context
+
+// Renderer defines an http response renderer
+type Renderer interface {
+ Render(ctx *Context)
+}
+
+type rendererFunc func(ctx *Context)
+
+func (f rendererFunc) Render(ctx *Context) {
+ f(ctx)
+}
diff --git a/context/response.go b/context/response.go
new file mode 100644
index 00000000..9c3c715a
--- /dev/null
+++ b/context/response.go
@@ -0,0 +1,27 @@
+package context
+
+import (
+ "strconv"
+
+ "net/http"
+)
+
+const (
+ //BadRequest indicates http error 400
+ BadRequest StatusCode = http.StatusBadRequest
+
+ //NotFound indicates http error 404
+ NotFound StatusCode = http.StatusNotFound
+)
+
+// StatusCode sets the http response status code
+type StatusCode int
+
+func (s StatusCode) Error() string {
+ return strconv.Itoa(int(s))
+}
+
+// Render sets the http status code
+func (s StatusCode) Render(ctx *Context) {
+ ctx.Output.SetStatus(int(s))
+}
diff --git a/controller.go b/controller.go
index c2a327b3..510e16b8 100644
--- a/controller.go
+++ b/controller.go
@@ -28,6 +28,7 @@ import (
"strings"
"github.com/astaxie/beego/context"
+ "github.com/astaxie/beego/context/param"
"github.com/astaxie/beego/session"
)
@@ -51,6 +52,7 @@ type ControllerComments struct {
Router string
AllowHTTPMethods []string
Params []map[string]string
+ MethodParams []*param.MethodParam
}
// Controller defines some basic http request handler operations, such as
diff --git a/error.go b/error.go
index ab626247..b913db39 100644
--- a/error.go
+++ b/error.go
@@ -252,6 +252,30 @@ func forbidden(rw http.ResponseWriter, r *http.Request) {
)
}
+// show 422 missing xsrf token
+func missingxsrf(rw http.ResponseWriter, r *http.Request) {
+ responseError(rw, r,
+ 422,
+ "
The page you have requested is forbidden."+
+ "
Perhaps you are here because:"+
+ "
"+
+ "
'_xsrf' argument missing from POST"+
+ "
",
+ )
+}
+
+// show 417 invalid xsrf token
+func invalidxsrf(rw http.ResponseWriter, r *http.Request) {
+ responseError(rw, r,
+ 417,
+ "
The page you have requested is forbidden."+
+ "
Perhaps you are here because:"+
+ "
"+
+ "
expected XSRF not found"+
+ "
",
+ )
+}
+
// show 404 not found error.
func notFound(rw http.ResponseWriter, r *http.Request) {
responseError(rw, r,
diff --git a/hooks.go b/hooks.go
index 0fddc82f..c5ec8e2d 100644
--- a/hooks.go
+++ b/hooks.go
@@ -32,6 +32,8 @@ func registerDefaultErrorHandler() error {
"502": badGateway,
"503": serviceUnavailable,
"504": gatewayTimeout,
+ "417": invalidxsrf,
+ "422": missingxsrf,
}
for e, h := range m {
if _, ok := ErrorMaps[e]; !ok {
diff --git a/orm/models_boot.go b/orm/models_boot.go
index 85d0917f..5327f754 100644
--- a/orm/models_boot.go
+++ b/orm/models_boot.go
@@ -75,7 +75,7 @@ func registerModel(PrefixOrSuffix string, model interface{}, isPrefix bool) {
}
if mi.fields.pk == nil {
- fmt.Printf(" `%s` need a primary key field, default use 'id' if not set\n", name)
+ fmt.Printf(" `%s` needs a primary key field, default is to use 'id' if not set\n", name)
os.Exit(2)
}
diff --git a/orm/orm.go b/orm/orm.go
index 5db79386..fcf82590 100644
--- a/orm/orm.go
+++ b/orm/orm.go
@@ -107,7 +107,7 @@ func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect
if mi, ok := modelCache.getByFullName(name); ok {
return mi, ind
}
- panic(fmt.Errorf(" table: `%s` not found, maybe not RegisterModel", name))
+ panic(fmt.Errorf(" table: `%s` not found, make sure it was registered with `RegisterModel()`", name))
}
// get field info from model info by given field name
diff --git a/orm/orm_raw.go b/orm/orm_raw.go
index 1e86212a..c8e741ea 100644
--- a/orm/orm_raw.go
+++ b/orm/orm_raw.go
@@ -493,19 +493,33 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
}
}
} else {
- for i := 0; i < ind.NumField(); i++ {
- f := ind.Field(i)
- fe := ind.Type().Field(i)
- _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName))
- var col string
- if col = tags["column"]; col == "" {
- col = snakeString(fe.Name)
- }
- if v, ok := columnsMp[col]; ok {
- value := reflect.ValueOf(v).Elem().Interface()
- o.setFieldValue(f, value)
+ // define recursive function
+ var recursiveSetField func(rv reflect.Value)
+ recursiveSetField = func(rv reflect.Value) {
+ for i := 0; i < rv.NumField(); i++ {
+ f := rv.Field(i)
+ fe := rv.Type().Field(i)
+
+ // check if the field is a Struct
+ // recursive the Struct type
+ if fe.Type.Kind() == reflect.Struct {
+ recursiveSetField(f)
+ }
+
+ _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName))
+ var col string
+ if col = tags["column"]; col == "" {
+ col = snakeString(fe.Name)
+ }
+ if v, ok := columnsMp[col]; ok {
+ value := reflect.ValueOf(v).Elem().Interface()
+ o.setFieldValue(f, value)
+ }
}
}
+
+ // init call the recursive function
+ recursiveSetField(ind)
}
if eTyps[0].Kind() == reflect.Ptr {
diff --git a/orm/orm_test.go b/orm/orm_test.go
index c5bfa8b9..f1f2d85e 100644
--- a/orm/orm_test.go
+++ b/orm/orm_test.go
@@ -1661,6 +1661,13 @@ func TestRawQueryRow(t *testing.T) {
throwFail(t, AssertIs(pid, nil))
}
+// user_profile table
+type userProfile struct {
+ User
+ Age int
+ Money float64
+}
+
func TestQueryRows(t *testing.T) {
Q := dDbBaser.TableQuote()
@@ -1731,6 +1738,19 @@ func TestQueryRows(t *testing.T) {
throwFailNow(t, AssertIs(usernames[1], "astaxie"))
throwFailNow(t, AssertIs(ids[2], 4))
throwFailNow(t, AssertIs(usernames[2], "nobody"))
+
+ //test query rows by nested struct
+ var l []userProfile
+ query = fmt.Sprintf("SELECT * FROM %suser_profile%s LEFT JOIN %suser%s ON %suser_profile%s.%sid%s = %suser%s.%sid%s", Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q)
+ num, err = dORM.Raw(query).QueryRows(&l)
+ throwFailNow(t, err)
+ throwFailNow(t, AssertIs(num, 2))
+ throwFailNow(t, AssertIs(len(l), 2))
+ throwFailNow(t, AssertIs(l[0].UserName, "slene"))
+ throwFailNow(t, AssertIs(l[0].Age, 28))
+ throwFailNow(t, AssertIs(l[1].UserName, "astaxie"))
+ throwFailNow(t, AssertIs(l[1].Age, 30))
+
}
func TestRawValues(t *testing.T) {
diff --git a/parser.go b/parser.go
index d40ee3ce..a9cfd894 100644
--- a/parser.go
+++ b/parser.go
@@ -24,9 +24,13 @@ import (
"io/ioutil"
"os"
"path/filepath"
+ "regexp"
"sort"
+ "strconv"
"strings"
+ "unicode"
+ "github.com/astaxie/beego/context/param"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/utils"
)
@@ -35,6 +39,7 @@ var globalRouterTemplate = `package routers
import (
"github.com/astaxie/beego"
+ "github.com/astaxie/beego/context/param"
)
func init() {
@@ -81,7 +86,7 @@ func parserPkg(pkgRealpath, pkgpath string) error {
if specDecl.Recv != nil {
exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser
if ok {
- parserComments(specDecl.Doc, specDecl.Name.String(), fmt.Sprint(exp.X), pkgpath)
+ parserComments(specDecl, fmt.Sprint(exp.X), pkgpath)
}
}
}
@@ -93,44 +98,169 @@ func parserPkg(pkgRealpath, pkgpath string) error {
return nil
}
-func parserComments(comments *ast.CommentGroup, funcName, controllerName, pkgpath string) error {
- if comments != nil && comments.List != nil {
- for _, c := range comments.List {
- t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
- if strings.HasPrefix(t, "@router") {
- elements := strings.TrimLeft(t, "@router ")
- e1 := strings.SplitN(elements, " ", 2)
- if len(e1) < 1 {
- return errors.New("you should has router information")
- }
- key := pkgpath + ":" + controllerName
- cc := ControllerComments{}
- cc.Method = funcName
- cc.Router = e1[0]
- if len(e1) == 2 && e1[1] != "" {
- e1 = strings.SplitN(e1[1], " ", 2)
- if len(e1) >= 1 {
- cc.AllowHTTPMethods = strings.Split(strings.Trim(e1[0], "[]"), ",")
- } else {
- cc.AllowHTTPMethods = append(cc.AllowHTTPMethods, "get")
- }
- } else {
- cc.AllowHTTPMethods = append(cc.AllowHTTPMethods, "get")
- }
- if len(e1) == 2 && e1[1] != "" {
- keyval := strings.Split(strings.Trim(e1[1], "[]"), " ")
- for _, kv := range keyval {
- kk := strings.Split(kv, ":")
- cc.Params = append(cc.Params, map[string]string{strings.Join(kk[:len(kk)-1], ":"): kk[len(kk)-1]})
- }
- }
- genInfoList[key] = append(genInfoList[key], cc)
- }
+type parsedComment struct {
+ routerPath string
+ methods []string
+ params map[string]parsedParam
+}
+
+type parsedParam struct {
+ name string
+ datatype string
+ location string
+ defValue string
+ required bool
+}
+
+func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error {
+ if f.Doc != nil {
+ parsedComment, err := parseComment(f.Doc.List)
+ if err != nil {
+ return err
}
+ if parsedComment.routerPath != "" {
+ key := pkgpath + ":" + controllerName
+ cc := ControllerComments{}
+ cc.Method = f.Name.String()
+ cc.Router = parsedComment.routerPath
+ cc.AllowHTTPMethods = parsedComment.methods
+ cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment)
+ genInfoList[key] = append(genInfoList[key], cc)
+ }
+
}
return nil
}
+func buildMethodParams(funcParams []*ast.Field, pc *parsedComment) []*param.MethodParam {
+ result := make([]*param.MethodParam, 0, len(funcParams))
+ for _, fparam := range funcParams {
+ methodParam := buildMethodParam(fparam, pc)
+ result = append(result, methodParam)
+ }
+ return result
+}
+
+func buildMethodParam(fparam *ast.Field, pc *parsedComment) *param.MethodParam {
+ options := []param.MethodParamOption{}
+ name := fparam.Names[0].Name
+ if cparam, ok := pc.params[name]; ok {
+ //Build param from comment info
+ name = cparam.name
+ if cparam.required {
+ options = append(options, param.IsRequired)
+ }
+ switch cparam.location {
+ case "body":
+ options = append(options, param.InBody)
+ case "header":
+ options = append(options, param.InHeader)
+ case "path":
+ options = append(options, param.InPath)
+ }
+ if cparam.defValue != "" {
+ options = append(options, param.Default(cparam.defValue))
+ }
+ } else {
+ if paramInPath(name, pc.routerPath) {
+ options = append(options, param.InPath)
+ }
+ }
+ return param.New(name, options...)
+}
+
+func paramInPath(name, route string) bool {
+ return strings.HasSuffix(route, ":"+name) ||
+ strings.Contains(route, ":"+name+"/")
+}
+
+var routeRegex = regexp.MustCompile(`@router\s+(\S+)(?:\s+\[(\S+)\])?`)
+
+func parseComment(lines []*ast.Comment) (pc *parsedComment, err error) {
+ pc = &parsedComment{}
+ for _, c := range lines {
+ t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
+ if strings.HasPrefix(t, "@router") {
+ matches := routeRegex.FindStringSubmatch(t)
+ if len(matches) == 3 {
+ pc.routerPath = matches[1]
+ methods := matches[2]
+ if methods == "" {
+ pc.methods = []string{"get"}
+ //pc.hasGet = true
+ } else {
+ pc.methods = strings.Split(methods, ",")
+ //pc.hasGet = strings.Contains(methods, "get")
+ }
+ } else {
+ return nil, errors.New("Router information is missing")
+ }
+ } else if strings.HasPrefix(t, "@Param") {
+ pv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Param")))
+ if len(pv) < 4 {
+ logs.Error("Invalid @Param format. Needs at least 4 parameters")
+ }
+ p := parsedParam{}
+ names := strings.SplitN(pv[0], "=>", 2)
+ p.name = names[0]
+ funcParamName := p.name
+ if len(names) > 1 {
+ funcParamName = names[1]
+ }
+ p.location = pv[1]
+ p.datatype = pv[2]
+ switch len(pv) {
+ case 5:
+ p.required, _ = strconv.ParseBool(pv[3])
+ case 6:
+ p.defValue = pv[3]
+ p.required, _ = strconv.ParseBool(pv[4])
+ }
+ if pc.params == nil {
+ pc.params = map[string]parsedParam{}
+ }
+ pc.params[funcParamName] = p
+ }
+ }
+ return
+}
+
+// direct copy from bee\g_docs.go
+// analisys params return []string
+// @Param query form string true "The email for login"
+// [query form string true "The email for login"]
+func getparams(str string) []string {
+ var s []rune
+ var j int
+ var start bool
+ var r []string
+ var quoted int8
+ for _, c := range str {
+ if unicode.IsSpace(c) && quoted == 0 {
+ if !start {
+ continue
+ } else {
+ start = false
+ j++
+ r = append(r, string(s))
+ s = make([]rune, 0)
+ continue
+ }
+ }
+
+ start = true
+ if c == '"' {
+ quoted ^= 1
+ continue
+ }
+ s = append(s, c)
+ }
+ if len(s) > 0 {
+ r = append(r, string(s))
+ }
+ return r
+}
+
func genRouterCode(pkgRealpath string) {
os.Mkdir(getRouterDir(pkgRealpath), 0755)
logs.Info("generate router from comments")
@@ -163,12 +293,24 @@ func genRouterCode(pkgRealpath string) {
}
params = strings.TrimRight(params, ",") + "}"
}
+ methodParams := "param.Make("
+ if len(c.MethodParams) > 0 {
+ lines := make([]string, 0, len(c.MethodParams))
+ for _, m := range c.MethodParams {
+ lines = append(lines, fmt.Sprint(m))
+ }
+ methodParams += "\n " +
+ strings.Join(lines, ",\n ") +
+ ",\n "
+ }
+ methodParams += ")"
globalinfo = globalinfo + `
beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"],
beego.ControllerComments{
Method: "` + strings.TrimSpace(c.Method) + `",
` + "Router: `" + c.Router + "`" + `,
AllowHTTPMethods: ` + allmethod + `,
+ MethodParams: ` + methodParams + `,
Params: ` + params + `})
`
}
diff --git a/plugins/authz/authz.go b/plugins/authz/authz.go
new file mode 100644
index 00000000..9dc0db76
--- /dev/null
+++ b/plugins/authz/authz.go
@@ -0,0 +1,86 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package authz provides handlers to enable ACL, RBAC, ABAC authorization support.
+// Simple Usage:
+// import(
+// "github.com/astaxie/beego"
+// "github.com/astaxie/beego/plugins/authz"
+// "github.com/casbin/casbin"
+// )
+//
+// func main(){
+// // mediate the access for every request
+// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")))
+// beego.Run()
+// }
+//
+//
+// Advanced Usage:
+//
+// func main(){
+// e := casbin.NewEnforcer("authz_model.conf", "")
+// e.AddRoleForUser("alice", "admin")
+// e.AddPolicy(...)
+//
+// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(e))
+// beego.Run()
+// }
+package authz
+
+import (
+ "github.com/astaxie/beego"
+ "github.com/astaxie/beego/context"
+ "github.com/casbin/casbin"
+ "net/http"
+)
+
+// NewAuthorizer returns the authorizer.
+// Use a casbin enforcer as input
+func NewAuthorizer(e *casbin.Enforcer) beego.FilterFunc {
+ return func(ctx *context.Context) {
+ a := &BasicAuthorizer{enforcer: e}
+
+ if !a.CheckPermission(ctx.Request) {
+ a.RequirePermission(ctx.ResponseWriter)
+ }
+ }
+}
+
+// BasicAuthorizer stores the casbin handler
+type BasicAuthorizer struct {
+ enforcer *casbin.Enforcer
+}
+
+// GetUserName gets the user name from the request.
+// Currently, only HTTP basic authentication is supported
+func (a *BasicAuthorizer) GetUserName(r *http.Request) string {
+ username, _, _ := r.BasicAuth()
+ return username
+}
+
+// CheckPermission checks the user/method/path combination from the request.
+// Returns true (permission granted) or false (permission forbidden)
+func (a *BasicAuthorizer) CheckPermission(r *http.Request) bool {
+ user := a.GetUserName(r)
+ method := r.Method
+ path := r.URL.Path
+ return a.enforcer.Enforce(user, path, method)
+}
+
+// RequirePermission returns the 403 Forbidden to the client
+func (a *BasicAuthorizer) RequirePermission(w http.ResponseWriter) {
+ w.WriteHeader(403)
+ w.Write([]byte("403 Forbidden\n"))
+}
diff --git a/plugins/authz/authz_model.conf b/plugins/authz/authz_model.conf
new file mode 100644
index 00000000..d1b3dbd7
--- /dev/null
+++ b/plugins/authz/authz_model.conf
@@ -0,0 +1,14 @@
+[request_definition]
+r = sub, obj, act
+
+[policy_definition]
+p = sub, obj, act
+
+[role_definition]
+g = _, _
+
+[policy_effect]
+e = some(where (p.eft == allow))
+
+[matchers]
+m = g(r.sub, p.sub) && keyMatch(r.obj, p.obj) && (r.act == p.act || p.act == "*")
\ No newline at end of file
diff --git a/plugins/authz/authz_policy.csv b/plugins/authz/authz_policy.csv
new file mode 100644
index 00000000..c062dd3e
--- /dev/null
+++ b/plugins/authz/authz_policy.csv
@@ -0,0 +1,7 @@
+p, alice, /dataset1/*, GET
+p, alice, /dataset1/resource1, POST
+p, bob, /dataset2/resource1, *
+p, bob, /dataset2/resource2, GET
+p, bob, /dataset2/folder1/*, POST
+p, dataset1_admin, /dataset1/*, *
+g, cathy, dataset1_admin
\ No newline at end of file
diff --git a/plugins/authz/authz_test.go b/plugins/authz/authz_test.go
new file mode 100644
index 00000000..49aed84c
--- /dev/null
+++ b/plugins/authz/authz_test.go
@@ -0,0 +1,107 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package authz
+
+import (
+ "github.com/astaxie/beego"
+ "github.com/astaxie/beego/context"
+ "github.com/astaxie/beego/plugins/auth"
+ "github.com/casbin/casbin"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+)
+
+func testRequest(t *testing.T, handler *beego.ControllerRegister, user string, path string, method string, code int) {
+ r, _ := http.NewRequest(method, path, nil)
+ r.SetBasicAuth(user, "123")
+ w := httptest.NewRecorder()
+ handler.ServeHTTP(w, r)
+
+ if w.Code != code {
+ t.Errorf("%s, %s, %s: %d, supposed to be %d", user, path, method, w.Code, code)
+ }
+}
+
+func TestBasic(t *testing.T) {
+ handler := beego.NewControllerRegister()
+
+ handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("alice", "123"))
+ handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")))
+
+ handler.Any("*", func(ctx *context.Context) {
+ ctx.Output.SetStatus(200)
+ })
+
+ testRequest(t, handler, "alice", "/dataset1/resource1", "GET", 200)
+ testRequest(t, handler, "alice", "/dataset1/resource1", "POST", 200)
+ testRequest(t, handler, "alice", "/dataset1/resource2", "GET", 200)
+ testRequest(t, handler, "alice", "/dataset1/resource2", "POST", 403)
+}
+
+func TestPathWildcard(t *testing.T) {
+ handler := beego.NewControllerRegister()
+
+ handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("bob", "123"))
+ handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")))
+
+ handler.Any("*", func(ctx *context.Context) {
+ ctx.Output.SetStatus(200)
+ })
+
+ testRequest(t, handler, "bob", "/dataset2/resource1", "GET", 200)
+ testRequest(t, handler, "bob", "/dataset2/resource1", "POST", 200)
+ testRequest(t, handler, "bob", "/dataset2/resource1", "DELETE", 200)
+ testRequest(t, handler, "bob", "/dataset2/resource2", "GET", 200)
+ testRequest(t, handler, "bob", "/dataset2/resource2", "POST", 403)
+ testRequest(t, handler, "bob", "/dataset2/resource2", "DELETE", 403)
+
+ testRequest(t, handler, "bob", "/dataset2/folder1/item1", "GET", 403)
+ testRequest(t, handler, "bob", "/dataset2/folder1/item1", "POST", 200)
+ testRequest(t, handler, "bob", "/dataset2/folder1/item1", "DELETE", 403)
+ testRequest(t, handler, "bob", "/dataset2/folder1/item2", "GET", 403)
+ testRequest(t, handler, "bob", "/dataset2/folder1/item2", "POST", 200)
+ testRequest(t, handler, "bob", "/dataset2/folder1/item2", "DELETE", 403)
+}
+
+func TestRBAC(t *testing.T) {
+ handler := beego.NewControllerRegister()
+
+ handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("cathy", "123"))
+ e := casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")
+ handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(e))
+
+ handler.Any("*", func(ctx *context.Context) {
+ ctx.Output.SetStatus(200)
+ })
+
+ // cathy can access all /dataset1/* resources via all methods because it has the dataset1_admin role.
+ testRequest(t, handler, "cathy", "/dataset1/item", "GET", 200)
+ testRequest(t, handler, "cathy", "/dataset1/item", "POST", 200)
+ testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 200)
+ testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403)
+ testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403)
+ testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403)
+
+ // delete all roles on user cathy, so cathy cannot access any resources now.
+ e.DeleteRolesForUser("cathy")
+
+ testRequest(t, handler, "cathy", "/dataset1/item", "GET", 403)
+ testRequest(t, handler, "cathy", "/dataset1/item", "POST", 403)
+ testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 403)
+ testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403)
+ testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403)
+ testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403)
+}
diff --git a/router.go b/router.go
index cf1beceb..a56b1917 100644
--- a/router.go
+++ b/router.go
@@ -27,6 +27,7 @@ import (
"time"
beecontext "github.com/astaxie/beego/context"
+ "github.com/astaxie/beego/context/param"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/toolbox"
"github.com/astaxie/beego/utils"
@@ -117,6 +118,7 @@ type ControllerInfo struct {
runFunction FilterFunc
routerType int
initialize func() ControllerInterface
+ methodParams []*param.MethodParam
}
// ControllerRegister containers registered router rules, controller handlers and filters.
@@ -152,6 +154,10 @@ func NewControllerRegister() *ControllerRegister {
// Add("/api",&RestController{},"get,post:ApiFunc"
// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) {
+ p.addWithMethodParams(pattern, c, nil, mappingMethods...)
+}
+
+func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, mappingMethods ...string) {
reflectVal := reflect.ValueOf(c)
t := reflect.Indirect(reflectVal).Type()
methods := make(map[string]string)
@@ -203,6 +209,7 @@ func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingM
return execController
}
+ route.methodParams = methodParams
if len(methods) == 0 {
for _, m := range HTTPMETHOD {
p.addToRouter(m, pattern, route)
@@ -267,7 +274,7 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) {
key := t.PkgPath() + ":" + t.Name()
if comm, ok := GlobalControllerRouter[key]; ok {
for _, a := range comm {
- p.Add(a.Router, c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)
+ p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)
}
}
}
@@ -646,11 +653,12 @@ func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath str
func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
startTime := time.Now()
var (
- runRouter reflect.Type
- findRouter bool
- runMethod string
- routerInfo *ControllerInfo
- isRunnable bool
+ runRouter reflect.Type
+ findRouter bool
+ runMethod string
+ methodParams []*param.MethodParam
+ routerInfo *ControllerInfo
+ isRunnable bool
)
context := p.pool.Get().(*beecontext.Context)
context.Reset(rw, r)
@@ -762,6 +770,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
routerInfo.handler.ServeHTTP(rw, r)
} else {
runRouter = routerInfo.controllerType
+ methodParams = routerInfo.methodParams
method := r.Method
if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodPost {
method = http.MethodPut
@@ -782,7 +791,17 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
// also defined runRouter & runMethod from filter
if !isRunnable {
//Invoke the request handler
- var execController ControllerInterface = routerInfo.initialize()
+ var execController ControllerInterface
+ if routerInfo.initialize != nil {
+ execController = routerInfo.initialize()
+ } else {
+ vc := reflect.New(runRouter)
+ var ok bool
+ execController, ok = vc.Interface().(ControllerInterface)
+ if !ok {
+ panic("controller is not ControllerInterface")
+ }
+ }
//call the controller init function
execController.Init(context, runRouter.Name(), runMethod, execController)
@@ -820,10 +839,15 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
execController.Options()
default:
if !execController.HandlerFunc(runMethod) {
- var in []reflect.Value
vc := reflect.ValueOf(execController)
method := vc.MethodByName(runMethod)
- method.Call(in)
+ in := param.ConvertParams(methodParams, method.Type(), context)
+ out := method.Call(in)
+
+ //For backward compatibility we only handle response if we had incoming methodParams
+ if methodParams != nil {
+ p.handleParamResponse(context, execController, out)
+ }
}
}
@@ -903,6 +927,20 @@ Admin:
}
}
+func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, execController ControllerInterface, results []reflect.Value) {
+ //looping in reverse order for the case when both error and value are returned and error sets the response status code
+ for i := len(results) - 1; i >= 0; i-- {
+ result := results[i]
+ if result.Kind() != reflect.Interface || !result.IsNil() {
+ resultValue := result.Interface()
+ context.RenderMethodResult(resultValue)
+ }
+ }
+ if !context.ResponseWriter.Started && context.Output.Status == 0 {
+ context.Output.SetStatus(200)
+ }
+}
+
// FindRouter Find Router info for URL
func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo *ControllerInfo, isFind bool) {
var urlPath = context.Input.URL()
diff --git a/templatefunc.go b/templatefunc.go
index a6f9c961..a104fd24 100644
--- a/templatefunc.go
+++ b/templatefunc.go
@@ -27,9 +27,10 @@ import (
)
const (
- formatTime = "15:04:05"
- formatDate = "2006-01-02"
- formatDateTime = "2006-01-02 15:04:05"
+ formatTime = "15:04:05"
+ formatDate = "2006-01-02"
+ formatDateTime = "2006-01-02 15:04:05"
+ formatDateTimeT = "2006-01-02T15:04:05"
)
// Substr returns the substr from start to length.
@@ -53,21 +54,21 @@ func Substr(s string, start, length int) string {
// HTML2str returns escaping text convert from html.
func HTML2str(html string) string {
- re, _ := regexp.Compile("\\<[\\S\\s]+?\\>")
+ re, _ := regexp.Compile(`\<[\S\s]+?\>`)
html = re.ReplaceAllStringFunc(html, strings.ToLower)
//remove STYLE
- re, _ = regexp.Compile("\\`)
html = re.ReplaceAllString(html, "")
//remove SCRIPT
- re, _ = regexp.Compile("\\`)
html = re.ReplaceAllString(html, "")
- re, _ = regexp.Compile("\\<[\\S\\s]+?\\>")
+ re, _ = regexp.Compile(`\<[\S\s]+?\>`)
html = re.ReplaceAllString(html, "\n")
- re, _ = regexp.Compile("\\s{2,}")
+ re, _ = regexp.Compile(`\s{2,}`)
html = re.ReplaceAllString(html, "\n")
return strings.TrimSpace(html)
@@ -360,8 +361,13 @@ func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) e
value = value[:25]
t, err = time.ParseInLocation(time.RFC3339, value, time.Local)
} else if len(value) >= 19 {
- value = value[:19]
- t, err = time.ParseInLocation(formatDateTime, value, time.Local)
+ if strings.Contains(value, "T") {
+ value = value[:19]
+ t, err = time.ParseInLocation(formatDateTimeT, value, time.Local)
+ } else {
+ value = value[:19]
+ t, err = time.ParseInLocation(formatDateTime, value, time.Local)
+ }
} else if len(value) >= 10 {
if len(value) > 10 {
value = value[:10]
@@ -373,7 +379,6 @@ func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) e
}
t, err = time.ParseInLocation(formatTime, value, time.Local)
}
-
if err != nil {
return err
}
diff --git a/validation/validation_test.go b/validation/validation_test.go
index 83e881bf..71a8bd17 100644
--- a/validation/validation_test.go
+++ b/validation/validation_test.go
@@ -35,6 +35,12 @@ func TestRequired(t *testing.T) {
if valid.Required("", "string").Ok {
t.Error("\"'\" string should be false")
}
+ if valid.Required(" ", "string").Ok {
+ t.Error("\" \" string should be false") // For #2361
+ }
+ if valid.Required("\n", "string").Ok {
+ t.Error("new line string should be false") // For #2361
+ }
if !valid.Required("astaxie", "string").Ok {
t.Error("string should be true")
}
@@ -175,10 +181,10 @@ func TestAlphaNumeric(t *testing.T) {
func TestMatch(t *testing.T) {
valid := Validation{}
- if valid.Match("suchuangji@gmail", regexp.MustCompile("^\\w+@\\w+\\.\\w+$"), "match").Ok {
+ if valid.Match("suchuangji@gmail", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok {
t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be false")
}
- if !valid.Match("suchuangji@gmail.com", regexp.MustCompile("^\\w+@\\w+\\.\\w+$"), "match").Ok {
+ if !valid.Match("suchuangji@gmail.com", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok {
t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be true")
}
}
@@ -186,10 +192,10 @@ func TestMatch(t *testing.T) {
func TestNoMatch(t *testing.T) {
valid := Validation{}
- if valid.NoMatch("123@gmail", regexp.MustCompile("[^\\w\\d]"), "nomatch").Ok {
+ if valid.NoMatch("123@gmail", regexp.MustCompile(`[^\w\d]`), "nomatch").Ok {
t.Error("\"123@gmail\" not match \"[^\\w\\d]\" should be false")
}
- if !valid.NoMatch("123gmail", regexp.MustCompile("[^\\w\\d]"), "match").Ok {
+ if !valid.NoMatch("123gmail", regexp.MustCompile(`[^\w\d]`), "match").Ok {
t.Error("\"123@gmail\" not match \"[^\\w\\d@]\" should be true")
}
}
diff --git a/validation/validators.go b/validation/validators.go
index 01aed443..5d489a55 100644
--- a/validation/validators.go
+++ b/validation/validators.go
@@ -18,6 +18,7 @@ import (
"fmt"
"reflect"
"regexp"
+ "strings"
"time"
"unicode/utf8"
)
@@ -98,7 +99,7 @@ func (r Required) IsSatisfied(obj interface{}) bool {
}
if str, ok := obj.(string); ok {
- return len(str) > 0
+ return len(strings.TrimSpace(str)) > 0
}
if _, ok := obj.(bool); ok {
return true
@@ -145,7 +146,7 @@ func (r Required) IsSatisfied(obj interface{}) bool {
// DefaultMessage return the default error message
func (r Required) DefaultMessage() string {
- return fmt.Sprint(MessageTmpls["Required"])
+ return MessageTmpls["Required"]
}
// GetKey return the r.Key
@@ -364,7 +365,7 @@ func (a Alpha) IsSatisfied(obj interface{}) bool {
// DefaultMessage return the default Length error message
func (a Alpha) DefaultMessage() string {
- return fmt.Sprint(MessageTmpls["Alpha"])
+ return MessageTmpls["Alpha"]
}
// GetKey return the m.Key
@@ -397,7 +398,7 @@ func (n Numeric) IsSatisfied(obj interface{}) bool {
// DefaultMessage return the default Length error message
func (n Numeric) DefaultMessage() string {
- return fmt.Sprint(MessageTmpls["Numeric"])
+ return MessageTmpls["Numeric"]
}
// GetKey return the n.Key
@@ -430,7 +431,7 @@ func (a AlphaNumeric) IsSatisfied(obj interface{}) bool {
// DefaultMessage return the default Length error message
func (a AlphaNumeric) DefaultMessage() string {
- return fmt.Sprint(MessageTmpls["AlphaNumeric"])
+ return MessageTmpls["AlphaNumeric"]
}
// GetKey return the a.Key
@@ -495,7 +496,7 @@ func (n NoMatch) GetLimitValue() interface{} {
return n.Regexp.String()
}
-var alphaDashPattern = regexp.MustCompile("[^\\d\\w-_]")
+var alphaDashPattern = regexp.MustCompile(`[^\d\w-_]`)
// AlphaDash check not Alpha
type AlphaDash struct {
@@ -505,7 +506,7 @@ type AlphaDash struct {
// DefaultMessage return the default AlphaDash error message
func (a AlphaDash) DefaultMessage() string {
- return fmt.Sprint(MessageTmpls["AlphaDash"])
+ return MessageTmpls["AlphaDash"]
}
// GetKey return the n.Key
@@ -518,7 +519,7 @@ func (a AlphaDash) GetLimitValue() interface{} {
return nil
}
-var emailPattern = regexp.MustCompile("^[\\w!#$%&'*+/=?^_`{|}~-]+(?:\\.[\\w!#$%&'*+/=?^_`{|}~-]+)*@(?:[\\w](?:[\\w-]*[\\w])?\\.)+[a-zA-Z0-9](?:[\\w-]*[\\w])?$")
+var emailPattern = regexp.MustCompile(`^[\w!#$%&'*+/=?^_` + "`" + `{|}~-]+(?:\.[\w!#$%&'*+/=?^_` + "`" + `{|}~-]+)*@(?:[\w](?:[\w-]*[\w])?\.)+[a-zA-Z0-9](?:[\w-]*[\w])?$`)
// Email check struct
type Email struct {
@@ -528,7 +529,7 @@ type Email struct {
// DefaultMessage return the default Email error message
func (e Email) DefaultMessage() string {
- return fmt.Sprint(MessageTmpls["Email"])
+ return MessageTmpls["Email"]
}
// GetKey return the n.Key
@@ -541,7 +542,7 @@ func (e Email) GetLimitValue() interface{} {
return nil
}
-var ipPattern = regexp.MustCompile("^((2[0-4]\\d|25[0-5]|[01]?\\d\\d?)\\.){3}(2[0-4]\\d|25[0-5]|[01]?\\d\\d?)$")
+var ipPattern = regexp.MustCompile(`^((2[0-4]\d|25[0-5]|[01]?\d\d?)\.){3}(2[0-4]\d|25[0-5]|[01]?\d\d?)$`)
// IP check struct
type IP struct {
@@ -551,7 +552,7 @@ type IP struct {
// DefaultMessage return the default IP error message
func (i IP) DefaultMessage() string {
- return fmt.Sprint(MessageTmpls["IP"])
+ return MessageTmpls["IP"]
}
// GetKey return the i.Key
@@ -564,7 +565,7 @@ func (i IP) GetLimitValue() interface{} {
return nil
}
-var base64Pattern = regexp.MustCompile("^(?:[A-Za-z0-99+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?$")
+var base64Pattern = regexp.MustCompile(`^(?:[A-Za-z0-99+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?$`)
// Base64 check struct
type Base64 struct {
@@ -574,7 +575,7 @@ type Base64 struct {
// DefaultMessage return the default Base64 error message
func (b Base64) DefaultMessage() string {
- return fmt.Sprint(MessageTmpls["Base64"])
+ return MessageTmpls["Base64"]
}
// GetKey return the b.Key
@@ -588,7 +589,7 @@ func (b Base64) GetLimitValue() interface{} {
}
// just for chinese mobile phone number
-var mobilePattern = regexp.MustCompile("^((\\+86)|(86))?(1(([35][0-9])|[8][0-9]|[7][06789]|[4][579]))\\d{8}$")
+var mobilePattern = regexp.MustCompile(`^((\+86)|(86))?(1(([35][0-9])|[8][0-9]|[7][06789]|[4][579]))\d{8}$`)
// Mobile check struct
type Mobile struct {
@@ -598,7 +599,7 @@ type Mobile struct {
// DefaultMessage return the default Mobile error message
func (m Mobile) DefaultMessage() string {
- return fmt.Sprint(MessageTmpls["Mobile"])
+ return MessageTmpls["Mobile"]
}
// GetKey return the m.Key
@@ -612,7 +613,7 @@ func (m Mobile) GetLimitValue() interface{} {
}
// just for chinese telephone number
-var telPattern = regexp.MustCompile("^(0\\d{2,3}(\\-)?)?\\d{7,8}$")
+var telPattern = regexp.MustCompile(`^(0\d{2,3}(\-)?)?\d{7,8}$`)
// Tel check telephone struct
type Tel struct {
@@ -622,7 +623,7 @@ type Tel struct {
// DefaultMessage return the default Tel error message
func (t Tel) DefaultMessage() string {
- return fmt.Sprint(MessageTmpls["Tel"])
+ return MessageTmpls["Tel"]
}
// GetKey return the t.Key
@@ -649,7 +650,7 @@ func (p Phone) IsSatisfied(obj interface{}) bool {
// DefaultMessage return the default Phone error message
func (p Phone) DefaultMessage() string {
- return fmt.Sprint(MessageTmpls["Phone"])
+ return MessageTmpls["Phone"]
}
// GetKey return the p.Key
@@ -663,7 +664,7 @@ func (p Phone) GetLimitValue() interface{} {
}
// just for chinese zipcode
-var zipCodePattern = regexp.MustCompile("^[1-9]\\d{5}$")
+var zipCodePattern = regexp.MustCompile(`^[1-9]\d{5}$`)
// ZipCode check the zip struct
type ZipCode struct {
@@ -673,7 +674,7 @@ type ZipCode struct {
// DefaultMessage return the default Zip error message
func (z ZipCode) DefaultMessage() string {
- return fmt.Sprint(MessageTmpls["ZipCode"])
+ return MessageTmpls["ZipCode"]
}
// GetKey return the z.Key