diff --git a/orm/models_test.go b/orm/models_test.go index 18ee0a59..3fed1a9b 100644 --- a/orm/models_test.go +++ b/orm/models_test.go @@ -3,6 +3,7 @@ package orm import ( "fmt" "os" + "strings" "time" _ "github.com/go-sql-driver/mysql" @@ -10,6 +11,56 @@ import ( _ "github.com/mattn/go-sqlite3" ) +// A true/false field. +type SliceStringField []string + +func (e SliceStringField) Value() []string { + return []string(e) +} + +func (e *SliceStringField) Set(d []string) { + *e = SliceStringField(d) +} + +func (e *SliceStringField) Add(v string) { + *e = append(*e, v) +} + +func (e *SliceStringField) String() string { + return strings.Join(e.Value(), ",") +} + +func (e *SliceStringField) FieldType() int { + return TypeCharField +} + +func (e *SliceStringField) SetRaw(value interface{}) error { + switch d := value.(type) { + case []string: + e.Set(d) + case string: + if len(d) > 0 { + parts := strings.Split(d, ",") + v := make([]string, 0, len(parts)) + for _, p := range parts { + v = append(v, strings.TrimSpace(p)) + } + e.Set(v) + } + default: + return fmt.Errorf(" unknown value `%v`", value) + } + return nil +} + +func (e *SliceStringField) RawValue() interface{} { + return e.String() +} + +func (e *SliceStringField) Clean() error { + return nil +} + type Data struct { Id int Boolean bool @@ -78,6 +129,7 @@ type User struct { Posts []*Post `orm:"reverse(many)" json:"-"` ShouldSkip string `orm:"-"` Nums int + Langs SliceStringField `orm:"size(100)"` } func (u *User) TableIndex() [][]string { diff --git a/orm/orm_test.go b/orm/orm_test.go index 156fe754..917e031f 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -472,6 +472,23 @@ The program—and web server—godoc processes Go source files to extract docume } } +func TestCustomField(t *testing.T) { + user := User{Id: 2} + err := dORM.Read(&user) + throwFailNow(t, err) + + user.Langs = append(user.Langs, "zh-CN", "en-US") + _, err = dORM.Update(&user, "Langs") + throwFailNow(t, err) + + user = User{Id: 2} + err = dORM.Read(&user) + throwFailNow(t, err) + throwFailNow(t, AssertIs(len(user.Langs), 2)) + throwFailNow(t, AssertIs(user.Langs[0], "zh-CN")) + throwFailNow(t, AssertIs(user.Langs[1], "en-US")) +} + func TestExpr(t *testing.T) { user := &User{} qs := dORM.QueryTable(user) @@ -728,7 +745,7 @@ func TestValues(t *testing.T) { var maps []Params qs := dORM.QueryTable("user") - num, err := qs.Values(&maps) + num, err := qs.OrderBy("Id").Values(&maps) throwFail(t, err) throwFail(t, AssertIs(num, 3)) if num == 3 { @@ -736,7 +753,7 @@ func TestValues(t *testing.T) { throwFail(t, AssertIs(maps[2]["Profile"], nil)) } - num, err = qs.Values(&maps, "UserName", "Profile__Age") + num, err = qs.OrderBy("Id").Values(&maps, "UserName", "Profile__Age") throwFail(t, err) throwFail(t, AssertIs(num, 3)) if num == 3 { @@ -750,7 +767,7 @@ func TestValuesList(t *testing.T) { var list []ParamsList qs := dORM.QueryTable("user") - num, err := qs.ValuesList(&list) + num, err := qs.OrderBy("Id").ValuesList(&list) throwFail(t, err) throwFail(t, AssertIs(num, 3)) if num == 3 { @@ -758,7 +775,7 @@ func TestValuesList(t *testing.T) { throwFail(t, AssertIs(list[2][9], nil)) } - num, err = qs.ValuesList(&list, "UserName", "Profile__Age") + num, err = qs.OrderBy("Id").ValuesList(&list, "UserName", "Profile__Age") throwFail(t, err) throwFail(t, AssertIs(num, 3)) if num == 3 {