diff --git a/context/input.go b/context/input.go index df0680e5..639c58c6 100644 --- a/context/input.go +++ b/context/input.go @@ -5,6 +5,7 @@ import ( "errors" "io/ioutil" "net/http" + "net/url" "reflect" "strconv" "strings" @@ -261,6 +262,7 @@ func (input *BeegoInput) SetData(key, val interface{}) { input.Data[key] = val } +// parseForm or parseMultiForm based on Content-type func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error { // Parse the body depending on the content type. switch input.Header("Content-Type") { @@ -278,3 +280,244 @@ func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error { return nil } + +// Bind data from request.Form[key] to dest +// like /?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie +// var id int beegoInput.Bind(&id, "id") id ==123 +// var isok bool beegoInput.Bind(&isok, "isok") id ==true +// var ft float64 beegoInput.Bind(&ft, "ft") ft ==1.2 +// ol := make([]int, 0, 2) beegoInput.Bind(&ol, "ol") ol ==[1 2] +// ul := make([]string, 0, 2) beegoInput.Bind(&ul, "ul") ul ==[str array] +// user struct{Name} beegoInput.Bind(&user, "user") user == {Name:"astaxie"} +func (input *BeegoInput) Bind(dest interface{}, key string) error { + value := reflect.ValueOf(dest) + if value.Kind() != reflect.Ptr { + return errors.New("beego: non-pointer passed to Bind: " + key) + } + value = value.Elem() + if !value.CanSet() { + return errors.New("beego: non-settable variable passed to Bind: " + key) + } + rv := input.bind(key, value.Type()) + if !rv.IsValid() { + return errors.New("beego: reflect value is empty") + } + value.Set(rv) + return nil +} + +func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value { + rv := reflect.Zero(reflect.TypeOf(0)) + switch typ.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + val := input.Query(key) + if len(val) == 0 { + return rv + } + rv = input.bindInt(val, typ) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + val := input.Query(key) + if len(val) == 0 { + return rv + } + rv = input.bindUint(val, typ) + case reflect.Float32, reflect.Float64: + val := input.Query(key) + if len(val) == 0 { + return rv + } + rv = input.bindFloat(val, typ) + case reflect.String: + val := input.Query(key) + if len(val) == 0 { + return rv + } + rv = input.bindString(val, typ) + case reflect.Bool: + val := input.Query(key) + if len(val) == 0 { + return rv + } + rv = input.bindBool(val, typ) + case reflect.Slice: + rv = input.bindSlice(&input.Request.Form, key, typ) + case reflect.Struct: + rv = input.bindStruct(&input.Request.Form, key, typ) + case reflect.Ptr: + rv = input.bindPoint(key, typ) + case reflect.Map: + rv = input.bindMap(&input.Request.Form, key, typ) + } + return rv +} + +func (input *BeegoInput) bindValue(val string, typ reflect.Type) reflect.Value { + rv := reflect.Zero(reflect.TypeOf(0)) + switch typ.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + rv = input.bindInt(val, typ) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + rv = input.bindUint(val, typ) + case reflect.Float32, reflect.Float64: + rv = input.bindFloat(val, typ) + case reflect.String: + rv = input.bindString(val, typ) + case reflect.Bool: + rv = input.bindBool(val, typ) + case reflect.Slice: + rv = input.bindSlice(&url.Values{"": {val}}, "", typ) + case reflect.Struct: + rv = input.bindStruct(&url.Values{"": {val}}, "", typ) + case reflect.Ptr: + rv = input.bindPoint(val, typ) + case reflect.Map: + rv = input.bindMap(&url.Values{"": {val}}, "", typ) + } + return rv +} + +func (input *BeegoInput) bindInt(val string, typ reflect.Type) reflect.Value { + intValue, err := strconv.ParseInt(val, 10, 64) + if err != nil { + return reflect.Zero(typ) + } + pValue := reflect.New(typ) + pValue.Elem().SetInt(intValue) + return pValue.Elem() +} + +func (input *BeegoInput) bindUint(val string, typ reflect.Type) reflect.Value { + uintValue, err := strconv.ParseUint(val, 10, 64) + if err != nil { + return reflect.Zero(typ) + } + pValue := reflect.New(typ) + pValue.Elem().SetUint(uintValue) + return pValue.Elem() +} + +func (input *BeegoInput) bindFloat(val string, typ reflect.Type) reflect.Value { + floatValue, err := strconv.ParseFloat(val, 64) + if err != nil { + return reflect.Zero(typ) + } + pValue := reflect.New(typ) + pValue.Elem().SetFloat(floatValue) + return pValue.Elem() +} + +func (input *BeegoInput) bindString(val string, typ reflect.Type) reflect.Value { + return reflect.ValueOf(val) +} + +func (input *BeegoInput) bindBool(val string, typ reflect.Type) reflect.Value { + val = strings.TrimSpace(strings.ToLower(val)) + switch val { + case "true", "on", "1": + return reflect.ValueOf(true) + } + return reflect.ValueOf(false) +} + +type sliceValue struct { + index int // Index extracted from brackets. If -1, no index was provided. + value reflect.Value // the bound value for this slice element. +} + +func (input *BeegoInput) bindSlice(params *url.Values, key string, typ reflect.Type) reflect.Value { + maxIndex := -1 + numNoIndex := 0 + sliceValues := []sliceValue{} + for reqKey, vals := range *params { + if !strings.HasPrefix(reqKey, key+"[") { + continue + } + // Extract the index, and the index where a sub-key starts. (e.g. field[0].subkey) + index := -1 + leftBracket, rightBracket := len(key), strings.Index(reqKey[len(key):], "]")+len(key) + if rightBracket > leftBracket+1 { + index, _ = strconv.Atoi(reqKey[leftBracket+1 : rightBracket]) + } + subKeyIndex := rightBracket + 1 + + // Handle the indexed case. + if index > -1 { + if index > maxIndex { + maxIndex = index + } + sliceValues = append(sliceValues, sliceValue{ + index: index, + value: input.bind(reqKey[:subKeyIndex], typ.Elem()), + }) + continue + } + + // It's an un-indexed element. (e.g. element[]) + numNoIndex += len(vals) + for _, val := range vals { + // Unindexed values can only be direct-bound. + sliceValues = append(sliceValues, sliceValue{ + index: -1, + value: input.bindValue(val, typ.Elem()), + }) + } + } + resultArray := reflect.MakeSlice(typ, maxIndex+1, maxIndex+1+numNoIndex) + for _, sv := range sliceValues { + if sv.index != -1 { + resultArray.Index(sv.index).Set(sv.value) + } else { + resultArray = reflect.Append(resultArray, sv.value) + } + } + return resultArray +} + +func (input *BeegoInput) bindStruct(params *url.Values, key string, typ reflect.Type) reflect.Value { + result := reflect.New(typ).Elem() + fieldValues := make(map[string]reflect.Value) + for reqKey, val := range *params { + if !strings.HasPrefix(reqKey, key+".") { + continue + } + + fieldName := reqKey[len(key)+1:] + + if _, ok := fieldValues[fieldName]; !ok { + // Time to bind this field. Get it and make sure we can set it. + fieldValue := result.FieldByName(fieldName) + if !fieldValue.IsValid() { + continue + } + if !fieldValue.CanSet() { + continue + } + boundVal := input.bindValue(val[0], fieldValue.Type()) + fieldValue.Set(boundVal) + fieldValues[fieldName] = boundVal + } + } + + return result +} + +func (input *BeegoInput) bindPoint(key string, typ reflect.Type) reflect.Value { + return input.bind(key, typ.Elem()).Addr() +} + +func (input *BeegoInput) bindMap(params *url.Values, key string, typ reflect.Type) reflect.Value { + var ( + result = reflect.MakeMap(typ) + keyType = typ.Key() + valueType = typ.Elem() + ) + for paramName, values := range *params { + if !strings.HasPrefix(paramName, key+"[") || paramName[len(paramName)-1] != ']' { + continue + } + + key := paramName[len(key)+1 : len(paramName)-1] + result.SetMapIndex(input.bindValue(key, keyType), input.bindValue(values[0], valueType)) + } + return result +} diff --git a/context/input_test.go b/context/input_test.go new file mode 100644 index 00000000..26e95982 --- /dev/null +++ b/context/input_test.go @@ -0,0 +1,58 @@ +package context + +import ( + "fmt" + "net/http" + "testing" +) + +func TestParse(t *testing.T) { + r, _ := http.NewRequest("GET", "/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil) + beegoInput := NewInput(r) + beegoInput.ParseFormOrMulitForm(1 << 20) + + var id int + err := beegoInput.Bind(&id, "id") + if id != 123 || err != nil { + t.Fatal("id should has int value") + } + fmt.Println(id) + + var isok bool + err = beegoInput.Bind(&isok, "isok") + if !isok || err != nil { + t.Fatal("isok should be true") + } + fmt.Println(isok) + + var float float64 + err = beegoInput.Bind(&float, "ft") + if float != 1.2 || err != nil { + t.Fatal("float should be equal to 1.2") + } + fmt.Println(float) + + ol := make([]int, 0, 2) + err = beegoInput.Bind(&ol, "ol") + if len(ol) != 2 || err != nil || ol[0] != 1 || ol[1] != 2 { + t.Fatal("ol should has two elements") + } + fmt.Println(ol) + + ul := make([]string, 0, 2) + err = beegoInput.Bind(&ul, "ul") + if len(ul) != 2 || err != nil || ul[0] != "str" || ul[1] != "array" { + t.Fatal("ul should has two elements") + } + fmt.Println(ul) + + type User struct { + Name string + } + user := User{} + err = beegoInput.Bind(&user, "user") + if err != nil || user.Name != "astaxie" { + t.Fatal("user should has name") + } + fmt.Println(user) +}