diff --git a/orm/models_utils.go b/orm/models_utils.go index 0c3bee5d..31f8fb5a 100644 --- a/orm/models_utils.go +++ b/orm/models_utils.go @@ -109,7 +109,7 @@ func getTableUnique(val reflect.Value) [][]string { func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string { column := col if col == "" { - column = snakeString(sf.Name) + column = nameStrategyMap[nameStrategy](sf.Name) } switch ft { case RelForeignKey, RelOneToOne: diff --git a/orm/orm.go b/orm/orm.go index b00c974e..bcf6e4be 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -425,7 +425,7 @@ func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { var name string if table, ok := ptrStructOrTableName.(string); ok { - name = snakeString(table) + name = nameStrategyMap[defaultNameStrategy](table) if mi, ok := modelCache.get(name); ok { qs = newQuerySet(o, mi) } @@ -549,7 +549,7 @@ func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { al.Name = aliasName al.DriverName = driverName al.DB = db - + detectTZ(al) o := new(orm) diff --git a/orm/orm_raw.go b/orm/orm_raw.go index c8e741ea..c8ef4398 100644 --- a/orm/orm_raw.go +++ b/orm/orm_raw.go @@ -358,7 +358,7 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) var col string if col = tags["column"]; col == "" { - col = snakeString(fe.Name) + col = nameStrategyMap[nameStrategy](fe.Name) } if v, ok := columnsMp[col]; ok { value := reflect.ValueOf(v).Elem().Interface() @@ -509,7 +509,7 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) var col string if col = tags["column"]; col == "" { - col = snakeString(fe.Name) + col = nameStrategyMap[nameStrategy](fe.Name) } if v, ok := columnsMp[col]; ok { value := reflect.ValueOf(v).Elem().Interface() diff --git a/orm/utils.go b/orm/utils.go index b76bb2e3..78392771 100644 --- a/orm/utils.go +++ b/orm/utils.go @@ -23,6 +23,18 @@ import ( "time" ) +type fn func(string) string + +var ( + nameStrategyMap = map[string]fn{ + defaultNameStrategy: snakeString, + SnakeAcronymNameStrategy: snakeStringWithAcronym, + } + defaultNameStrategy = "snakeString" + SnakeAcronymNameStrategy = "snakeStringWithAcronym" + nameStrategy = defaultNameStrategy +) + // StrTo is the target string type StrTo string @@ -198,6 +210,27 @@ func ToInt64(value interface{}) (d int64) { return } +func snakeStringWithAcronym(s string) string { + data := make([]byte, 0, len(s)*2) + num := len(s) + for i := 0; i < num; i++ { + d := s[i] + before := false + after := false + if i > 0 { + before = s[i-1] >= 'a' && s[i-1] <= 'z' + } + if i+1 < num { + after = s[i+1] >= 'a' && s[i+1] <= 'z' + } + if i > 0 && d >= 'A' && d <= 'Z' && (before || after) { + data = append(data, '_') + } + data = append(data, d) + } + return strings.ToLower(string(data[:])) +} + // snake string, XxYy to xx_yy , XxYY to xx_y_y func snakeString(s string) string { data := make([]byte, 0, len(s)*2) @@ -216,6 +249,14 @@ func snakeString(s string) string { return strings.ToLower(string(data[:])) } +// SetNameStrategy set different name strategy +func SetNameStrategy(s string) { + if SnakeAcronymNameStrategy != s { + nameStrategy = defaultNameStrategy + } + nameStrategy = s +} + // camel string, xx_yy to XxYy func camelString(s string) string { data := make([]byte, 0, len(s)) diff --git a/orm/utils_test.go b/orm/utils_test.go index 11e76687..7d94cada 100644 --- a/orm/utils_test.go +++ b/orm/utils_test.go @@ -51,3 +51,20 @@ func TestSnakeString(t *testing.T) { } } } + +func TestSnakeStringWithAcronym(t *testing.T) { + camel := []string{"ID", "PicURL", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} + snake := []string{"id", "pic_url", "hello_world", "hello_world", "hel_lo_word", "pic_url1", "xy_xx"} + + answer := make(map[string]string) + for i, v := range camel { + answer[v] = snake[i] + } + + for _, v := range camel { + res := snakeStringWithAcronym(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +}