diff --git a/router.go b/router.go index 2f5d2eae..11cb4aa8 100644 --- a/router.go +++ b/router.go @@ -201,9 +201,12 @@ func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInt numOfFields := elemVal.NumField() for i := 0; i < numOfFields; i++ { - fieldVal := elemVal.Field(i) fieldType := elemType.Field(i) - execElem.FieldByName(fieldType.Name).Set(fieldVal) + + if execElem.FieldByName(fieldType.Name).CanSet() { + fieldVal := elemVal.Field(i) + execElem.FieldByName(fieldType.Name).Set(fieldVal) + } } return execController diff --git a/validation/validation.go b/validation/validation.go index 0d89be30..2781a72e 100644 --- a/validation/validation.go +++ b/validation/validation.go @@ -245,7 +245,21 @@ func (v *Validation) ZipCode(obj interface{}, key string) *Result { } func (v *Validation) apply(chk Validator, obj interface{}) *Result { - if chk.IsSatisfied(obj) { + if nil == obj { + if chk.IsSatisfied(obj) { + return &Result{Ok: true} + } + } else if reflect.TypeOf(obj).Kind() == reflect.Ptr { + if reflect.ValueOf(obj).IsNil() { + if chk.IsSatisfied(nil) { + return &Result{Ok: true} + } + } else { + if chk.IsSatisfied(reflect.ValueOf(obj).Elem().Interface()) { + return &Result{Ok: true} + } + } + } else if chk.IsSatisfied(obj) { return &Result{Ok: true} } @@ -357,7 +371,18 @@ func (v *Validation) Valid(obj interface{}) (b bool, err error) { hasReuired = true } - if !hasReuired && v.RequiredFirst && len(objV.Field(i).String()) == 0 { + currentField := objV.Field(i).Interface() + if objV.Field(i).Kind() == reflect.Ptr { + if objV.Field(i).IsNil() { + currentField = "" + } else { + currentField = objV.Field(i).Elem().Interface() + } + } + + + chk := Required{""}.IsSatisfied(currentField) + if !hasReuired && v.RequiredFirst && !chk { if _, ok := CanSkipFuncs[vf.Name]; ok { continue } @@ -414,3 +439,9 @@ func (v *Validation) RecursiveValid(objc interface{}) (bool, error) { } return pass, err } + +func (v *Validation) CanSkipAlso(skipFunc string) { + if _, ok := CanSkipFuncs[skipFunc]; !ok { + CanSkipFuncs[skipFunc] = struct{}{} + } +} diff --git a/validation/validation_test.go b/validation/validation_test.go index bf612015..f97105fd 100644 --- a/validation/validation_test.go +++ b/validation/validation_test.go @@ -442,3 +442,122 @@ func TestSkipValid(t *testing.T) { t.Fatal("validation should be passed") } } + +func TestPointer(t *testing.T) { + type User struct { + ID int + + Email *string `valid:"Email"` + ReqEmail *string `valid:"Required;Email"` + } + + u := User{ + ReqEmail: nil, + Email: nil, + } + + valid := Validation{} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + validEmail := "a@a.com" + u = User{ + ReqEmail: &validEmail, + Email: nil, + } + + valid = Validation{RequiredFirst: true} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Fatal("validation should be passed") + } + + u = User{ + ReqEmail: &validEmail, + Email: nil, + } + + valid = Validation{} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + invalidEmail := "a@a" + u = User{ + ReqEmail: &validEmail, + Email: &invalidEmail, + } + + valid = Validation{RequiredFirst: true} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + u = User{ + ReqEmail: &validEmail, + Email: &invalidEmail, + } + + valid = Validation{} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } +} + + +func TestCanSkipAlso(t *testing.T) { + type User struct { + ID int + + Email string `valid:"Email"` + ReqEmail string `valid:"Required;Email"` + MatchRange int `valid:"Range(10, 20)"` + } + + u := User{ + ReqEmail: "a@a.com", + Email: "", + MatchRange: 0, + } + + valid := Validation{RequiredFirst: true} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + valid = Validation{RequiredFirst: true} + valid.CanSkipAlso("Range") + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Fatal("validation should be passed") + } + +} +