From ffe1d5212009bae069acfe3a2d02954c7acb1756 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 15 Jul 2020 10:04:22 +0800 Subject: [PATCH] Move orm to pkg/orm --- pkg/orm/README.md | 159 +++ pkg/orm/cmd.go | 283 +++++ pkg/orm/cmd_utils.go | 320 +++++ pkg/orm/db.go | 1902 +++++++++++++++++++++++++++++ pkg/orm/db_alias.go | 466 +++++++ pkg/orm/db_mysql.go | 183 +++ pkg/orm/db_oracle.go | 137 +++ pkg/orm/db_postgres.go | 189 +++ pkg/orm/db_sqlite.go | 161 +++ pkg/orm/db_tables.go | 482 ++++++++ pkg/orm/db_tidb.go | 63 + pkg/orm/db_utils.go | 177 +++ pkg/orm/models.go | 99 ++ pkg/orm/models_boot.go | 347 ++++++ pkg/orm/models_fields.go | 783 ++++++++++++ pkg/orm/models_info_f.go | 473 ++++++++ pkg/orm/models_info_m.go | 148 +++ pkg/orm/models_test.go | 497 ++++++++ pkg/orm/models_utils.go | 227 ++++ pkg/orm/orm.go | 579 +++++++++ pkg/orm/orm_conds.go | 153 +++ pkg/orm/orm_log.go | 222 ++++ pkg/orm/orm_object.go | 87 ++ pkg/orm/orm_querym2m.go | 140 +++ pkg/orm/orm_queryset.go | 300 +++++ pkg/orm/orm_raw.go | 867 +++++++++++++ pkg/orm/orm_test.go | 2494 ++++++++++++++++++++++++++++++++++++++ pkg/orm/qb.go | 62 + pkg/orm/qb_mysql.go | 185 +++ pkg/orm/qb_tidb.go | 182 +++ pkg/orm/types.go | 473 ++++++++ pkg/orm/utils.go | 319 +++++ pkg/orm/utils_test.go | 70 ++ 33 files changed, 13229 insertions(+) create mode 100644 pkg/orm/README.md create mode 100644 pkg/orm/cmd.go create mode 100644 pkg/orm/cmd_utils.go create mode 100644 pkg/orm/db.go create mode 100644 pkg/orm/db_alias.go create mode 100644 pkg/orm/db_mysql.go create mode 100644 pkg/orm/db_oracle.go create mode 100644 pkg/orm/db_postgres.go create mode 100644 pkg/orm/db_sqlite.go create mode 100644 pkg/orm/db_tables.go create mode 100644 pkg/orm/db_tidb.go create mode 100644 pkg/orm/db_utils.go create mode 100644 pkg/orm/models.go create mode 100644 pkg/orm/models_boot.go create mode 100644 pkg/orm/models_fields.go create mode 100644 pkg/orm/models_info_f.go create mode 100644 pkg/orm/models_info_m.go create mode 100644 pkg/orm/models_test.go create mode 100644 pkg/orm/models_utils.go create mode 100644 pkg/orm/orm.go create mode 100644 pkg/orm/orm_conds.go create mode 100644 pkg/orm/orm_log.go create mode 100644 pkg/orm/orm_object.go create mode 100644 pkg/orm/orm_querym2m.go create mode 100644 pkg/orm/orm_queryset.go create mode 100644 pkg/orm/orm_raw.go create mode 100644 pkg/orm/orm_test.go create mode 100644 pkg/orm/qb.go create mode 100644 pkg/orm/qb_mysql.go create mode 100644 pkg/orm/qb_tidb.go create mode 100644 pkg/orm/types.go create mode 100644 pkg/orm/utils.go create mode 100644 pkg/orm/utils_test.go diff --git a/pkg/orm/README.md b/pkg/orm/README.md new file mode 100644 index 00000000..6e808d2a --- /dev/null +++ b/pkg/orm/README.md @@ -0,0 +1,159 @@ +# beego orm + +[![Build Status](https://drone.io/github.com/astaxie/beego/status.png)](https://drone.io/github.com/astaxie/beego/latest) + +A powerful orm framework for go. + +It is heavily influenced by Django ORM, SQLAlchemy. + +**Support Database:** + +* MySQL: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) +* PostgreSQL: [github.com/lib/pq](https://github.com/lib/pq) +* Sqlite3: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) + +Passed all test, but need more feedback. + +**Features:** + +* full go type support +* easy for usage, simple CRUD operation +* auto join with relation table +* cross DataBase compatible query +* Raw SQL query / mapper without orm model +* full test keep stable and strong + +more features please read the docs + +**Install:** + + go get github.com/astaxie/beego/orm + +## Changelog + +* 2013-08-19: support table auto create +* 2013-08-13: update test for database types +* 2013-08-13: go type support, such as int8, uint8, byte, rune +* 2013-08-13: date / datetime timezone support very well + +## Quick Start + +#### Simple Usage + +```go +package main + +import ( + "fmt" + "github.com/astaxie/beego/orm" + _ "github.com/go-sql-driver/mysql" // import your used driver +) + +// Model Struct +type User struct { + Id int `orm:"auto"` + Name string `orm:"size(100)"` +} + +func init() { + // register model + orm.RegisterModel(new(User)) + + // set default database + orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30) + + // create table + orm.RunSyncdb("default", false, true) +} + +func main() { + o := orm.NewOrm() + + user := User{Name: "slene"} + + // insert + id, err := o.Insert(&user) + + // update + user.Name = "astaxie" + num, err := o.Update(&user) + + // read one + u := User{Id: user.Id} + err = o.Read(&u) + + // delete + num, err = o.Delete(&u) +} +``` + +#### Next with relation + +```go +type Post struct { + Id int `orm:"auto"` + Title string `orm:"size(100)"` + User *User `orm:"rel(fk)"` +} + +var posts []*Post +qs := o.QueryTable("post") +num, err := qs.Filter("User__Name", "slene").All(&posts) +``` + +#### Use Raw sql + +If you don't like ORM,use Raw SQL to query / mapping without ORM setting + +```go +var maps []Params +num, err := o.Raw("SELECT id FROM user WHERE name = ?", "slene").Values(&maps) +if num > 0 { + fmt.Println(maps[0]["id"]) +} +``` + +#### Transaction + +```go +o.Begin() +... +user := User{Name: "slene"} +id, err := o.Insert(&user) +if err == nil { + o.Commit() +} else { + o.Rollback() +} + +``` + +#### Debug Log Queries + +In development env, you can simple use + +```go +func main() { + orm.Debug = true +... +``` + +enable log queries. + +output include all queries, such as exec / prepare / transaction. + +like this: + +```go +[ORM] - 2013-08-09 13:18:16 - [Queries/default] - [ db.Exec / 0.4ms] - [INSERT INTO `user` (`name`) VALUES (?)] - `slene` +... +``` + +note: not recommend use this in product env. + +## Docs + +more details and examples in docs and test + +[documents](http://beego.me/docs/mvc/model/overview.md) + diff --git a/pkg/orm/cmd.go b/pkg/orm/cmd.go new file mode 100644 index 00000000..0ff4dc40 --- /dev/null +++ b/pkg/orm/cmd.go @@ -0,0 +1,283 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "flag" + "fmt" + "os" + "strings" +) + +type commander interface { + Parse([]string) + Run() error +} + +var ( + commands = make(map[string]commander) +) + +// print help. +func printHelp(errs ...string) { + content := `orm command usage: + + syncdb - auto create tables + sqlall - print sql of create tables + help - print this help +` + + if len(errs) > 0 { + fmt.Println(errs[0]) + } + fmt.Println(content) + os.Exit(2) +} + +// RunCommand listen for orm command and then run it if command arguments passed. +func RunCommand() { + if len(os.Args) < 2 || os.Args[1] != "orm" { + return + } + + BootStrap() + + args := argString(os.Args[2:]) + name := args.Get(0) + + if name == "help" { + printHelp() + } + + if cmd, ok := commands[name]; ok { + cmd.Parse(os.Args[3:]) + cmd.Run() + os.Exit(0) + } else { + if name == "" { + printHelp() + } else { + printHelp(fmt.Sprintf("unknown command %s", name)) + } + } +} + +// sync database struct command interface. +type commandSyncDb struct { + al *alias + force bool + verbose bool + noInfo bool + rtOnError bool +} + +// parse orm command line arguments. +func (d *commandSyncDb) Parse(args []string) { + var name string + + flagSet := flag.NewFlagSet("orm command: syncdb", flag.ExitOnError) + flagSet.StringVar(&name, "db", "default", "DataBase alias name") + flagSet.BoolVar(&d.force, "force", false, "drop tables before create") + flagSet.BoolVar(&d.verbose, "v", false, "verbose info") + flagSet.Parse(args) + + d.al = getDbAlias(name) +} + +// run orm line command. +func (d *commandSyncDb) Run() error { + var drops []string + if d.force { + drops = getDbDropSQL(d.al) + } + + db := d.al.DB + + if d.force { + for i, mi := range modelCache.allOrdered() { + query := drops[i] + if !d.noInfo { + fmt.Printf("drop table `%s`\n", mi.table) + } + _, err := db.Exec(query) + if d.verbose { + fmt.Printf(" %s\n\n", query) + } + if err != nil { + if d.rtOnError { + return err + } + fmt.Printf(" %s\n", err.Error()) + } + } + } + + sqls, indexes := getDbCreateSQL(d.al) + + tables, err := d.al.DbBaser.GetTables(db) + if err != nil { + if d.rtOnError { + return err + } + fmt.Printf(" %s\n", err.Error()) + } + + for i, mi := range modelCache.allOrdered() { + if tables[mi.table] { + if !d.noInfo { + fmt.Printf("table `%s` already exists, skip\n", mi.table) + } + + var fields []*fieldInfo + columns, err := d.al.DbBaser.GetColumns(db, mi.table) + if err != nil { + if d.rtOnError { + return err + } + fmt.Printf(" %s\n", err.Error()) + } + + for _, fi := range mi.fields.fieldsDB { + if _, ok := columns[fi.column]; !ok { + fields = append(fields, fi) + } + } + + for _, fi := range fields { + query := getColumnAddQuery(d.al, fi) + + if !d.noInfo { + fmt.Printf("add column `%s` for table `%s`\n", fi.fullName, mi.table) + } + + _, err := db.Exec(query) + if d.verbose { + fmt.Printf(" %s\n", query) + } + if err != nil { + if d.rtOnError { + return err + } + fmt.Printf(" %s\n", err.Error()) + } + } + + for _, idx := range indexes[mi.table] { + if !d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) { + if !d.noInfo { + fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table) + } + + query := idx.SQL + _, err := db.Exec(query) + if d.verbose { + fmt.Printf(" %s\n", query) + } + if err != nil { + if d.rtOnError { + return err + } + fmt.Printf(" %s\n", err.Error()) + } + } + } + + continue + } + + if !d.noInfo { + fmt.Printf("create table `%s` \n", mi.table) + } + + queries := []string{sqls[i]} + for _, idx := range indexes[mi.table] { + queries = append(queries, idx.SQL) + } + + for _, query := range queries { + _, err := db.Exec(query) + if d.verbose { + query = " " + strings.Join(strings.Split(query, "\n"), "\n ") + fmt.Println(query) + } + if err != nil { + if d.rtOnError { + return err + } + fmt.Printf(" %s\n", err.Error()) + } + } + if d.verbose { + fmt.Println("") + } + } + + return nil +} + +// database creation commander interface implement. +type commandSQLAll struct { + al *alias +} + +// parse orm command line arguments. +func (d *commandSQLAll) Parse(args []string) { + var name string + + flagSet := flag.NewFlagSet("orm command: sqlall", flag.ExitOnError) + flagSet.StringVar(&name, "db", "default", "DataBase alias name") + flagSet.Parse(args) + + d.al = getDbAlias(name) +} + +// run orm line command. +func (d *commandSQLAll) Run() error { + sqls, indexes := getDbCreateSQL(d.al) + var all []string + for i, mi := range modelCache.allOrdered() { + queries := []string{sqls[i]} + for _, idx := range indexes[mi.table] { + queries = append(queries, idx.SQL) + } + sql := strings.Join(queries, "\n") + all = append(all, sql) + } + fmt.Println(strings.Join(all, "\n\n")) + + return nil +} + +func init() { + commands["syncdb"] = new(commandSyncDb) + commands["sqlall"] = new(commandSQLAll) +} + +// RunSyncdb run syncdb command line. +// name means table's alias name. default is "default". +// force means run next sql if the current is error. +// verbose means show all info when running command or not. +func RunSyncdb(name string, force bool, verbose bool) error { + BootStrap() + + al := getDbAlias(name) + cmd := new(commandSyncDb) + cmd.al = al + cmd.force = force + cmd.noInfo = !verbose + cmd.verbose = verbose + cmd.rtOnError = true + return cmd.Run() +} diff --git a/pkg/orm/cmd_utils.go b/pkg/orm/cmd_utils.go new file mode 100644 index 00000000..61f17346 --- /dev/null +++ b/pkg/orm/cmd_utils.go @@ -0,0 +1,320 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "os" + "strings" +) + +type dbIndex struct { + Table string + Name string + SQL string +} + +// create database drop sql. +func getDbDropSQL(al *alias) (sqls []string) { + if len(modelCache.cache) == 0 { + fmt.Println("no Model found, need register your model") + os.Exit(2) + } + + Q := al.DbBaser.TableQuote() + + for _, mi := range modelCache.allOrdered() { + sqls = append(sqls, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q)) + } + return sqls +} + +// get database column type string. +func getColumnTyp(al *alias, fi *fieldInfo) (col string) { + T := al.DbBaser.DbTypes() + fieldType := fi.fieldType + fieldSize := fi.size + +checkColumn: + switch fieldType { + case TypeBooleanField: + col = T["bool"] + case TypeVarCharField: + if al.Driver == DRPostgres && fi.toText { + col = T["string-text"] + } else { + col = fmt.Sprintf(T["string"], fieldSize) + } + case TypeCharField: + col = fmt.Sprintf(T["string-char"], fieldSize) + case TypeTextField: + col = T["string-text"] + case TypeTimeField: + col = T["time.Time-clock"] + case TypeDateField: + col = T["time.Time-date"] + case TypeDateTimeField: + col = T["time.Time"] + case TypeBitField: + col = T["int8"] + case TypeSmallIntegerField: + col = T["int16"] + case TypeIntegerField: + col = T["int32"] + case TypeBigIntegerField: + if al.Driver == DRSqlite { + fieldType = TypeIntegerField + goto checkColumn + } + col = T["int64"] + case TypePositiveBitField: + col = T["uint8"] + case TypePositiveSmallIntegerField: + col = T["uint16"] + case TypePositiveIntegerField: + col = T["uint32"] + case TypePositiveBigIntegerField: + col = T["uint64"] + case TypeFloatField: + col = T["float64"] + case TypeDecimalField: + s := T["float64-decimal"] + if !strings.Contains(s, "%d") { + col = s + } else { + col = fmt.Sprintf(s, fi.digits, fi.decimals) + } + case TypeJSONField: + if al.Driver != DRPostgres { + fieldType = TypeVarCharField + goto checkColumn + } + col = T["json"] + case TypeJsonbField: + if al.Driver != DRPostgres { + fieldType = TypeVarCharField + goto checkColumn + } + col = T["jsonb"] + case RelForeignKey, RelOneToOne: + fieldType = fi.relModelInfo.fields.pk.fieldType + fieldSize = fi.relModelInfo.fields.pk.size + goto checkColumn + } + + return +} + +// create alter sql string. +func getColumnAddQuery(al *alias, fi *fieldInfo) string { + Q := al.DbBaser.TableQuote() + typ := getColumnTyp(al, fi) + + if !fi.null { + typ += " " + "NOT NULL" + } + + return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s", + Q, fi.mi.table, Q, + Q, fi.column, Q, + typ, getColumnDefault(fi), + ) +} + +// create database creation string. +func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) { + if len(modelCache.cache) == 0 { + fmt.Println("no Model found, need register your model") + os.Exit(2) + } + + Q := al.DbBaser.TableQuote() + T := al.DbBaser.DbTypes() + sep := fmt.Sprintf("%s, %s", Q, Q) + + tableIndexes = make(map[string][]dbIndex) + + for _, mi := range modelCache.allOrdered() { + sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) + sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.fullName) + sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) + + sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q) + + columns := make([]string, 0, len(mi.fields.fieldsDB)) + + sqlIndexes := [][]string{} + + for _, fi := range mi.fields.fieldsDB { + + column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q) + col := getColumnTyp(al, fi) + + if fi.auto { + switch al.Driver { + case DRSqlite, DRPostgres: + column += T["auto"] + default: + column += col + " " + T["auto"] + } + } else if fi.pk { + column += col + " " + T["pk"] + } else { + column += col + + if !fi.null { + column += " " + "NOT NULL" + } + + //if fi.initial.String() != "" { + // column += " DEFAULT " + fi.initial.String() + //} + + // Append attribute DEFAULT + column += getColumnDefault(fi) + + if fi.unique { + column += " " + "UNIQUE" + } + + if fi.index { + sqlIndexes = append(sqlIndexes, []string{fi.column}) + } + } + + if strings.Contains(column, "%COL%") { + column = strings.Replace(column, "%COL%", fi.column, -1) + } + + if fi.description != "" && al.Driver!=DRSqlite { + column += " " + fmt.Sprintf("COMMENT '%s'",fi.description) + } + + columns = append(columns, column) + } + + if mi.model != nil { + allnames := getTableUnique(mi.addrField) + if !mi.manual && len(mi.uniques) > 0 { + allnames = append(allnames, mi.uniques) + } + for _, names := range allnames { + cols := make([]string, 0, len(names)) + for _, name := range names { + if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { + cols = append(cols, fi.column) + } else { + panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.fullName)) + } + } + column := fmt.Sprintf(" UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q) + columns = append(columns, column) + } + } + + sql += strings.Join(columns, ",\n") + sql += "\n)" + + if al.Driver == DRMySQL { + var engine string + if mi.model != nil { + engine = getTableEngine(mi.addrField) + } + if engine == "" { + engine = al.Engine + } + sql += " ENGINE=" + engine + } + + sql += ";" + sqls = append(sqls, sql) + + if mi.model != nil { + for _, names := range getTableIndex(mi.addrField) { + cols := make([]string, 0, len(names)) + for _, name := range names { + if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { + cols = append(cols, fi.column) + } else { + panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.fullName)) + } + } + sqlIndexes = append(sqlIndexes, cols) + } + } + + for _, names := range sqlIndexes { + name := mi.table + "_" + strings.Join(names, "_") + cols := strings.Join(names, sep) + sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q) + + index := dbIndex{} + index.Table = mi.table + index.Name = name + index.SQL = sql + + tableIndexes[mi.table] = append(tableIndexes[mi.table], index) + } + + } + + return +} + +// Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands +func getColumnDefault(fi *fieldInfo) string { + var ( + v, t, d string + ) + + // Skip default attribute if field is in relations + if fi.rel || fi.reverse { + return v + } + + t = " DEFAULT '%s' " + + // These defaults will be useful if there no config value orm:"default" and NOT NULL is on + switch fi.fieldType { + case TypeTimeField, TypeDateField, TypeDateTimeField, TypeTextField: + return v + + case TypeBitField, TypeSmallIntegerField, TypeIntegerField, + TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField, + TypePositiveIntegerField, TypePositiveBigIntegerField, TypeFloatField, + TypeDecimalField: + t = " DEFAULT %s " + d = "0" + case TypeBooleanField: + t = " DEFAULT %s " + d = "FALSE" + case TypeJSONField, TypeJsonbField: + d = "{}" + } + + if fi.colDefault { + if !fi.initial.Exist() { + v = fmt.Sprintf(t, "") + } else { + v = fmt.Sprintf(t, fi.initial.String()) + } + } else { + if !fi.null { + v = fmt.Sprintf(t, d) + } + } + + return v +} diff --git a/pkg/orm/db.go b/pkg/orm/db.go new file mode 100644 index 00000000..9a1827e8 --- /dev/null +++ b/pkg/orm/db.go @@ -0,0 +1,1902 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "database/sql" + "errors" + "fmt" + "reflect" + "strings" + "time" +) + +const ( + formatTime = "15:04:05" + formatDate = "2006-01-02" + formatDateTime = "2006-01-02 15:04:05" +) + +var ( + // ErrMissPK missing pk error + ErrMissPK = errors.New("missed pk value") +) + +var ( + operators = map[string]bool{ + "exact": true, + "iexact": true, + "contains": true, + "icontains": true, + // "regex": true, + // "iregex": true, + "gt": true, + "gte": true, + "lt": true, + "lte": true, + "eq": true, + "nq": true, + "ne": true, + "startswith": true, + "endswith": true, + "istartswith": true, + "iendswith": true, + "in": true, + "between": true, + // "year": true, + // "month": true, + // "day": true, + // "week_day": true, + "isnull": true, + // "search": true, + } +) + +// an instance of dbBaser interface/ +type dbBase struct { + ins dbBaser +} + +// check dbBase implements dbBaser interface. +var _ dbBaser = new(dbBase) + +// get struct columns values as interface slice. +func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, autoFields []string, err error) { + if names == nil { + ns := make([]string, 0, len(cols)) + names = &ns + } + values = make([]interface{}, 0, len(cols)) + + for _, column := range cols { + var fi *fieldInfo + if fi, _ = mi.fields.GetByAny(column); fi != nil { + column = fi.column + } else { + panic(fmt.Errorf("wrong db field/column name `%s` for model `%s`", column, mi.fullName)) + } + if !fi.dbcol || fi.auto && skipAuto { + continue + } + value, err := d.collectFieldValue(mi, fi, ind, insert, tz) + if err != nil { + return nil, nil, err + } + + // ignore empty value auto field + if insert && fi.auto { + if fi.fieldType&IsPositiveIntegerField > 0 { + if vu, ok := value.(uint64); !ok || vu == 0 { + continue + } + } else { + if vu, ok := value.(int64); !ok || vu == 0 { + continue + } + } + autoFields = append(autoFields, fi.column) + } + + *names, values = append(*names, column), append(values, value) + } + + return +} + +// get one field value in struct column as interface. +func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) { + var value interface{} + if fi.pk { + _, value, _ = getExistPk(mi, ind) + } else { + field := ind.FieldByIndex(fi.fieldIndex) + if fi.isFielder { + f := field.Addr().Interface().(Fielder) + value = f.RawValue() + } else { + switch fi.fieldType { + case TypeBooleanField: + if nb, ok := field.Interface().(sql.NullBool); ok { + value = nil + if nb.Valid { + value = nb.Bool + } + } else if field.Kind() == reflect.Ptr { + if field.IsNil() { + value = nil + } else { + value = field.Elem().Bool() + } + } else { + value = field.Bool() + } + case TypeVarCharField, TypeCharField, TypeTextField, TypeJSONField, TypeJsonbField: + if ns, ok := field.Interface().(sql.NullString); ok { + value = nil + if ns.Valid { + value = ns.String + } + } else if field.Kind() == reflect.Ptr { + if field.IsNil() { + value = nil + } else { + value = field.Elem().String() + } + } else { + value = field.String() + } + case TypeFloatField, TypeDecimalField: + if nf, ok := field.Interface().(sql.NullFloat64); ok { + value = nil + if nf.Valid { + value = nf.Float64 + } + } else if field.Kind() == reflect.Ptr { + if field.IsNil() { + value = nil + } else { + value = field.Elem().Float() + } + } else { + vu := field.Interface() + if _, ok := vu.(float32); ok { + value, _ = StrTo(ToStr(vu)).Float64() + } else { + value = field.Float() + } + } + case TypeTimeField, TypeDateField, TypeDateTimeField: + value = field.Interface() + if t, ok := value.(time.Time); ok { + d.ins.TimeToDB(&t, tz) + if t.IsZero() { + value = nil + } else { + value = t + } + } + default: + switch { + case fi.fieldType&IsPositiveIntegerField > 0: + if field.Kind() == reflect.Ptr { + if field.IsNil() { + value = nil + } else { + value = field.Elem().Uint() + } + } else { + value = field.Uint() + } + case fi.fieldType&IsIntegerField > 0: + if ni, ok := field.Interface().(sql.NullInt64); ok { + value = nil + if ni.Valid { + value = ni.Int64 + } + } else if field.Kind() == reflect.Ptr { + if field.IsNil() { + value = nil + } else { + value = field.Elem().Int() + } + } else { + value = field.Int() + } + case fi.fieldType&IsRelField > 0: + if field.IsNil() { + value = nil + } else { + if _, vu, ok := getExistPk(fi.relModelInfo, reflect.Indirect(field)); ok { + value = vu + } else { + value = nil + } + } + if !fi.null && value == nil { + return nil, fmt.Errorf("field `%s` cannot be NULL", fi.fullName) + } + } + } + } + switch fi.fieldType { + case TypeTimeField, TypeDateField, TypeDateTimeField: + if fi.autoNow || fi.autoNowAdd && insert { + if insert { + if t, ok := value.(time.Time); ok && !t.IsZero() { + break + } + } + tnow := time.Now() + d.ins.TimeToDB(&tnow, tz) + value = tnow + if fi.isFielder { + f := field.Addr().Interface().(Fielder) + f.SetRaw(tnow.In(DefaultTimeLoc)) + } else if field.Kind() == reflect.Ptr { + v := tnow.In(DefaultTimeLoc) + field.Set(reflect.ValueOf(&v)) + } else { + field.Set(reflect.ValueOf(tnow.In(DefaultTimeLoc))) + } + } + case TypeJSONField, TypeJsonbField: + if s, ok := value.(string); (ok && len(s) == 0) || value == nil { + if fi.colDefault && fi.initial.Exist() { + value = fi.initial.String() + } else { + value = nil + } + } + } + } + return value, nil +} + +// create insert sql preparation statement object. +func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { + Q := d.ins.TableQuote() + + dbcols := make([]string, 0, len(mi.fields.dbcols)) + marks := make([]string, 0, len(mi.fields.dbcols)) + for _, fi := range mi.fields.fieldsDB { + if !fi.auto { + dbcols = append(dbcols, fi.column) + marks = append(marks, "?") + } + } + qmarks := strings.Join(marks, ", ") + sep := fmt.Sprintf("%s, %s", Q, Q) + columns := strings.Join(dbcols, sep) + + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) + + d.ins.ReplaceMarks(&query) + + d.ins.HasReturningID(mi, &query) + + stmt, err := q.Prepare(query) + return stmt, query, err +} + +// insert struct with prepared statement and given struct reflect value. +func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { + values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) + if err != nil { + return 0, err + } + + if d.ins.HasReturningID(mi, nil) { + row := stmt.QueryRow(values...) + var id int64 + err := row.Scan(&id) + return id, err + } + res, err := stmt.Exec(values...) + if err == nil { + return res.LastInsertId() + } + return 0, err +} + +// query sql ,read records and persist in dbBaser. +func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { + var whereCols []string + var args []interface{} + + // if specify cols length > 0, then use it for where condition. + if len(cols) > 0 { + var err error + whereCols = make([]string, 0, len(cols)) + args, _, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz) + if err != nil { + return err + } + } else { + // default use pk value as where condtion. + pkColumn, pkValue, ok := getExistPk(mi, ind) + if !ok { + return ErrMissPK + } + whereCols = []string{pkColumn} + args = append(args, pkValue) + } + + Q := d.ins.TableQuote() + + sep := fmt.Sprintf("%s, %s", Q, Q) + sels := strings.Join(mi.fields.dbcols, sep) + colsNum := len(mi.fields.dbcols) + + sep = fmt.Sprintf("%s = ? AND %s", Q, Q) + wheres := strings.Join(whereCols, sep) + + forUpdate := "" + if isForUpdate { + forUpdate = "FOR UPDATE" + } + + query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ? %s", Q, sels, Q, Q, mi.table, Q, Q, wheres, Q, forUpdate) + + refs := make([]interface{}, colsNum) + for i := range refs { + var ref interface{} + refs[i] = &ref + } + + d.ins.ReplaceMarks(&query) + + row := q.QueryRow(query, args...) + if err := row.Scan(refs...); err != nil { + if err == sql.ErrNoRows { + return ErrNoRows + } + return err + } + elm := reflect.New(mi.addrField.Elem().Type()) + mind := reflect.Indirect(elm) + d.setColsValues(mi, &mind, mi.fields.dbcols, refs, tz) + ind.Set(mind) + return nil +} + +// execute insert sql dbQuerier with given struct reflect.Value. +func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { + names := make([]string, 0, len(mi.fields.dbcols)) + values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) + if err != nil { + return 0, err + } + + id, err := d.InsertValue(q, mi, false, names, values) + if err != nil { + return 0, err + } + + if len(autoFields) > 0 { + err = d.ins.setval(q, mi, autoFields) + } + return id, err +} + +// multi-insert sql with given slice struct reflect.Value. +func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { + var ( + cnt int64 + nums int + values []interface{} + names []string + ) + + // typ := reflect.Indirect(mi.addrField).Type() + + length, autoFields := sind.Len(), make([]string, 0, 1) + + for i := 1; i <= length; i++ { + + ind := reflect.Indirect(sind.Index(i - 1)) + + // Is this needed ? + // if !ind.Type().AssignableTo(typ) { + // return cnt, ErrArgs + // } + + if i == 1 { + var ( + vus []interface{} + err error + ) + vus, autoFields, err = d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) + if err != nil { + return cnt, err + } + values = make([]interface{}, bulk*len(vus)) + nums += copy(values, vus) + } else { + vus, _, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, nil, tz) + if err != nil { + return cnt, err + } + + if len(vus) != len(names) { + return cnt, ErrArgs + } + + nums += copy(values[nums:], vus) + } + + if i > 1 && i%bulk == 0 || length == i { + num, err := d.InsertValue(q, mi, true, names, values[:nums]) + if err != nil { + return cnt, err + } + cnt += num + nums = 0 + } + } + + var err error + if len(autoFields) > 0 { + err = d.ins.setval(q, mi, autoFields) + } + + return cnt, err +} + +// execute insert sql with given struct and given values. +// insert the given values, not the field values in struct. +func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { + Q := d.ins.TableQuote() + + marks := make([]string, len(names)) + for i := range marks { + marks[i] = "?" + } + + sep := fmt.Sprintf("%s, %s", Q, Q) + qmarks := strings.Join(marks, ", ") + columns := strings.Join(names, sep) + + multi := len(values) / len(names) + + if isMulti && multi > 1 { + qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks + } + + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) + + d.ins.ReplaceMarks(&query) + + if isMulti || !d.ins.HasReturningID(mi, &query) { + res, err := q.Exec(query, values...) + if err == nil { + if isMulti { + return res.RowsAffected() + } + return res.LastInsertId() + } + return 0, err + } + row := q.QueryRow(query, values...) + var id int64 + err := row.Scan(&id) + return id, err +} + +// InsertOrUpdate a row +// If your primary key or unique column conflict will update +// If no will insert +func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { + args0 := "" + iouStr := "" + argsMap := map[string]string{} + switch a.Driver { + case DRMySQL: + iouStr = "ON DUPLICATE KEY UPDATE" + case DRPostgres: + if len(args) == 0 { + return 0, fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName) + } + args0 = strings.ToLower(args[0]) + iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0) + default: + return 0, fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName) + } + + //Get on the key-value pairs + for _, v := range args { + kv := strings.Split(v, "=") + if len(kv) == 2 { + argsMap[strings.ToLower(kv[0])] = kv[1] + } + } + + isMulti := false + names := make([]string, 0, len(mi.fields.dbcols)-1) + Q := d.ins.TableQuote() + values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ) + + if err != nil { + return 0, err + } + + marks := make([]string, len(names)) + updateValues := make([]interface{}, 0) + updates := make([]string, len(names)) + var conflitValue interface{} + for i, v := range names { + // identifier in database may not be case-sensitive, so quote it + v = fmt.Sprintf("%s%s%s", Q, v, Q) + marks[i] = "?" + valueStr := argsMap[strings.ToLower(v)] + if v == args0 { + conflitValue = values[i] + } + if valueStr != "" { + switch a.Driver { + case DRMySQL: + updates[i] = v + "=" + valueStr + case DRPostgres: + if conflitValue != nil { + //postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values + updates[i] = fmt.Sprintf("%s=(select %s from %s where %s = ? )", v, valueStr, mi.table, args0) + updateValues = append(updateValues, conflitValue) + } else { + return 0, fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v) + } + } + } else { + updates[i] = v + "=?" + updateValues = append(updateValues, values[i]) + } + } + + values = append(values, updateValues...) + + sep := fmt.Sprintf("%s, %s", Q, Q) + qmarks := strings.Join(marks, ", ") + qupdates := strings.Join(updates, ", ") + columns := strings.Join(names, sep) + + multi := len(values) / len(names) + + if isMulti { + qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks + } + //conflitValue maybe is a int,can`t use fmt.Sprintf + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr) + + d.ins.ReplaceMarks(&query) + + if isMulti || !d.ins.HasReturningID(mi, &query) { + res, err := q.Exec(query, values...) + if err == nil { + if isMulti { + return res.RowsAffected() + } + return res.LastInsertId() + } + return 0, err + } + + row := q.QueryRow(query, values...) + var id int64 + err = row.Scan(&id) + if err != nil && err.Error() == `pq: syntax error at or near "ON"` { + err = fmt.Errorf("postgres version must 9.5 or higher") + } + return id, err +} + +// execute update sql dbQuerier with given struct reflect.Value. +func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { + pkName, pkValue, ok := getExistPk(mi, ind) + if !ok { + return 0, ErrMissPK + } + + var setNames []string + + // if specify cols length is zero, then commit all columns. + if len(cols) == 0 { + cols = mi.fields.dbcols + setNames = make([]string, 0, len(mi.fields.dbcols)-1) + } else { + setNames = make([]string, 0, len(cols)) + } + + setValues, _, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz) + if err != nil { + return 0, err + } + + var findAutoNowAdd, findAutoNow bool + var index int + for i, col := range setNames { + if mi.fields.GetByColumn(col).autoNowAdd { + index = i + findAutoNowAdd = true + } + if mi.fields.GetByColumn(col).autoNow { + findAutoNow = true + } + } + if findAutoNowAdd { + setNames = append(setNames[0:index], setNames[index+1:]...) + setValues = append(setValues[0:index], setValues[index+1:]...) + } + + if !findAutoNow { + for col, info := range mi.fields.columns { + if info.autoNow { + setNames = append(setNames, col) + setValues = append(setValues, time.Now()) + } + } + } + + setValues = append(setValues, pkValue) + + Q := d.ins.TableQuote() + + sep := fmt.Sprintf("%s = ?, %s", Q, Q) + setColumns := strings.Join(setNames, sep) + + query := fmt.Sprintf("UPDATE %s%s%s SET %s%s%s = ? WHERE %s%s%s = ?", Q, mi.table, Q, Q, setColumns, Q, Q, pkName, Q) + + d.ins.ReplaceMarks(&query) + + res, err := q.Exec(query, setValues...) + if err == nil { + return res.RowsAffected() + } + return 0, err +} + +// execute delete sql dbQuerier with given struct reflect.Value. +// delete index is pk. +func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { + var whereCols []string + var args []interface{} + // if specify cols length > 0, then use it for where condition. + if len(cols) > 0 { + var err error + whereCols = make([]string, 0, len(cols)) + args, _, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz) + if err != nil { + return 0, err + } + } else { + // default use pk value as where condtion. + pkColumn, pkValue, ok := getExistPk(mi, ind) + if !ok { + return 0, ErrMissPK + } + whereCols = []string{pkColumn} + args = append(args, pkValue) + } + + Q := d.ins.TableQuote() + + sep := fmt.Sprintf("%s = ? AND %s", Q, Q) + wheres := strings.Join(whereCols, sep) + + query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q) + + d.ins.ReplaceMarks(&query) + res, err := q.Exec(query, args...) + if err == nil { + num, err := res.RowsAffected() + if err != nil { + return 0, err + } + if num > 0 { + if mi.fields.pk.auto { + if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { + ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(0) + } else { + ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0) + } + } + err := d.deleteRels(q, mi, args, tz) + if err != nil { + return num, err + } + } + return num, err + } + return 0, err +} + +// update table-related record by querySet. +// need querySet not struct reflect.Value to update related records. +func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { + columns := make([]string, 0, len(params)) + values := make([]interface{}, 0, len(params)) + for col, val := range params { + if fi, ok := mi.fields.GetByAny(col); !ok || !fi.dbcol { + panic(fmt.Errorf("wrong field/column name `%s`", col)) + } else { + columns = append(columns, fi.column) + values = append(values, val) + } + } + + if len(columns) == 0 { + panic(fmt.Errorf("update params cannot empty")) + } + + tables := newDbTables(mi, d.ins) + if qs != nil { + tables.parseRelated(qs.related, qs.relDepth) + } + + where, args := tables.getCondSQL(cond, false, tz) + + values = append(values, args...) + + join := tables.getJoinSQL() + + var query, T string + + Q := d.ins.TableQuote() + + if d.ins.SupportUpdateJoin() { + T = "T0." + } + + cols := make([]string, 0, len(columns)) + + for i, v := range columns { + col := fmt.Sprintf("%s%s%s%s", T, Q, v, Q) + if c, ok := values[i].(colValue); ok { + switch c.opt { + case ColAdd: + cols = append(cols, col+" = "+col+" + ?") + case ColMinus: + cols = append(cols, col+" = "+col+" - ?") + case ColMultiply: + cols = append(cols, col+" = "+col+" * ?") + case ColExcept: + cols = append(cols, col+" = "+col+" / ?") + case ColBitAnd: + cols = append(cols, col+" = "+col+" & ?") + case ColBitRShift: + cols = append(cols, col+" = "+col+" >> ?") + case ColBitLShift: + cols = append(cols, col+" = "+col+" << ?") + case ColBitXOR: + cols = append(cols, col+" = "+col+" ^ ?") + case ColBitOr: + cols = append(cols, col+" = "+col+" | ?") + } + values[i] = c.value + } else { + cols = append(cols, col+" = ?") + } + } + + sets := strings.Join(cols, ", ") + " " + + if d.ins.SupportUpdateJoin() { + query = fmt.Sprintf("UPDATE %s%s%s T0 %sSET %s%s", Q, mi.table, Q, join, sets, where) + } else { + supQuery := fmt.Sprintf("SELECT T0.%s%s%s FROM %s%s%s T0 %s%s", Q, mi.fields.pk.column, Q, Q, mi.table, Q, join, where) + query = fmt.Sprintf("UPDATE %s%s%s SET %sWHERE %s%s%s IN ( %s )", Q, mi.table, Q, sets, Q, mi.fields.pk.column, Q, supQuery) + } + + d.ins.ReplaceMarks(&query) + var err error + var res sql.Result + if qs != nil && qs.forContext { + res, err = q.ExecContext(qs.ctx, query, values...) + } else { + res, err = q.Exec(query, values...) + } + if err == nil { + return res.RowsAffected() + } + return 0, err +} + +// delete related records. +// do UpdateBanch or DeleteBanch by condition of tables' relationship. +func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { + for _, fi := range mi.fields.fieldsReverse { + fi = fi.reverseFieldInfo + switch fi.onDelete { + case odCascade: + cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) + _, err := d.DeleteBatch(q, nil, fi.mi, cond, tz) + if err != nil { + return err + } + case odSetDefault, odSetNULL: + cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) + params := Params{fi.column: nil} + if fi.onDelete == odSetDefault { + params[fi.column] = fi.initial.String() + } + _, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz) + if err != nil { + return err + } + case odDoNothing: + } + } + return nil +} + +// delete table-related records. +func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { + tables := newDbTables(mi, d.ins) + tables.skipEnd = true + + if qs != nil { + tables.parseRelated(qs.related, qs.relDepth) + } + + if cond == nil || cond.IsEmpty() { + panic(fmt.Errorf("delete operation cannot execute without condition")) + } + + Q := d.ins.TableQuote() + + where, args := tables.getCondSQL(cond, false, tz) + join := tables.getJoinSQL() + + cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q) + query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s", cols, Q, mi.table, Q, join, where) + + d.ins.ReplaceMarks(&query) + + var rs *sql.Rows + r, err := q.Query(query, args...) + if err != nil { + return 0, err + } + rs = r + defer rs.Close() + + var ref interface{} + args = make([]interface{}, 0) + cnt := 0 + for rs.Next() { + if err := rs.Scan(&ref); err != nil { + return 0, err + } + pkValue, err := d.convertValueFromDB(mi.fields.pk, reflect.ValueOf(ref).Interface(), tz) + if err != nil { + return 0, err + } + args = append(args, pkValue) + cnt++ + } + + if cnt == 0 { + return 0, nil + } + + marks := make([]string, len(args)) + for i := range marks { + marks[i] = "?" + } + sqlIn := fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) + query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn) + + d.ins.ReplaceMarks(&query) + var res sql.Result + if qs != nil && qs.forContext { + res, err = q.ExecContext(qs.ctx, query, args...) + } else { + res, err = q.Exec(query, args...) + } + if err == nil { + num, err := res.RowsAffected() + if err != nil { + return 0, err + } + if num > 0 { + err := d.deleteRels(q, mi, args, tz) + if err != nil { + return num, err + } + } + return num, nil + } + return 0, err +} + +// read related records. +func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { + + val := reflect.ValueOf(container) + ind := reflect.Indirect(val) + + errTyp := true + one := true + isPtr := true + + if val.Kind() == reflect.Ptr { + fn := "" + if ind.Kind() == reflect.Slice { + one = false + typ := ind.Type().Elem() + switch typ.Kind() { + case reflect.Ptr: + fn = getFullName(typ.Elem()) + case reflect.Struct: + isPtr = false + fn = getFullName(typ) + } + } else { + fn = getFullName(ind.Type()) + } + errTyp = fn != mi.fullName + } + + if errTyp { + if one { + panic(fmt.Errorf("wrong object type `%s` for rows scan, need *%s", val.Type(), mi.fullName)) + } else { + panic(fmt.Errorf("wrong object type `%s` for rows scan, need *[]*%s or *[]%s", val.Type(), mi.fullName, mi.fullName)) + } + } + + rlimit := qs.limit + offset := qs.offset + + Q := d.ins.TableQuote() + + var tCols []string + if len(cols) > 0 { + hasRel := len(qs.related) > 0 || qs.relDepth > 0 + tCols = make([]string, 0, len(cols)) + var maps map[string]bool + if hasRel { + maps = make(map[string]bool) + } + for _, col := range cols { + if fi, ok := mi.fields.GetByAny(col); ok { + tCols = append(tCols, fi.column) + if hasRel { + maps[fi.column] = true + } + } else { + return 0, fmt.Errorf("wrong field/column name `%s`", col) + } + } + if hasRel { + for _, fi := range mi.fields.fieldsDB { + if fi.fieldType&IsRelField > 0 { + if !maps[fi.column] { + tCols = append(tCols, fi.column) + } + } + } + } + } else { + tCols = mi.fields.dbcols + } + + colsNum := len(tCols) + sep := fmt.Sprintf("%s, T0.%s", Q, Q) + sels := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(tCols, sep), Q) + + tables := newDbTables(mi, d.ins) + tables.parseRelated(qs.related, qs.relDepth) + + where, args := tables.getCondSQL(cond, false, tz) + groupBy := tables.getGroupSQL(qs.groups) + orderBy := tables.getOrderSQL(qs.orders) + limit := tables.getLimitSQL(mi, offset, rlimit) + join := tables.getJoinSQL() + + for _, tbl := range tables.tables { + if tbl.sel { + colsNum += len(tbl.mi.fields.dbcols) + sep := fmt.Sprintf("%s, %s.%s", Q, tbl.index, Q) + sels += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q) + } + } + + sqlSelect := "SELECT" + if qs.distinct { + sqlSelect += " DISTINCT" + } + query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) + + if qs.forupdate { + query += " FOR UPDATE" + } + + d.ins.ReplaceMarks(&query) + + var rs *sql.Rows + var err error + if qs != nil && qs.forContext { + rs, err = q.QueryContext(qs.ctx, query, args...) + if err != nil { + return 0, err + } + } else { + rs, err = q.Query(query, args...) + if err != nil { + return 0, err + } + } + + refs := make([]interface{}, colsNum) + for i := range refs { + var ref interface{} + refs[i] = &ref + } + + defer rs.Close() + + slice := ind + + var cnt int64 + for rs.Next() { + if one && cnt == 0 || !one { + if err := rs.Scan(refs...); err != nil { + return 0, err + } + + elm := reflect.New(mi.addrField.Elem().Type()) + mind := reflect.Indirect(elm) + + cacheV := make(map[string]*reflect.Value) + cacheM := make(map[string]*modelInfo) + trefs := refs + + d.setColsValues(mi, &mind, tCols, refs[:len(tCols)], tz) + trefs = refs[len(tCols):] + + for _, tbl := range tables.tables { + // loop selected tables + if tbl.sel { + last := mind + names := "" + mmi := mi + // loop cascade models + for _, name := range tbl.names { + names += name + if val, ok := cacheV[names]; ok { + last = *val + mmi = cacheM[names] + } else { + fi := mmi.fields.GetByName(name) + lastm := mmi + mmi = fi.relModelInfo + field := last + if last.Kind() != reflect.Invalid { + field = reflect.Indirect(last.FieldByIndex(fi.fieldIndex)) + if field.IsValid() { + d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)], tz) + for _, fi := range mmi.fields.fieldsReverse { + if fi.inModel && fi.reverseFieldInfo.mi == lastm { + if fi.reverseFieldInfo != nil { + f := field.FieldByIndex(fi.fieldIndex) + if f.Kind() == reflect.Ptr { + f.Set(last.Addr()) + } + } + } + } + last = field + } + } + cacheV[names] = &field + cacheM[names] = mmi + } + } + trefs = trefs[len(mmi.fields.dbcols):] + } + } + + if one { + ind.Set(mind) + } else { + if cnt == 0 { + // you can use a empty & caped container list + // orm will not replace it + if ind.Len() != 0 { + // if container is not empty + // create a new one + slice = reflect.New(ind.Type()).Elem() + } + } + + if isPtr { + slice = reflect.Append(slice, mind.Addr()) + } else { + slice = reflect.Append(slice, mind) + } + } + } + cnt++ + } + + if !one { + if cnt > 0 { + ind.Set(slice) + } else { + // when a result is empty and container is nil + // to set a empty container + if ind.IsNil() { + ind.Set(reflect.MakeSlice(ind.Type(), 0, 0)) + } + } + } + + return cnt, nil +} + +// excute count sql and return count result int64. +func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { + tables := newDbTables(mi, d.ins) + tables.parseRelated(qs.related, qs.relDepth) + + where, args := tables.getCondSQL(cond, false, tz) + groupBy := tables.getGroupSQL(qs.groups) + tables.getOrderSQL(qs.orders) + join := tables.getJoinSQL() + + Q := d.ins.TableQuote() + + query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s%s", Q, mi.table, Q, join, where, groupBy) + + if groupBy != "" { + query = fmt.Sprintf("SELECT COUNT(*) FROM (%s) AS T", query) + } + + d.ins.ReplaceMarks(&query) + + var row *sql.Row + if qs != nil && qs.forContext { + row = q.QueryRowContext(qs.ctx, query, args...) + } else { + row = q.QueryRow(query, args...) + } + err = row.Scan(&cnt) + return +} + +// generate sql with replacing operator string placeholders and replaced values. +func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { + var sql string + params := getFlatParams(fi, args, tz) + + if len(params) == 0 { + panic(fmt.Errorf("operator `%s` need at least one args", operator)) + } + arg := params[0] + + switch operator { + case "in": + marks := make([]string, len(params)) + for i := range marks { + marks[i] = "?" + } + sql = fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) + case "between": + if len(params) != 2 { + panic(fmt.Errorf("operator `%s` need 2 args not %d", operator, len(params))) + } + sql = "BETWEEN ? AND ?" + default: + if len(params) > 1 { + panic(fmt.Errorf("operator `%s` need 1 args not %d", operator, len(params))) + } + sql = d.ins.OperatorSQL(operator) + switch operator { + case "exact": + if arg == nil { + params[0] = "IS NULL" + } + case "iexact", "contains", "icontains", "startswith", "endswith", "istartswith", "iendswith": + param := strings.Replace(ToStr(arg), `%`, `\%`, -1) + switch operator { + case "iexact": + case "contains", "icontains": + param = fmt.Sprintf("%%%s%%", param) + case "startswith", "istartswith": + param = fmt.Sprintf("%s%%", param) + case "endswith", "iendswith": + param = fmt.Sprintf("%%%s", param) + } + params[0] = param + case "isnull": + if b, ok := arg.(bool); ok { + if b { + sql = "IS NULL" + } else { + sql = "IS NOT NULL" + } + params = nil + } else { + panic(fmt.Errorf("operator `%s` need a bool value not `%T`", operator, arg)) + } + } + } + return sql, params +} + +// gernerate sql string with inner function, such as UPPER(text). +func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) { + // default not use +} + +// set values to struct column. +func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) { + for i, column := range cols { + val := reflect.Indirect(reflect.ValueOf(values[i])).Interface() + + fi := mi.fields.GetByColumn(column) + + field := ind.FieldByIndex(fi.fieldIndex) + + value, err := d.convertValueFromDB(fi, val, tz) + if err != nil { + panic(fmt.Errorf("Raw value: `%v` %s", val, err.Error())) + } + + _, err = d.setFieldValue(fi, value, field) + + if err != nil { + panic(fmt.Errorf("Raw value: `%v` %s", val, err.Error())) + } + } +} + +// convert value from database result to value following in field type. +func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) { + if val == nil { + return nil, nil + } + + var value interface{} + var tErr error + + var str *StrTo + switch v := val.(type) { + case []byte: + s := StrTo(string(v)) + str = &s + case string: + s := StrTo(v) + str = &s + } + + fieldType := fi.fieldType + +setValue: + switch { + case fieldType == TypeBooleanField: + if str == nil { + switch v := val.(type) { + case int64: + b := v == 1 + value = b + default: + s := StrTo(ToStr(v)) + str = &s + } + } + if str != nil { + b, err := str.Bool() + if err != nil { + tErr = err + goto end + } + value = b + } + case fieldType == TypeVarCharField || fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField: + if str == nil { + value = ToStr(val) + } else { + value = str.String() + } + case fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField: + if str == nil { + switch t := val.(type) { + case time.Time: + d.ins.TimeFromDB(&t, tz) + value = t + default: + s := StrTo(ToStr(t)) + str = &s + } + } + if str != nil { + s := str.String() + var ( + t time.Time + err error + ) + if len(s) >= 19 { + s = s[:19] + t, err = time.ParseInLocation(formatDateTime, s, tz) + } else if len(s) >= 10 { + if len(s) > 10 { + s = s[:10] + } + t, err = time.ParseInLocation(formatDate, s, tz) + } else if len(s) >= 8 { + if len(s) > 8 { + s = s[:8] + } + t, err = time.ParseInLocation(formatTime, s, tz) + } + t = t.In(DefaultTimeLoc) + + if err != nil && s != "00:00:00" && s != "0000-00-00" && s != "0000-00-00 00:00:00" { + tErr = err + goto end + } + value = t + } + case fieldType&IsIntegerField > 0: + if str == nil { + s := StrTo(ToStr(val)) + str = &s + } + if str != nil { + var err error + switch fieldType { + case TypeBitField: + _, err = str.Int8() + case TypeSmallIntegerField: + _, err = str.Int16() + case TypeIntegerField: + _, err = str.Int32() + case TypeBigIntegerField: + _, err = str.Int64() + case TypePositiveBitField: + _, err = str.Uint8() + case TypePositiveSmallIntegerField: + _, err = str.Uint16() + case TypePositiveIntegerField: + _, err = str.Uint32() + case TypePositiveBigIntegerField: + _, err = str.Uint64() + } + if err != nil { + tErr = err + goto end + } + if fieldType&IsPositiveIntegerField > 0 { + v, _ := str.Uint64() + value = v + } else { + v, _ := str.Int64() + value = v + } + } + case fieldType == TypeFloatField || fieldType == TypeDecimalField: + if str == nil { + switch v := val.(type) { + case float64: + value = v + default: + s := StrTo(ToStr(v)) + str = &s + } + } + if str != nil { + v, err := str.Float64() + if err != nil { + tErr = err + goto end + } + value = v + } + case fieldType&IsRelField > 0: + fi = fi.relModelInfo.fields.pk + fieldType = fi.fieldType + goto setValue + } + +end: + if tErr != nil { + err := fmt.Errorf("convert to `%s` failed, field: %s err: %s", fi.addrValue.Type(), fi.fullName, tErr) + return nil, err + } + + return value, nil + +} + +// set one value to struct column field. +func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) { + + fieldType := fi.fieldType + isNative := !fi.isFielder + +setValue: + switch { + case fieldType == TypeBooleanField: + if isNative { + if nb, ok := field.Interface().(sql.NullBool); ok { + if value == nil { + nb.Valid = false + } else { + nb.Bool = value.(bool) + nb.Valid = true + } + field.Set(reflect.ValueOf(nb)) + } else if field.Kind() == reflect.Ptr { + if value != nil { + v := value.(bool) + field.Set(reflect.ValueOf(&v)) + } + } else { + if value == nil { + value = false + } + field.SetBool(value.(bool)) + } + } + case fieldType == TypeVarCharField || fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField: + if isNative { + if ns, ok := field.Interface().(sql.NullString); ok { + if value == nil { + ns.Valid = false + } else { + ns.String = value.(string) + ns.Valid = true + } + field.Set(reflect.ValueOf(ns)) + } else if field.Kind() == reflect.Ptr { + if value != nil { + v := value.(string) + field.Set(reflect.ValueOf(&v)) + } + } else { + if value == nil { + value = "" + } + field.SetString(value.(string)) + } + } + case fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField: + if isNative { + if value == nil { + value = time.Time{} + } else if field.Kind() == reflect.Ptr { + if value != nil { + v := value.(time.Time) + field.Set(reflect.ValueOf(&v)) + } + } else { + field.Set(reflect.ValueOf(value)) + } + } + case fieldType == TypePositiveBitField && field.Kind() == reflect.Ptr: + if value != nil { + v := uint8(value.(uint64)) + field.Set(reflect.ValueOf(&v)) + } + case fieldType == TypePositiveSmallIntegerField && field.Kind() == reflect.Ptr: + if value != nil { + v := uint16(value.(uint64)) + field.Set(reflect.ValueOf(&v)) + } + case fieldType == TypePositiveIntegerField && field.Kind() == reflect.Ptr: + if value != nil { + if field.Type() == reflect.TypeOf(new(uint)) { + v := uint(value.(uint64)) + field.Set(reflect.ValueOf(&v)) + } else { + v := uint32(value.(uint64)) + field.Set(reflect.ValueOf(&v)) + } + } + case fieldType == TypePositiveBigIntegerField && field.Kind() == reflect.Ptr: + if value != nil { + v := value.(uint64) + field.Set(reflect.ValueOf(&v)) + } + case fieldType == TypeBitField && field.Kind() == reflect.Ptr: + if value != nil { + v := int8(value.(int64)) + field.Set(reflect.ValueOf(&v)) + } + case fieldType == TypeSmallIntegerField && field.Kind() == reflect.Ptr: + if value != nil { + v := int16(value.(int64)) + field.Set(reflect.ValueOf(&v)) + } + case fieldType == TypeIntegerField && field.Kind() == reflect.Ptr: + if value != nil { + if field.Type() == reflect.TypeOf(new(int)) { + v := int(value.(int64)) + field.Set(reflect.ValueOf(&v)) + } else { + v := int32(value.(int64)) + field.Set(reflect.ValueOf(&v)) + } + } + case fieldType == TypeBigIntegerField && field.Kind() == reflect.Ptr: + if value != nil { + v := value.(int64) + field.Set(reflect.ValueOf(&v)) + } + case fieldType&IsIntegerField > 0: + if fieldType&IsPositiveIntegerField > 0 { + if isNative { + if value == nil { + value = uint64(0) + } + field.SetUint(value.(uint64)) + } + } else { + if isNative { + if ni, ok := field.Interface().(sql.NullInt64); ok { + if value == nil { + ni.Valid = false + } else { + ni.Int64 = value.(int64) + ni.Valid = true + } + field.Set(reflect.ValueOf(ni)) + } else { + if value == nil { + value = int64(0) + } + field.SetInt(value.(int64)) + } + } + } + case fieldType == TypeFloatField || fieldType == TypeDecimalField: + if isNative { + if nf, ok := field.Interface().(sql.NullFloat64); ok { + if value == nil { + nf.Valid = false + } else { + nf.Float64 = value.(float64) + nf.Valid = true + } + field.Set(reflect.ValueOf(nf)) + } else if field.Kind() == reflect.Ptr { + if value != nil { + if field.Type() == reflect.TypeOf(new(float32)) { + v := float32(value.(float64)) + field.Set(reflect.ValueOf(&v)) + } else { + v := value.(float64) + field.Set(reflect.ValueOf(&v)) + } + } + } else { + + if value == nil { + value = float64(0) + } + field.SetFloat(value.(float64)) + } + } + case fieldType&IsRelField > 0: + if value != nil { + fieldType = fi.relModelInfo.fields.pk.fieldType + mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) + field.Set(mf) + f := mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) + field = f + goto setValue + } + } + + if !isNative { + fd := field.Addr().Interface().(Fielder) + err := fd.SetRaw(value) + if err != nil { + err = fmt.Errorf("converted value `%v` set to Fielder `%s` failed, err: %s", value, fi.fullName, err) + return nil, err + } + } + + return value, nil +} + +// query sql, read values , save to *[]ParamList. +func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { + + var ( + maps []Params + lists []ParamsList + list ParamsList + ) + + typ := 0 + switch v := container.(type) { + case *[]Params: + d := *v + if len(d) == 0 { + maps = d + } + typ = 1 + case *[]ParamsList: + d := *v + if len(d) == 0 { + lists = d + } + typ = 2 + case *ParamsList: + d := *v + if len(d) == 0 { + list = d + } + typ = 3 + default: + panic(fmt.Errorf("unsupport read values type `%T`", container)) + } + + tables := newDbTables(mi, d.ins) + + var ( + cols []string + infos []*fieldInfo + ) + + hasExprs := len(exprs) > 0 + + Q := d.ins.TableQuote() + + if hasExprs { + cols = make([]string, 0, len(exprs)) + infos = make([]*fieldInfo, 0, len(exprs)) + for _, ex := range exprs { + index, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep)) + if !suc { + panic(fmt.Errorf("unknown field/column name `%s`", ex)) + } + cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, fi.column, Q, Q, name, Q)) + infos = append(infos, fi) + } + } else { + cols = make([]string, 0, len(mi.fields.dbcols)) + infos = make([]*fieldInfo, 0, len(exprs)) + for _, fi := range mi.fields.fieldsDB { + cols = append(cols, fmt.Sprintf("T0.%s%s%s %s%s%s", Q, fi.column, Q, Q, fi.name, Q)) + infos = append(infos, fi) + } + } + + where, args := tables.getCondSQL(cond, false, tz) + groupBy := tables.getGroupSQL(qs.groups) + orderBy := tables.getOrderSQL(qs.orders) + limit := tables.getLimitSQL(mi, qs.offset, qs.limit) + join := tables.getJoinSQL() + + sels := strings.Join(cols, ", ") + + sqlSelect := "SELECT" + if qs.distinct { + sqlSelect += " DISTINCT" + } + query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) + + d.ins.ReplaceMarks(&query) + + rs, err := q.Query(query, args...) + if err != nil { + return 0, err + } + refs := make([]interface{}, len(cols)) + for i := range refs { + var ref interface{} + refs[i] = &ref + } + + defer rs.Close() + + var ( + cnt int64 + columns []string + ) + for rs.Next() { + if cnt == 0 { + cols, err := rs.Columns() + if err != nil { + return 0, err + } + columns = cols + } + + if err := rs.Scan(refs...); err != nil { + return 0, err + } + + switch typ { + case 1: + params := make(Params, len(cols)) + for i, ref := range refs { + fi := infos[i] + + val := reflect.Indirect(reflect.ValueOf(ref)).Interface() + + value, err := d.convertValueFromDB(fi, val, tz) + if err != nil { + panic(fmt.Errorf("db value convert failed `%v` %s", val, err.Error())) + } + + params[columns[i]] = value + } + maps = append(maps, params) + case 2: + params := make(ParamsList, 0, len(cols)) + for i, ref := range refs { + fi := infos[i] + + val := reflect.Indirect(reflect.ValueOf(ref)).Interface() + + value, err := d.convertValueFromDB(fi, val, tz) + if err != nil { + panic(fmt.Errorf("db value convert failed `%v` %s", val, err.Error())) + } + + params = append(params, value) + } + lists = append(lists, params) + case 3: + for i, ref := range refs { + fi := infos[i] + + val := reflect.Indirect(reflect.ValueOf(ref)).Interface() + + value, err := d.convertValueFromDB(fi, val, tz) + if err != nil { + panic(fmt.Errorf("db value convert failed `%v` %s", val, err.Error())) + } + + list = append(list, value) + } + } + + cnt++ + } + + switch v := container.(type) { + case *[]Params: + *v = maps + case *[]ParamsList: + *v = lists + case *ParamsList: + *v = list + } + + return cnt, nil +} + +func (d *dbBase) RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) { + return 0, nil +} + +// flag of update joined record. +func (d *dbBase) SupportUpdateJoin() bool { + return true +} + +func (d *dbBase) MaxLimit() uint64 { + return 18446744073709551615 +} + +// return quote. +func (d *dbBase) TableQuote() string { + return "`" +} + +// replace value placeholder in parametered sql string. +func (d *dbBase) ReplaceMarks(query *string) { + // default use `?` as mark, do nothing +} + +// flag of RETURNING sql. +func (d *dbBase) HasReturningID(*modelInfo, *string) bool { + return false +} + +// sync auto key +func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { + return nil +} + +// convert time from db. +func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) { + *t = t.In(tz) +} + +// convert time to db. +func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) { + *t = t.In(tz) +} + +// get database types. +func (d *dbBase) DbTypes() map[string]string { + return nil +} + +// gt all tables. +func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) { + tables := make(map[string]bool) + query := d.ins.ShowTablesQuery() + rows, err := db.Query(query) + if err != nil { + return tables, err + } + + defer rows.Close() + + for rows.Next() { + var table string + err := rows.Scan(&table) + if err != nil { + return tables, err + } + if table != "" { + tables[table] = true + } + } + + return tables, nil +} + +// get all cloumns in table. +func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { + columns := make(map[string][3]string) + query := d.ins.ShowColumnsQuery(table) + rows, err := db.Query(query) + if err != nil { + return columns, err + } + + defer rows.Close() + + for rows.Next() { + var ( + name string + typ string + null string + ) + err := rows.Scan(&name, &typ, &null) + if err != nil { + return columns, err + } + columns[name] = [3]string{name, typ, null} + } + + return columns, nil +} + +// not implement. +func (d *dbBase) OperatorSQL(operator string) string { + panic(ErrNotImplement) +} + +// not implement. +func (d *dbBase) ShowTablesQuery() string { + panic(ErrNotImplement) +} + +// not implement. +func (d *dbBase) ShowColumnsQuery(table string) string { + panic(ErrNotImplement) +} + +// not implement. +func (d *dbBase) IndexExists(dbQuerier, string, string) bool { + panic(ErrNotImplement) +} diff --git a/pkg/orm/db_alias.go b/pkg/orm/db_alias.go new file mode 100644 index 00000000..bf6c350c --- /dev/null +++ b/pkg/orm/db_alias.go @@ -0,0 +1,466 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "database/sql" + "fmt" + lru "github.com/hashicorp/golang-lru" + "reflect" + "sync" + "time" +) + +// DriverType database driver constant int. +type DriverType int + +// Enum the Database driver +const ( + _ DriverType = iota // int enum type + DRMySQL // mysql + DRSqlite // sqlite + DROracle // oracle + DRPostgres // pgsql + DRTiDB // TiDB +) + +// database driver string. +type driver string + +// get type constant int of current driver.. +func (d driver) Type() DriverType { + a, _ := dataBaseCache.get(string(d)) + return a.Driver +} + +// get name of current driver +func (d driver) Name() string { + return string(d) +} + +// check driver iis implemented Driver interface or not. +var _ Driver = new(driver) + +var ( + dataBaseCache = &_dbCache{cache: make(map[string]*alias)} + drivers = map[string]DriverType{ + "mysql": DRMySQL, + "postgres": DRPostgres, + "sqlite3": DRSqlite, + "tidb": DRTiDB, + "oracle": DROracle, + "oci8": DROracle, // github.com/mattn/go-oci8 + "ora": DROracle, //https://github.com/rana/ora + } + dbBasers = map[DriverType]dbBaser{ + DRMySQL: newdbBaseMysql(), + DRSqlite: newdbBaseSqlite(), + DROracle: newdbBaseOracle(), + DRPostgres: newdbBasePostgres(), + DRTiDB: newdbBaseTidb(), + } +) + +// database alias cacher. +type _dbCache struct { + mux sync.RWMutex + cache map[string]*alias +} + +// add database alias with original name. +func (ac *_dbCache) add(name string, al *alias) (added bool) { + ac.mux.Lock() + defer ac.mux.Unlock() + if _, ok := ac.cache[name]; !ok { + ac.cache[name] = al + added = true + } + return +} + +// get database alias if cached. +func (ac *_dbCache) get(name string) (al *alias, ok bool) { + ac.mux.RLock() + defer ac.mux.RUnlock() + al, ok = ac.cache[name] + return +} + +// get default alias. +func (ac *_dbCache) getDefault() (al *alias) { + al, _ = ac.get("default") + return +} + +type DB struct { + *sync.RWMutex + DB *sql.DB + stmtDecorators *lru.Cache +} + +func (d *DB) Begin() (*sql.Tx, error) { + return d.DB.Begin() +} + +func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { + return d.DB.BeginTx(ctx, opts) +} + +//su must call release to release *sql.Stmt after using +func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { + d.RLock() + c, ok := d.stmtDecorators.Get(query) + if ok { + c.(*stmtDecorator).acquire() + d.RUnlock() + return c.(*stmtDecorator), nil + } + d.RUnlock() + + d.Lock() + c, ok = d.stmtDecorators.Get(query) + if ok { + c.(*stmtDecorator).acquire() + d.Unlock() + return c.(*stmtDecorator), nil + } + + stmt, err := d.Prepare(query) + if err != nil { + d.Unlock() + return nil, err + } + sd := newStmtDecorator(stmt) + sd.acquire() + d.stmtDecorators.Add(query, sd) + d.Unlock() + + return sd, nil +} + +func (d *DB) Prepare(query string) (*sql.Stmt, error) { + return d.DB.Prepare(query) +} + +func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + return d.DB.PrepareContext(ctx, query) +} + +func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { + sd, err := d.getStmtDecorator(query) + if err != nil { + return nil, err + } + stmt := sd.getStmt() + defer sd.release() + return stmt.Exec(args...) +} + +func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + sd, err := d.getStmtDecorator(query) + if err != nil { + return nil, err + } + stmt := sd.getStmt() + defer sd.release() + return stmt.ExecContext(ctx, args...) +} + +func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { + sd, err := d.getStmtDecorator(query) + if err != nil { + return nil, err + } + stmt := sd.getStmt() + defer sd.release() + return stmt.Query(args...) +} + +func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + sd, err := d.getStmtDecorator(query) + if err != nil { + return nil, err + } + stmt := sd.getStmt() + defer sd.release() + return stmt.QueryContext(ctx, args...) +} + +func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { + sd, err := d.getStmtDecorator(query) + if err != nil { + panic(err) + } + stmt := sd.getStmt() + defer sd.release() + return stmt.QueryRow(args...) + +} + +func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + sd, err := d.getStmtDecorator(query) + if err != nil { + panic(err) + } + stmt := sd.getStmt() + defer sd.release() + return stmt.QueryRowContext(ctx, args) +} + +type alias struct { + Name string + Driver DriverType + DriverName string + DataSource string + MaxIdleConns int + MaxOpenConns int + DB *DB + DbBaser dbBaser + TZ *time.Location + Engine string +} + +func detectTZ(al *alias) { + // orm timezone system match database + // default use Local + al.TZ = DefaultTimeLoc + + if al.DriverName == "sphinx" { + return + } + + switch al.Driver { + case DRMySQL: + row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)") + var tz string + row.Scan(&tz) + if len(tz) >= 8 { + if tz[0] != '-' { + tz = "+" + tz + } + t, err := time.Parse("-07:00:00", tz) + if err == nil { + if t.Location().String() != "" { + al.TZ = t.Location() + } + } else { + DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) + } + } + + // get default engine from current database + row = al.DB.QueryRow("SELECT ENGINE, TRANSACTIONS FROM information_schema.engines WHERE SUPPORT = 'DEFAULT'") + var engine string + var tx bool + row.Scan(&engine, &tx) + + if engine != "" { + al.Engine = engine + } else { + al.Engine = "INNODB" + } + + case DRSqlite, DROracle: + al.TZ = time.UTC + + case DRPostgres: + row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')") + var tz string + row.Scan(&tz) + loc, err := time.LoadLocation(tz) + if err == nil { + al.TZ = loc + } else { + DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) + } + } +} + +func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { + al := new(alias) + al.Name = aliasName + al.DriverName = driverName + al.DB = &DB{ + RWMutex: new(sync.RWMutex), + DB: db, + stmtDecorators: newStmtDecoratorLruWithEvict(), + } + + if dr, ok := drivers[driverName]; ok { + al.DbBaser = dbBasers[dr] + al.Driver = dr + } else { + return nil, fmt.Errorf("driver name `%s` have not registered", driverName) + } + + err := db.Ping() + if err != nil { + return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error()) + } + + if !dataBaseCache.add(aliasName, al) { + return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) + } + + return al, nil +} + +// AddAliasWthDB add a aliasName for the drivename +func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { + _, err := addAliasWthDB(aliasName, driverName, db) + return err +} + +// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. +func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { + var ( + err error + db *sql.DB + al *alias + ) + + db, err = sql.Open(driverName, dataSource) + if err != nil { + err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) + goto end + } + + al, err = addAliasWthDB(aliasName, driverName, db) + if err != nil { + goto end + } + + al.DataSource = dataSource + + detectTZ(al) + + for i, v := range params { + switch i { + case 0: + SetMaxIdleConns(al.Name, v) + case 1: + SetMaxOpenConns(al.Name, v) + } + } + +end: + if err != nil { + if db != nil { + db.Close() + } + DebugLog.Println(err.Error()) + } + + return err +} + +// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. +func RegisterDriver(driverName string, typ DriverType) error { + if t, ok := drivers[driverName]; !ok { + drivers[driverName] = typ + } else { + if t != typ { + return fmt.Errorf("driverName `%s` db driver already registered and is other type", driverName) + } + } + return nil +} + +// SetDataBaseTZ Change the database default used timezone +func SetDataBaseTZ(aliasName string, tz *time.Location) error { + if al, ok := dataBaseCache.get(aliasName); ok { + al.TZ = tz + } else { + return fmt.Errorf("DataBase alias name `%s` not registered", aliasName) + } + return nil +} + +// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name +func SetMaxIdleConns(aliasName string, maxIdleConns int) { + al := getDbAlias(aliasName) + al.MaxIdleConns = maxIdleConns + al.DB.DB.SetMaxIdleConns(maxIdleConns) +} + +// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name +func SetMaxOpenConns(aliasName string, maxOpenConns int) { + al := getDbAlias(aliasName) + al.MaxOpenConns = maxOpenConns + al.DB.DB.SetMaxOpenConns(maxOpenConns) + // for tip go 1.2 + if fun := reflect.ValueOf(al.DB).MethodByName("SetMaxOpenConns"); fun.IsValid() { + fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)}) + } +} + +// GetDB Get *sql.DB from registered database by db alias name. +// Use "default" as alias name if you not set. +func GetDB(aliasNames ...string) (*sql.DB, error) { + var name string + if len(aliasNames) > 0 { + name = aliasNames[0] + } else { + name = "default" + } + al, ok := dataBaseCache.get(name) + if ok { + return al.DB.DB, nil + } + return nil, fmt.Errorf("DataBase of alias name `%s` not found", name) +} + +type stmtDecorator struct { + wg sync.WaitGroup + stmt *sql.Stmt +} + +func (s *stmtDecorator) getStmt() *sql.Stmt { + return s.stmt +} + +// acquire will add one +// since this method will be used inside read lock scope, +// so we can not do more things here +// we should think about refactor this +func (s *stmtDecorator) acquire() { + s.wg.Add(1) +} + +func (s *stmtDecorator) release() { + s.wg.Done() +} + +//garbage recycle for stmt +func (s *stmtDecorator) destroy() { + go func() { + s.wg.Wait() + _ = s.stmt.Close() + }() +} + +func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator { + return &stmtDecorator{ + stmt: sqlStmt, + } +} + +func newStmtDecoratorLruWithEvict() *lru.Cache { + cache, _ := lru.NewWithEvict(1000, func(key interface{}, value interface{}) { + value.(*stmtDecorator).destroy() + }) + return cache +} diff --git a/pkg/orm/db_mysql.go b/pkg/orm/db_mysql.go new file mode 100644 index 00000000..6e99058e --- /dev/null +++ b/pkg/orm/db_mysql.go @@ -0,0 +1,183 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "reflect" + "strings" +) + +// mysql operators. +var mysqlOperators = map[string]string{ + "exact": "= ?", + "iexact": "LIKE ?", + "contains": "LIKE BINARY ?", + "icontains": "LIKE ?", + // "regex": "REGEXP BINARY ?", + // "iregex": "REGEXP ?", + "gt": "> ?", + "gte": ">= ?", + "lt": "< ?", + "lte": "<= ?", + "eq": "= ?", + "ne": "!= ?", + "startswith": "LIKE BINARY ?", + "endswith": "LIKE BINARY ?", + "istartswith": "LIKE ?", + "iendswith": "LIKE ?", +} + +// mysql column field types. +var mysqlTypes = map[string]string{ + "auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY", + "pk": "NOT NULL PRIMARY KEY", + "bool": "bool", + "string": "varchar(%d)", + "string-char": "char(%d)", + "string-text": "longtext", + "time.Time-date": "date", + "time.Time": "datetime", + "int8": "tinyint", + "int16": "smallint", + "int32": "integer", + "int64": "bigint", + "uint8": "tinyint unsigned", + "uint16": "smallint unsigned", + "uint32": "integer unsigned", + "uint64": "bigint unsigned", + "float64": "double precision", + "float64-decimal": "numeric(%d, %d)", +} + +// mysql dbBaser implementation. +type dbBaseMysql struct { + dbBase +} + +var _ dbBaser = new(dbBaseMysql) + +// get mysql operator. +func (d *dbBaseMysql) OperatorSQL(operator string) string { + return mysqlOperators[operator] +} + +// get mysql table field types. +func (d *dbBaseMysql) DbTypes() map[string]string { + return mysqlTypes +} + +// show table sql for mysql. +func (d *dbBaseMysql) ShowTablesQuery() string { + return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()" +} + +// show columns sql of table for mysql. +func (d *dbBaseMysql) ShowColumnsQuery(table string) string { + return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+ + "WHERE table_schema = DATABASE() AND table_name = '%s'", table) +} + +// execute sql to check index exist. +func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool { + row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ + "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) + var cnt int + row.Scan(&cnt) + return cnt > 0 +} + +// InsertOrUpdate a row +// If your primary key or unique column conflict will update +// If no will insert +// Add "`" for mysql sql building +func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { + var iouStr string + argsMap := map[string]string{} + + iouStr = "ON DUPLICATE KEY UPDATE" + + //Get on the key-value pairs + for _, v := range args { + kv := strings.Split(v, "=") + if len(kv) == 2 { + argsMap[strings.ToLower(kv[0])] = kv[1] + } + } + + isMulti := false + names := make([]string, 0, len(mi.fields.dbcols)-1) + Q := d.ins.TableQuote() + values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ) + + if err != nil { + return 0, err + } + + marks := make([]string, len(names)) + updateValues := make([]interface{}, 0) + updates := make([]string, len(names)) + + for i, v := range names { + marks[i] = "?" + valueStr := argsMap[strings.ToLower(v)] + if valueStr != "" { + updates[i] = "`" + v + "`" + "=" + valueStr + } else { + updates[i] = "`" + v + "`" + "=?" + updateValues = append(updateValues, values[i]) + } + } + + values = append(values, updateValues...) + + sep := fmt.Sprintf("%s, %s", Q, Q) + qmarks := strings.Join(marks, ", ") + qupdates := strings.Join(updates, ", ") + columns := strings.Join(names, sep) + + multi := len(values) / len(names) + + if isMulti { + qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks + } + //conflitValue maybe is a int,can`t use fmt.Sprintf + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr) + + d.ins.ReplaceMarks(&query) + + if isMulti || !d.ins.HasReturningID(mi, &query) { + res, err := q.Exec(query, values...) + if err == nil { + if isMulti { + return res.RowsAffected() + } + return res.LastInsertId() + } + return 0, err + } + + row := q.QueryRow(query, values...) + var id int64 + err = row.Scan(&id) + return id, err +} + +// create new mysql dbBaser. +func newdbBaseMysql() dbBaser { + b := new(dbBaseMysql) + b.ins = b + return b +} diff --git a/pkg/orm/db_oracle.go b/pkg/orm/db_oracle.go new file mode 100644 index 00000000..5d121f83 --- /dev/null +++ b/pkg/orm/db_oracle.go @@ -0,0 +1,137 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "strings" +) + +// oracle operators. +var oracleOperators = map[string]string{ + "exact": "= ?", + "gt": "> ?", + "gte": ">= ?", + "lt": "< ?", + "lte": "<= ?", + "//iendswith": "LIKE ?", +} + +// oracle column field types. +var oracleTypes = map[string]string{ + "pk": "NOT NULL PRIMARY KEY", + "bool": "bool", + "string": "VARCHAR2(%d)", + "string-char": "CHAR(%d)", + "string-text": "VARCHAR2(%d)", + "time.Time-date": "DATE", + "time.Time": "TIMESTAMP", + "int8": "INTEGER", + "int16": "INTEGER", + "int32": "INTEGER", + "int64": "INTEGER", + "uint8": "INTEGER", + "uint16": "INTEGER", + "uint32": "INTEGER", + "uint64": "INTEGER", + "float64": "NUMBER", + "float64-decimal": "NUMBER(%d, %d)", +} + +// oracle dbBaser +type dbBaseOracle struct { + dbBase +} + +var _ dbBaser = new(dbBaseOracle) + +// create oracle dbBaser. +func newdbBaseOracle() dbBaser { + b := new(dbBaseOracle) + b.ins = b + return b +} + +// OperatorSQL get oracle operator. +func (d *dbBaseOracle) OperatorSQL(operator string) string { + return oracleOperators[operator] +} + +// DbTypes get oracle table field types. +func (d *dbBaseOracle) DbTypes() map[string]string { + return oracleTypes +} + +//ShowTablesQuery show all the tables in database +func (d *dbBaseOracle) ShowTablesQuery() string { + return "SELECT TABLE_NAME FROM USER_TABLES" +} + +// Oracle +func (d *dbBaseOracle) ShowColumnsQuery(table string) string { + return fmt.Sprintf("SELECT COLUMN_NAME FROM ALL_TAB_COLUMNS "+ + "WHERE TABLE_NAME ='%s'", strings.ToUpper(table)) +} + +// check index is exist +func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool { + row := db.QueryRow("SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+ + "WHERE USER_IND_COLUMNS.INDEX_NAME = USER_INDEXES.INDEX_NAME "+ + "AND USER_IND_COLUMNS.TABLE_NAME = ? AND USER_IND_COLUMNS.INDEX_NAME = ?", strings.ToUpper(table), strings.ToUpper(name)) + + var cnt int + row.Scan(&cnt) + return cnt > 0 +} + +// execute insert sql with given struct and given values. +// insert the given values, not the field values in struct. +func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { + Q := d.ins.TableQuote() + + marks := make([]string, len(names)) + for i := range marks { + marks[i] = ":" + names[i] + } + + sep := fmt.Sprintf("%s, %s", Q, Q) + qmarks := strings.Join(marks, ", ") + columns := strings.Join(names, sep) + + multi := len(values) / len(names) + + if isMulti { + qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks + } + + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) + + d.ins.ReplaceMarks(&query) + + if isMulti || !d.ins.HasReturningID(mi, &query) { + res, err := q.Exec(query, values...) + if err == nil { + if isMulti { + return res.RowsAffected() + } + return res.LastInsertId() + } + return 0, err + } + row := q.QueryRow(query, values...) + var id int64 + err := row.Scan(&id) + return id, err +} diff --git a/pkg/orm/db_postgres.go b/pkg/orm/db_postgres.go new file mode 100644 index 00000000..c488fb38 --- /dev/null +++ b/pkg/orm/db_postgres.go @@ -0,0 +1,189 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "strconv" +) + +// postgresql operators. +var postgresOperators = map[string]string{ + "exact": "= ?", + "iexact": "= UPPER(?)", + "contains": "LIKE ?", + "icontains": "LIKE UPPER(?)", + "gt": "> ?", + "gte": ">= ?", + "lt": "< ?", + "lte": "<= ?", + "eq": "= ?", + "ne": "!= ?", + "startswith": "LIKE ?", + "endswith": "LIKE ?", + "istartswith": "LIKE UPPER(?)", + "iendswith": "LIKE UPPER(?)", +} + +// postgresql column field types. +var postgresTypes = map[string]string{ + "auto": "serial NOT NULL PRIMARY KEY", + "pk": "NOT NULL PRIMARY KEY", + "bool": "bool", + "string": "varchar(%d)", + "string-char": "char(%d)", + "string-text": "text", + "time.Time-date": "date", + "time.Time": "timestamp with time zone", + "int8": `smallint CHECK("%COL%" >= -127 AND "%COL%" <= 128)`, + "int16": "smallint", + "int32": "integer", + "int64": "bigint", + "uint8": `smallint CHECK("%COL%" >= 0 AND "%COL%" <= 255)`, + "uint16": `integer CHECK("%COL%" >= 0)`, + "uint32": `bigint CHECK("%COL%" >= 0)`, + "uint64": `bigint CHECK("%COL%" >= 0)`, + "float64": "double precision", + "float64-decimal": "numeric(%d, %d)", + "json": "json", + "jsonb": "jsonb", +} + +// postgresql dbBaser. +type dbBasePostgres struct { + dbBase +} + +var _ dbBaser = new(dbBasePostgres) + +// get postgresql operator. +func (d *dbBasePostgres) OperatorSQL(operator string) string { + return postgresOperators[operator] +} + +// generate functioned sql string, such as contains(text). +func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { + switch operator { + case "contains", "startswith", "endswith": + *leftCol = fmt.Sprintf("%s::text", *leftCol) + case "iexact", "icontains", "istartswith", "iendswith": + *leftCol = fmt.Sprintf("UPPER(%s::text)", *leftCol) + } +} + +// postgresql unsupports updating joined record. +func (d *dbBasePostgres) SupportUpdateJoin() bool { + return false +} + +func (d *dbBasePostgres) MaxLimit() uint64 { + return 0 +} + +// postgresql quote is ". +func (d *dbBasePostgres) TableQuote() string { + return `"` +} + +// postgresql value placeholder is $n. +// replace default ? to $n. +func (d *dbBasePostgres) ReplaceMarks(query *string) { + q := *query + num := 0 + for _, c := range q { + if c == '?' { + num++ + } + } + if num == 0 { + return + } + data := make([]byte, 0, len(q)+num) + num = 1 + for i := 0; i < len(q); i++ { + c := q[i] + if c == '?' { + data = append(data, '$') + data = append(data, []byte(strconv.Itoa(num))...) + num++ + } else { + data = append(data, c) + } + } + *query = string(data) +} + +// make returning sql support for postgresql. +func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool { + fi := mi.fields.pk + if fi.fieldType&IsPositiveIntegerField == 0 && fi.fieldType&IsIntegerField == 0 { + return false + } + + if query != nil { + *query = fmt.Sprintf(`%s RETURNING "%s"`, *query, fi.column) + } + return true +} + +// sync auto key +func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { + if len(autoFields) == 0 { + return nil + } + + Q := d.ins.TableQuote() + for _, name := range autoFields { + query := fmt.Sprintf("SELECT setval(pg_get_serial_sequence('%s', '%s'), (SELECT MAX(%s%s%s) FROM %s%s%s));", + mi.table, name, + Q, name, Q, + Q, mi.table, Q) + if _, err := db.Exec(query); err != nil { + return err + } + } + return nil +} + +// show table sql for postgresql. +func (d *dbBasePostgres) ShowTablesQuery() string { + return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')" +} + +// show table columns sql for postgresql. +func (d *dbBasePostgres) ShowColumnsQuery(table string) string { + return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table) +} + +// get column types of postgresql. +func (d *dbBasePostgres) DbTypes() map[string]string { + return postgresTypes +} + +// check index exist in postgresql. +func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool { + query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name) + row := db.QueryRow(query) + var cnt int + row.Scan(&cnt) + return cnt > 0 +} + +// create new postgresql dbBaser. +func newdbBasePostgres() dbBaser { + b := new(dbBasePostgres) + b.ins = b + return b +} diff --git a/pkg/orm/db_sqlite.go b/pkg/orm/db_sqlite.go new file mode 100644 index 00000000..1d62ee34 --- /dev/null +++ b/pkg/orm/db_sqlite.go @@ -0,0 +1,161 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "database/sql" + "fmt" + "reflect" + "time" +) + +// sqlite operators. +var sqliteOperators = map[string]string{ + "exact": "= ?", + "iexact": "LIKE ? ESCAPE '\\'", + "contains": "LIKE ? ESCAPE '\\'", + "icontains": "LIKE ? ESCAPE '\\'", + "gt": "> ?", + "gte": ">= ?", + "lt": "< ?", + "lte": "<= ?", + "eq": "= ?", + "ne": "!= ?", + "startswith": "LIKE ? ESCAPE '\\'", + "endswith": "LIKE ? ESCAPE '\\'", + "istartswith": "LIKE ? ESCAPE '\\'", + "iendswith": "LIKE ? ESCAPE '\\'", +} + +// sqlite column types. +var sqliteTypes = map[string]string{ + "auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT", + "pk": "NOT NULL PRIMARY KEY", + "bool": "bool", + "string": "varchar(%d)", + "string-char": "character(%d)", + "string-text": "text", + "time.Time-date": "date", + "time.Time": "datetime", + "int8": "tinyint", + "int16": "smallint", + "int32": "integer", + "int64": "bigint", + "uint8": "tinyint unsigned", + "uint16": "smallint unsigned", + "uint32": "integer unsigned", + "uint64": "bigint unsigned", + "float64": "real", + "float64-decimal": "decimal", +} + +// sqlite dbBaser. +type dbBaseSqlite struct { + dbBase +} + +var _ dbBaser = new(dbBaseSqlite) + +// override base db read for update behavior as SQlite does not support syntax +func (d *dbBaseSqlite) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { + if isForUpdate { + DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work") + } + return d.dbBase.Read(q, mi, ind, tz, cols, false) +} + +// get sqlite operator. +func (d *dbBaseSqlite) OperatorSQL(operator string) string { + return sqliteOperators[operator] +} + +// generate functioned sql for sqlite. +// only support DATE(text). +func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { + if fi.fieldType == TypeDateField { + *leftCol = fmt.Sprintf("DATE(%s)", *leftCol) + } +} + +// unable updating joined record in sqlite. +func (d *dbBaseSqlite) SupportUpdateJoin() bool { + return false +} + +// max int in sqlite. +func (d *dbBaseSqlite) MaxLimit() uint64 { + return 9223372036854775807 +} + +// get column types in sqlite. +func (d *dbBaseSqlite) DbTypes() map[string]string { + return sqliteTypes +} + +// get show tables sql in sqlite. +func (d *dbBaseSqlite) ShowTablesQuery() string { + return "SELECT name FROM sqlite_master WHERE type = 'table'" +} + +// get columns in sqlite. +func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { + query := d.ins.ShowColumnsQuery(table) + rows, err := db.Query(query) + if err != nil { + return nil, err + } + + columns := make(map[string][3]string) + for rows.Next() { + var tmp, name, typ, null sql.NullString + err := rows.Scan(&tmp, &name, &typ, &null, &tmp, &tmp) + if err != nil { + return nil, err + } + columns[name.String] = [3]string{name.String, typ.String, null.String} + } + + return columns, nil +} + +// get show columns sql in sqlite. +func (d *dbBaseSqlite) ShowColumnsQuery(table string) string { + return fmt.Sprintf("pragma table_info('%s')", table) +} + +// check index exist in sqlite. +func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool { + query := fmt.Sprintf("PRAGMA index_list('%s')", table) + rows, err := db.Query(query) + if err != nil { + panic(err) + } + defer rows.Close() + for rows.Next() { + var tmp, index sql.NullString + rows.Scan(&tmp, &index, &tmp, &tmp, &tmp) + if name == index.String { + return true + } + } + return false +} + +// create new sqlite dbBaser. +func newdbBaseSqlite() dbBaser { + b := new(dbBaseSqlite) + b.ins = b + return b +} diff --git a/pkg/orm/db_tables.go b/pkg/orm/db_tables.go new file mode 100644 index 00000000..4b21a6fc --- /dev/null +++ b/pkg/orm/db_tables.go @@ -0,0 +1,482 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "strings" + "time" +) + +// table info struct. +type dbTable struct { + id int + index string + name string + names []string + sel bool + inner bool + mi *modelInfo + fi *fieldInfo + jtl *dbTable +} + +// tables collection struct, contains some tables. +type dbTables struct { + tablesM map[string]*dbTable + tables []*dbTable + mi *modelInfo + base dbBaser + skipEnd bool +} + +// set table info to collection. +// if not exist, create new. +func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable { + name := strings.Join(names, ExprSep) + if j, ok := t.tablesM[name]; ok { + j.name = name + j.mi = mi + j.fi = fi + j.inner = inner + } else { + i := len(t.tables) + 1 + jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} + t.tablesM[name] = jt + t.tables = append(t.tables, jt) + } + return t.tablesM[name] +} + +// add table info to collection. +func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) { + name := strings.Join(names, ExprSep) + if _, ok := t.tablesM[name]; !ok { + i := len(t.tables) + 1 + jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} + t.tablesM[name] = jt + t.tables = append(t.tables, jt) + return jt, true + } + return t.tablesM[name], false +} + +// get table info in collection. +func (t *dbTables) get(name string) (*dbTable, bool) { + j, ok := t.tablesM[name] + return j, ok +} + +// get related fields info in recursive depth loop. +// loop once, depth decreases one. +func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string { + if depth < 0 || fi.fieldType == RelManyToMany { + return related + } + + if prefix == "" { + prefix = fi.name + } else { + prefix = prefix + ExprSep + fi.name + } + related = append(related, prefix) + + depth-- + for _, fi := range fi.relModelInfo.fields.fieldsRel { + related = t.loopDepth(depth, prefix, fi, related) + } + + return related +} + +// parse related fields. +func (t *dbTables) parseRelated(rels []string, depth int) { + + relsNum := len(rels) + related := make([]string, relsNum) + copy(related, rels) + + relDepth := depth + + if relsNum != 0 { + relDepth = 0 + } + + relDepth-- + for _, fi := range t.mi.fields.fieldsRel { + related = t.loopDepth(relDepth, "", fi, related) + } + + for i, s := range related { + var ( + exs = strings.Split(s, ExprSep) + names = make([]string, 0, len(exs)) + mmi = t.mi + cancel = true + jtl *dbTable + ) + + inner := true + + for _, ex := range exs { + if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany { + names = append(names, fi.name) + mmi = fi.relModelInfo + + if fi.null || t.skipEnd { + inner = false + } + + jt := t.set(names, mmi, fi, inner) + jt.jtl = jtl + + if fi.reverse { + cancel = false + } + + if cancel { + jt.sel = depth > 0 + + if i < relsNum { + jt.sel = true + } + } + + jtl = jt + + } else { + panic(fmt.Errorf("unknown model/table name `%s`", ex)) + } + } + } +} + +// generate join string. +func (t *dbTables) getJoinSQL() (join string) { + Q := t.base.TableQuote() + + for _, jt := range t.tables { + if jt.inner { + join += "INNER JOIN " + } else { + join += "LEFT OUTER JOIN " + } + var ( + table string + t1, t2 string + c1, c2 string + ) + t1 = "T0" + if jt.jtl != nil { + t1 = jt.jtl.index + } + t2 = jt.index + table = jt.mi.table + + switch { + case jt.fi.fieldType == RelManyToMany || jt.fi.fieldType == RelReverseMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany: + c1 = jt.fi.mi.fields.pk.column + for _, ffi := range jt.mi.fields.fieldsRel { + if jt.fi.mi == ffi.relModelInfo { + c2 = ffi.column + break + } + } + default: + c1 = jt.fi.column + c2 = jt.fi.relModelInfo.fields.pk.column + + if jt.fi.reverse { + c1 = jt.mi.fields.pk.column + c2 = jt.fi.reverseFieldInfo.column + } + } + + join += fmt.Sprintf("%s%s%s %s ON %s.%s%s%s = %s.%s%s%s ", Q, table, Q, t2, + t2, Q, c2, Q, t1, Q, c1, Q) + } + return +} + +// parse orm model struct field tag expression. +func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) { + var ( + jtl *dbTable + fi *fieldInfo + fiN *fieldInfo + mmi = mi + ) + + num := len(exprs) - 1 + var names []string + + inner := true + +loopFor: + for i, ex := range exprs { + + var ok, okN bool + + if fiN != nil { + fi = fiN + ok = true + fiN = nil + } + + if i == 0 { + fi, ok = mmi.fields.GetByAny(ex) + } + + _ = okN + + if ok { + + isRel := fi.rel || fi.reverse + + names = append(names, fi.name) + + switch { + case fi.rel: + mmi = fi.relModelInfo + if fi.fieldType == RelManyToMany { + mmi = fi.relThroughModelInfo + } + case fi.reverse: + mmi = fi.reverseFieldInfo.mi + } + + if i < num { + fiN, okN = mmi.fields.GetByAny(exprs[i+1]) + } + + if isRel && (!fi.mi.isThrough || num != i) { + if fi.null || t.skipEnd { + inner = false + } + + if t.skipEnd && okN || !t.skipEnd { + if t.skipEnd && okN && fiN.pk { + goto loopEnd + } + + jt, _ := t.add(names, mmi, fi, inner) + jt.jtl = jtl + jtl = jt + } + + } + + if num != i { + continue + } + + loopEnd: + + if i == 0 || jtl == nil { + index = "T0" + } else { + index = jtl.index + } + + info = fi + + if jtl == nil { + name = fi.name + } else { + name = jtl.name + ExprSep + fi.name + } + + switch { + case fi.rel: + + case fi.reverse: + switch fi.reverseFieldInfo.fieldType { + case RelOneToOne, RelForeignKey: + index = jtl.index + info = fi.reverseFieldInfo.mi.fields.pk + name = info.name + } + } + + break loopFor + + } else { + index = "" + name = "" + info = nil + success = false + return + } + } + + success = index != "" && info != nil + return +} + +// generate condition sql. +func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) { + if cond == nil || cond.IsEmpty() { + return + } + + Q := t.base.TableQuote() + + mi := t.mi + + for i, p := range cond.params { + if i > 0 { + if p.isOr { + where += "OR " + } else { + where += "AND " + } + } + if p.isNot { + where += "NOT " + } + if p.isCond { + w, ps := t.getCondSQL(p.cond, true, tz) + if w != "" { + w = fmt.Sprintf("( %s) ", w) + } + where += w + params = append(params, ps...) + } else { + exprs := p.exprs + + num := len(exprs) - 1 + operator := "" + if operators[exprs[num]] { + operator = exprs[num] + exprs = exprs[:num] + } + + index, _, fi, suc := t.parseExprs(mi, exprs) + if !suc { + panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep))) + } + + if operator == "" { + operator = "exact" + } + + var operSQL string + var args []interface{} + if p.isRaw { + operSQL = p.sql + } else { + operSQL, args = t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz) + } + + leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q) + t.base.GenerateOperatorLeftCol(fi, operator, &leftCol) + + where += fmt.Sprintf("%s %s ", leftCol, operSQL) + params = append(params, args...) + + } + } + + if !sub && where != "" { + where = "WHERE " + where + } + + return +} + +// generate group sql. +func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) { + if len(groups) == 0 { + return + } + + Q := t.base.TableQuote() + + groupSqls := make([]string, 0, len(groups)) + for _, group := range groups { + exprs := strings.Split(group, ExprSep) + + index, _, fi, suc := t.parseExprs(t.mi, exprs) + if !suc { + panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) + } + + groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)) + } + + groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", ")) + return +} + +// generate order sql. +func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) { + if len(orders) == 0 { + return + } + + Q := t.base.TableQuote() + + orderSqls := make([]string, 0, len(orders)) + for _, order := range orders { + asc := "ASC" + if order[0] == '-' { + asc = "DESC" + order = order[1:] + } + exprs := strings.Split(order, ExprSep) + + index, _, fi, suc := t.parseExprs(t.mi, exprs) + if !suc { + panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) + } + + orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc)) + } + + orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) + return +} + +// generate limit sql. +func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) { + if limit == 0 { + limit = int64(DefaultRowsLimit) + } + if limit < 0 { + // no limit + if offset > 0 { + maxLimit := t.base.MaxLimit() + if maxLimit == 0 { + limits = fmt.Sprintf("OFFSET %d", offset) + } else { + limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset) + } + } + } else if offset <= 0 { + limits = fmt.Sprintf("LIMIT %d", limit) + } else { + limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset) + } + return +} + +// crete new tables collection. +func newDbTables(mi *modelInfo, base dbBaser) *dbTables { + tables := &dbTables{} + tables.tablesM = make(map[string]*dbTable) + tables.mi = mi + tables.base = base + return tables +} diff --git a/pkg/orm/db_tidb.go b/pkg/orm/db_tidb.go new file mode 100644 index 00000000..6020a488 --- /dev/null +++ b/pkg/orm/db_tidb.go @@ -0,0 +1,63 @@ +// Copyright 2015 TiDB Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" +) + +// mysql dbBaser implementation. +type dbBaseTidb struct { + dbBase +} + +var _ dbBaser = new(dbBaseTidb) + +// get mysql operator. +func (d *dbBaseTidb) OperatorSQL(operator string) string { + return mysqlOperators[operator] +} + +// get mysql table field types. +func (d *dbBaseTidb) DbTypes() map[string]string { + return mysqlTypes +} + +// show table sql for mysql. +func (d *dbBaseTidb) ShowTablesQuery() string { + return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()" +} + +// show columns sql of table for mysql. +func (d *dbBaseTidb) ShowColumnsQuery(table string) string { + return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+ + "WHERE table_schema = DATABASE() AND table_name = '%s'", table) +} + +// execute sql to check index exist. +func (d *dbBaseTidb) IndexExists(db dbQuerier, table string, name string) bool { + row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ + "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) + var cnt int + row.Scan(&cnt) + return cnt > 0 +} + +// create new mysql dbBaser. +func newdbBaseTidb() dbBaser { + b := new(dbBaseTidb) + b.ins = b + return b +} diff --git a/pkg/orm/db_utils.go b/pkg/orm/db_utils.go new file mode 100644 index 00000000..7ae10ca5 --- /dev/null +++ b/pkg/orm/db_utils.go @@ -0,0 +1,177 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "reflect" + "time" +) + +// get table alias. +func getDbAlias(name string) *alias { + if al, ok := dataBaseCache.get(name); ok { + return al + } + panic(fmt.Errorf("unknown DataBase alias name %s", name)) +} + +// get pk column info. +func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) { + fi := mi.fields.pk + + v := ind.FieldByIndex(fi.fieldIndex) + if fi.fieldType&IsPositiveIntegerField > 0 { + vu := v.Uint() + exist = vu > 0 + value = vu + } else if fi.fieldType&IsIntegerField > 0 { + vu := v.Int() + exist = true + value = vu + } else if fi.fieldType&IsRelField > 0 { + _, value, exist = getExistPk(fi.relModelInfo, reflect.Indirect(v)) + } else { + vu := v.String() + exist = vu != "" + value = vu + } + + column = fi.column + return +} + +// get fields description as flatted string. +func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) { + +outFor: + for _, arg := range args { + val := reflect.ValueOf(arg) + + if arg == nil { + params = append(params, arg) + continue + } + + kind := val.Kind() + if kind == reflect.Ptr { + val = val.Elem() + kind = val.Kind() + arg = val.Interface() + } + + switch kind { + case reflect.String: + v := val.String() + if fi != nil { + if fi.fieldType == TypeTimeField || fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField { + var t time.Time + var err error + if len(v) >= 19 { + s := v[:19] + t, err = time.ParseInLocation(formatDateTime, s, DefaultTimeLoc) + } else if len(v) >= 10 { + s := v + if len(v) > 10 { + s = v[:10] + } + t, err = time.ParseInLocation(formatDate, s, tz) + } else { + s := v + if len(s) > 8 { + s = v[:8] + } + t, err = time.ParseInLocation(formatTime, s, tz) + } + if err == nil { + if fi.fieldType == TypeDateField { + v = t.In(tz).Format(formatDate) + } else if fi.fieldType == TypeDateTimeField { + v = t.In(tz).Format(formatDateTime) + } else { + v = t.In(tz).Format(formatTime) + } + } + } + } + arg = v + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + arg = val.Int() + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + arg = val.Uint() + case reflect.Float32: + arg, _ = StrTo(ToStr(arg)).Float64() + case reflect.Float64: + arg = val.Float() + case reflect.Bool: + arg = val.Bool() + case reflect.Slice, reflect.Array: + if _, ok := arg.([]byte); ok { + continue outFor + } + + var args []interface{} + for i := 0; i < val.Len(); i++ { + v := val.Index(i) + + var vu interface{} + if v.CanInterface() { + vu = v.Interface() + } + + if vu == nil { + continue + } + + args = append(args, vu) + } + + if len(args) > 0 { + p := getFlatParams(fi, args, tz) + params = append(params, p...) + } + continue outFor + case reflect.Struct: + if v, ok := arg.(time.Time); ok { + if fi != nil && fi.fieldType == TypeDateField { + arg = v.In(tz).Format(formatDate) + } else if fi != nil && fi.fieldType == TypeDateTimeField { + arg = v.In(tz).Format(formatDateTime) + } else if fi != nil && fi.fieldType == TypeTimeField { + arg = v.In(tz).Format(formatTime) + } else { + arg = v.In(tz).Format(formatDateTime) + } + } else { + typ := val.Type() + name := getFullName(typ) + var value interface{} + if mmi, ok := modelCache.getByFullName(name); ok { + if _, vu, exist := getExistPk(mmi, val); exist { + value = vu + } + } + arg = value + + if arg == nil { + panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name)) + } + } + } + + params = append(params, arg) + } + return +} diff --git a/pkg/orm/models.go b/pkg/orm/models.go new file mode 100644 index 00000000..4776bcba --- /dev/null +++ b/pkg/orm/models.go @@ -0,0 +1,99 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "sync" +) + +const ( + odCascade = "cascade" + odSetNULL = "set_null" + odSetDefault = "set_default" + odDoNothing = "do_nothing" + defaultStructTagName = "orm" + defaultStructTagDelim = ";" +) + +var ( + modelCache = &_modelCache{ + cache: make(map[string]*modelInfo), + cacheByFullName: make(map[string]*modelInfo), + } +) + +// model info collection +type _modelCache struct { + sync.RWMutex // only used outsite for bootStrap + orders []string + cache map[string]*modelInfo + cacheByFullName map[string]*modelInfo + done bool +} + +// get all model info +func (mc *_modelCache) all() map[string]*modelInfo { + m := make(map[string]*modelInfo, len(mc.cache)) + for k, v := range mc.cache { + m[k] = v + } + return m +} + +// get ordered model info +func (mc *_modelCache) allOrdered() []*modelInfo { + m := make([]*modelInfo, 0, len(mc.orders)) + for _, table := range mc.orders { + m = append(m, mc.cache[table]) + } + return m +} + +// get model info by table name +func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) { + mi, ok = mc.cache[table] + return +} + +// get model info by full name +func (mc *_modelCache) getByFullName(name string) (mi *modelInfo, ok bool) { + mi, ok = mc.cacheByFullName[name] + return +} + +// set model info to collection +func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo { + mii := mc.cache[table] + mc.cache[table] = mi + mc.cacheByFullName[mi.fullName] = mi + if mii == nil { + mc.orders = append(mc.orders, table) + } + return mii +} + +// clean all model info. +func (mc *_modelCache) clean() { + mc.orders = make([]string, 0) + mc.cache = make(map[string]*modelInfo) + mc.cacheByFullName = make(map[string]*modelInfo) + mc.done = false +} + +// ResetModelCache Clean model cache. Then you can re-RegisterModel. +// Common use this api for test case. +func ResetModelCache() { + modelCache.clean() +} diff --git a/pkg/orm/models_boot.go b/pkg/orm/models_boot.go new file mode 100644 index 00000000..8c56b3c4 --- /dev/null +++ b/pkg/orm/models_boot.go @@ -0,0 +1,347 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "os" + "reflect" + "runtime/debug" + "strings" +) + +// register models. +// PrefixOrSuffix means table name prefix or suffix. +// isPrefix whether the prefix is prefix or suffix +func registerModel(PrefixOrSuffix string, model interface{}, isPrefix bool) { + val := reflect.ValueOf(model) + typ := reflect.Indirect(val).Type() + + if val.Kind() != reflect.Ptr { + panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ))) + } + // For this case: + // u := &User{} + // registerModel(&u) + if typ.Kind() == reflect.Ptr { + panic(fmt.Errorf(" only allow ptr model struct, it looks you use two reference to the struct `%s`", typ)) + } + + table := getTableName(val) + + if PrefixOrSuffix != "" { + if isPrefix { + table = PrefixOrSuffix + table + } else { + table = table + PrefixOrSuffix + } + } + // models's fullname is pkgpath + struct name + name := getFullName(typ) + if _, ok := modelCache.getByFullName(name); ok { + fmt.Printf(" model `%s` repeat register, must be unique\n", name) + os.Exit(2) + } + + if _, ok := modelCache.get(table); ok { + fmt.Printf(" table name `%s` repeat register, must be unique\n", table) + os.Exit(2) + } + + mi := newModelInfo(val) + if mi.fields.pk == nil { + outFor: + for _, fi := range mi.fields.fieldsDB { + if strings.ToLower(fi.name) == "id" { + switch fi.addrValue.Elem().Kind() { + case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: + fi.auto = true + fi.pk = true + mi.fields.pk = fi + break outFor + } + } + } + + if mi.fields.pk == nil { + fmt.Printf(" `%s` needs a primary key field, default is to use 'id' if not set\n", name) + os.Exit(2) + } + + } + + mi.table = table + mi.pkg = typ.PkgPath() + mi.model = model + mi.manual = true + + modelCache.set(table, mi) +} + +// bootstrap models +func bootStrap() { + if modelCache.done { + return + } + var ( + err error + models map[string]*modelInfo + ) + if dataBaseCache.getDefault() == nil { + err = fmt.Errorf("must have one register DataBase alias named `default`") + goto end + } + + // set rel and reverse model + // RelManyToMany set the relTable + models = modelCache.all() + for _, mi := range models { + for _, fi := range mi.fields.columns { + if fi.rel || fi.reverse { + elm := fi.addrValue.Type().Elem() + if fi.fieldType == RelReverseMany || fi.fieldType == RelManyToMany { + elm = elm.Elem() + } + // check the rel or reverse model already register + name := getFullName(elm) + mii, ok := modelCache.getByFullName(name) + if !ok || mii.pkg != elm.PkgPath() { + err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String()) + goto end + } + fi.relModelInfo = mii + + switch fi.fieldType { + case RelManyToMany: + if fi.relThrough != "" { + if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) { + pn := fi.relThrough[:i] + rmi, ok := modelCache.getByFullName(fi.relThrough) + if !ok || pn != rmi.pkg { + err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough) + goto end + } + fi.relThroughModelInfo = rmi + fi.relTable = rmi.table + } else { + err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough) + goto end + } + } else { + i := newM2MModelInfo(mi, mii) + if fi.relTable != "" { + i.table = fi.relTable + } + if v := modelCache.set(i.table, i); v != nil { + err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable) + goto end + } + fi.relTable = i.table + fi.relThroughModelInfo = i + } + + fi.relThroughModelInfo.isThrough = true + } + } + } + } + + // check the rel filed while the relModelInfo also has filed point to current model + // if not exist, add a new field to the relModelInfo + models = modelCache.all() + for _, mi := range models { + for _, fi := range mi.fields.fieldsRel { + switch fi.fieldType { + case RelForeignKey, RelOneToOne, RelManyToMany: + inModel := false + for _, ffi := range fi.relModelInfo.fields.fieldsReverse { + if ffi.relModelInfo == mi { + inModel = true + break + } + } + if !inModel { + rmi := fi.relModelInfo + ffi := new(fieldInfo) + ffi.name = mi.name + ffi.column = ffi.name + ffi.fullName = rmi.fullName + "." + ffi.name + ffi.reverse = true + ffi.relModelInfo = mi + ffi.mi = rmi + if fi.fieldType == RelOneToOne { + ffi.fieldType = RelReverseOne + } else { + ffi.fieldType = RelReverseMany + } + if !rmi.fields.Add(ffi) { + added := false + for cnt := 0; cnt < 5; cnt++ { + ffi.name = fmt.Sprintf("%s%d", mi.name, cnt) + ffi.column = ffi.name + ffi.fullName = rmi.fullName + "." + ffi.name + if added = rmi.fields.Add(ffi); added { + break + } + } + if !added { + panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName)) + } + } + } + } + } + } + + models = modelCache.all() + for _, mi := range models { + for _, fi := range mi.fields.fieldsRel { + switch fi.fieldType { + case RelManyToMany: + for _, ffi := range fi.relThroughModelInfo.fields.fieldsRel { + switch ffi.fieldType { + case RelOneToOne, RelForeignKey: + if ffi.relModelInfo == fi.relModelInfo { + fi.reverseFieldInfoTwo = ffi + } + if ffi.relModelInfo == mi { + fi.reverseField = ffi.name + fi.reverseFieldInfo = ffi + } + } + } + if fi.reverseFieldInfoTwo == nil { + err = fmt.Errorf("can not find m2m field for m2m model `%s`, ensure your m2m model defined correct", + fi.relThroughModelInfo.fullName) + goto end + } + } + } + } + + models = modelCache.all() + for _, mi := range models { + for _, fi := range mi.fields.fieldsReverse { + switch fi.fieldType { + case RelReverseOne: + found := false + mForA: + for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] { + if ffi.relModelInfo == mi { + found = true + fi.reverseField = ffi.name + fi.reverseFieldInfo = ffi + + ffi.reverseField = fi.name + ffi.reverseFieldInfo = fi + break mForA + } + } + if !found { + err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) + goto end + } + case RelReverseMany: + found := false + mForB: + for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] { + if ffi.relModelInfo == mi { + found = true + fi.reverseField = ffi.name + fi.reverseFieldInfo = ffi + + ffi.reverseField = fi.name + ffi.reverseFieldInfo = fi + + break mForB + } + } + if !found { + mForC: + for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] { + conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough || + fi.relTable != "" && fi.relTable == ffi.relTable || + fi.relThrough == "" && fi.relTable == "" + if ffi.relModelInfo == mi && conditions { + found = true + + fi.reverseField = ffi.reverseFieldInfoTwo.name + fi.reverseFieldInfo = ffi.reverseFieldInfoTwo + fi.relThroughModelInfo = ffi.relThroughModelInfo + fi.reverseFieldInfoTwo = ffi.reverseFieldInfo + fi.reverseFieldInfoM2M = ffi + ffi.reverseFieldInfoM2M = fi + + break mForC + } + } + } + if !found { + err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) + goto end + } + } + } + } + +end: + if err != nil { + fmt.Println(err) + debug.PrintStack() + os.Exit(2) + } +} + +// RegisterModel register models +func RegisterModel(models ...interface{}) { + if modelCache.done { + panic(fmt.Errorf("RegisterModel must be run before BootStrap")) + } + RegisterModelWithPrefix("", models...) +} + +// RegisterModelWithPrefix register models with a prefix +func RegisterModelWithPrefix(prefix string, models ...interface{}) { + if modelCache.done { + panic(fmt.Errorf("RegisterModelWithPrefix must be run before BootStrap")) + } + + for _, model := range models { + registerModel(prefix, model, true) + } +} + +// RegisterModelWithSuffix register models with a suffix +func RegisterModelWithSuffix(suffix string, models ...interface{}) { + if modelCache.done { + panic(fmt.Errorf("RegisterModelWithSuffix must be run before BootStrap")) + } + + for _, model := range models { + registerModel(suffix, model, false) + } +} + +// BootStrap bootstrap models. +// make all model parsed and can not add more models +func BootStrap() { + modelCache.Lock() + defer modelCache.Unlock() + if modelCache.done { + return + } + bootStrap() + modelCache.done = true +} diff --git a/pkg/orm/models_fields.go b/pkg/orm/models_fields.go new file mode 100644 index 00000000..b4fad94f --- /dev/null +++ b/pkg/orm/models_fields.go @@ -0,0 +1,783 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "strconv" + "time" +) + +// Define the Type enum +const ( + TypeBooleanField = 1 << iota + TypeVarCharField + TypeCharField + TypeTextField + TypeTimeField + TypeDateField + TypeDateTimeField + TypeBitField + TypeSmallIntegerField + TypeIntegerField + TypeBigIntegerField + TypePositiveBitField + TypePositiveSmallIntegerField + TypePositiveIntegerField + TypePositiveBigIntegerField + TypeFloatField + TypeDecimalField + TypeJSONField + TypeJsonbField + RelForeignKey + RelOneToOne + RelManyToMany + RelReverseOne + RelReverseMany +) + +// Define some logic enum +const ( + IsIntegerField = ^-TypePositiveBigIntegerField >> 6 << 7 + IsPositiveIntegerField = ^-TypePositiveBigIntegerField >> 10 << 11 + IsRelField = ^-RelReverseMany >> 18 << 19 + IsFieldType = ^-RelReverseMany<<1 + 1 +) + +// BooleanField A true/false field. +type BooleanField bool + +// Value return the BooleanField +func (e BooleanField) Value() bool { + return bool(e) +} + +// Set will set the BooleanField +func (e *BooleanField) Set(d bool) { + *e = BooleanField(d) +} + +// String format the Bool to string +func (e *BooleanField) String() string { + return strconv.FormatBool(e.Value()) +} + +// FieldType return BooleanField the type +func (e *BooleanField) FieldType() int { + return TypeBooleanField +} + +// SetRaw set the interface to bool +func (e *BooleanField) SetRaw(value interface{}) error { + switch d := value.(type) { + case bool: + e.Set(d) + case string: + v, err := StrTo(d).Bool() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return the current value +func (e *BooleanField) RawValue() interface{} { + return e.Value() +} + +// verify the BooleanField implement the Fielder interface +var _ Fielder = new(BooleanField) + +// CharField A string field +// required values tag: size +// The size is enforced at the database level and in models’s validation. +// eg: `orm:"size(120)"` +type CharField string + +// Value return the CharField's Value +func (e CharField) Value() string { + return string(e) +} + +// Set CharField value +func (e *CharField) Set(d string) { + *e = CharField(d) +} + +// String return the CharField +func (e *CharField) String() string { + return e.Value() +} + +// FieldType return the enum type +func (e *CharField) FieldType() int { + return TypeVarCharField +} + +// SetRaw set the interface to string +func (e *CharField) SetRaw(value interface{}) error { + switch d := value.(type) { + case string: + e.Set(d) + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return the CharField value +func (e *CharField) RawValue() interface{} { + return e.Value() +} + +// verify CharField implement Fielder +var _ Fielder = new(CharField) + +// TimeField A time, represented in go by a time.Time instance. +// only time values like 10:00:00 +// Has a few extra, optional attr tag: +// +// auto_now: +// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// auto_now_add: +// Automatically set the field to now when the object is first created. Useful for creation of timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// eg: `orm:"auto_now"` or `orm:"auto_now_add"` +type TimeField time.Time + +// Value return the time.Time +func (e TimeField) Value() time.Time { + return time.Time(e) +} + +// Set set the TimeField's value +func (e *TimeField) Set(d time.Time) { + *e = TimeField(d) +} + +// String convert time to string +func (e *TimeField) String() string { + return e.Value().String() +} + +// FieldType return enum type Date +func (e *TimeField) FieldType() int { + return TypeDateField +} + +// SetRaw convert the interface to time.Time. Allow string and time.Time +func (e *TimeField) SetRaw(value interface{}) error { + switch d := value.(type) { + case time.Time: + e.Set(d) + case string: + v, err := timeParse(d, formatTime) + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return time value +func (e *TimeField) RawValue() interface{} { + return e.Value() +} + +var _ Fielder = new(TimeField) + +// DateField A date, represented in go by a time.Time instance. +// only date values like 2006-01-02 +// Has a few extra, optional attr tag: +// +// auto_now: +// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// auto_now_add: +// Automatically set the field to now when the object is first created. Useful for creation of timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// eg: `orm:"auto_now"` or `orm:"auto_now_add"` +type DateField time.Time + +// Value return the time.Time +func (e DateField) Value() time.Time { + return time.Time(e) +} + +// Set set the DateField's value +func (e *DateField) Set(d time.Time) { + *e = DateField(d) +} + +// String convert datetime to string +func (e *DateField) String() string { + return e.Value().String() +} + +// FieldType return enum type Date +func (e *DateField) FieldType() int { + return TypeDateField +} + +// SetRaw convert the interface to time.Time. Allow string and time.Time +func (e *DateField) SetRaw(value interface{}) error { + switch d := value.(type) { + case time.Time: + e.Set(d) + case string: + v, err := timeParse(d, formatDate) + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return Date value +func (e *DateField) RawValue() interface{} { + return e.Value() +} + +// verify DateField implement fielder interface +var _ Fielder = new(DateField) + +// DateTimeField A date, represented in go by a time.Time instance. +// datetime values like 2006-01-02 15:04:05 +// Takes the same extra arguments as DateField. +type DateTimeField time.Time + +// Value return the datetime value +func (e DateTimeField) Value() time.Time { + return time.Time(e) +} + +// Set set the time.Time to datetime +func (e *DateTimeField) Set(d time.Time) { + *e = DateTimeField(d) +} + +// String return the time's String +func (e *DateTimeField) String() string { + return e.Value().String() +} + +// FieldType return the enum TypeDateTimeField +func (e *DateTimeField) FieldType() int { + return TypeDateTimeField +} + +// SetRaw convert the string or time.Time to DateTimeField +func (e *DateTimeField) SetRaw(value interface{}) error { + switch d := value.(type) { + case time.Time: + e.Set(d) + case string: + v, err := timeParse(d, formatDateTime) + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return the datetime value +func (e *DateTimeField) RawValue() interface{} { + return e.Value() +} + +// verify datetime implement fielder +var _ Fielder = new(DateTimeField) + +// FloatField A floating-point number represented in go by a float32 value. +type FloatField float64 + +// Value return the FloatField value +func (e FloatField) Value() float64 { + return float64(e) +} + +// Set the Float64 +func (e *FloatField) Set(d float64) { + *e = FloatField(d) +} + +// String return the string +func (e *FloatField) String() string { + return ToStr(e.Value(), -1, 32) +} + +// FieldType return the enum type +func (e *FloatField) FieldType() int { + return TypeFloatField +} + +// SetRaw converter interface Float64 float32 or string to FloatField +func (e *FloatField) SetRaw(value interface{}) error { + switch d := value.(type) { + case float32: + e.Set(float64(d)) + case float64: + e.Set(d) + case string: + v, err := StrTo(d).Float64() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return the FloatField value +func (e *FloatField) RawValue() interface{} { + return e.Value() +} + +// verify FloatField implement Fielder +var _ Fielder = new(FloatField) + +// SmallIntegerField -32768 to 32767 +type SmallIntegerField int16 + +// Value return int16 value +func (e SmallIntegerField) Value() int16 { + return int16(e) +} + +// Set the SmallIntegerField value +func (e *SmallIntegerField) Set(d int16) { + *e = SmallIntegerField(d) +} + +// String convert smallint to string +func (e *SmallIntegerField) String() string { + return ToStr(e.Value()) +} + +// FieldType return enum type SmallIntegerField +func (e *SmallIntegerField) FieldType() int { + return TypeSmallIntegerField +} + +// SetRaw convert interface int16/string to int16 +func (e *SmallIntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case int16: + e.Set(d) + case string: + v, err := StrTo(d).Int16() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return smallint value +func (e *SmallIntegerField) RawValue() interface{} { + return e.Value() +} + +// verify SmallIntegerField implement Fielder +var _ Fielder = new(SmallIntegerField) + +// IntegerField -2147483648 to 2147483647 +type IntegerField int32 + +// Value return the int32 +func (e IntegerField) Value() int32 { + return int32(e) +} + +// Set IntegerField value +func (e *IntegerField) Set(d int32) { + *e = IntegerField(d) +} + +// String convert Int32 to string +func (e *IntegerField) String() string { + return ToStr(e.Value()) +} + +// FieldType return the enum type +func (e *IntegerField) FieldType() int { + return TypeIntegerField +} + +// SetRaw convert interface int32/string to int32 +func (e *IntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case int32: + e.Set(d) + case string: + v, err := StrTo(d).Int32() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return IntegerField value +func (e *IntegerField) RawValue() interface{} { + return e.Value() +} + +// verify IntegerField implement Fielder +var _ Fielder = new(IntegerField) + +// BigIntegerField -9223372036854775808 to 9223372036854775807. +type BigIntegerField int64 + +// Value return int64 +func (e BigIntegerField) Value() int64 { + return int64(e) +} + +// Set the BigIntegerField value +func (e *BigIntegerField) Set(d int64) { + *e = BigIntegerField(d) +} + +// String convert BigIntegerField to string +func (e *BigIntegerField) String() string { + return ToStr(e.Value()) +} + +// FieldType return enum type +func (e *BigIntegerField) FieldType() int { + return TypeBigIntegerField +} + +// SetRaw convert interface int64/string to int64 +func (e *BigIntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case int64: + e.Set(d) + case string: + v, err := StrTo(d).Int64() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return BigIntegerField value +func (e *BigIntegerField) RawValue() interface{} { + return e.Value() +} + +// verify BigIntegerField implement Fielder +var _ Fielder = new(BigIntegerField) + +// PositiveSmallIntegerField 0 to 65535 +type PositiveSmallIntegerField uint16 + +// Value return uint16 +func (e PositiveSmallIntegerField) Value() uint16 { + return uint16(e) +} + +// Set PositiveSmallIntegerField value +func (e *PositiveSmallIntegerField) Set(d uint16) { + *e = PositiveSmallIntegerField(d) +} + +// String convert uint16 to string +func (e *PositiveSmallIntegerField) String() string { + return ToStr(e.Value()) +} + +// FieldType return enum type +func (e *PositiveSmallIntegerField) FieldType() int { + return TypePositiveSmallIntegerField +} + +// SetRaw convert Interface uint16/string to uint16 +func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case uint16: + e.Set(d) + case string: + v, err := StrTo(d).Uint16() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue returns PositiveSmallIntegerField value +func (e *PositiveSmallIntegerField) RawValue() interface{} { + return e.Value() +} + +// verify PositiveSmallIntegerField implement Fielder +var _ Fielder = new(PositiveSmallIntegerField) + +// PositiveIntegerField 0 to 4294967295 +type PositiveIntegerField uint32 + +// Value return PositiveIntegerField value. Uint32 +func (e PositiveIntegerField) Value() uint32 { + return uint32(e) +} + +// Set the PositiveIntegerField value +func (e *PositiveIntegerField) Set(d uint32) { + *e = PositiveIntegerField(d) +} + +// String convert PositiveIntegerField to string +func (e *PositiveIntegerField) String() string { + return ToStr(e.Value()) +} + +// FieldType return enum type +func (e *PositiveIntegerField) FieldType() int { + return TypePositiveIntegerField +} + +// SetRaw convert interface uint32/string to Uint32 +func (e *PositiveIntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case uint32: + e.Set(d) + case string: + v, err := StrTo(d).Uint32() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return the PositiveIntegerField Value +func (e *PositiveIntegerField) RawValue() interface{} { + return e.Value() +} + +// verify PositiveIntegerField implement Fielder +var _ Fielder = new(PositiveIntegerField) + +// PositiveBigIntegerField 0 to 18446744073709551615 +type PositiveBigIntegerField uint64 + +// Value return uint64 +func (e PositiveBigIntegerField) Value() uint64 { + return uint64(e) +} + +// Set PositiveBigIntegerField value +func (e *PositiveBigIntegerField) Set(d uint64) { + *e = PositiveBigIntegerField(d) +} + +// String convert PositiveBigIntegerField to string +func (e *PositiveBigIntegerField) String() string { + return ToStr(e.Value()) +} + +// FieldType return enum type +func (e *PositiveBigIntegerField) FieldType() int { + return TypePositiveIntegerField +} + +// SetRaw convert interface uint64/string to Uint64 +func (e *PositiveBigIntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case uint64: + e.Set(d) + case string: + v, err := StrTo(d).Uint64() + if err == nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return PositiveBigIntegerField value +func (e *PositiveBigIntegerField) RawValue() interface{} { + return e.Value() +} + +// verify PositiveBigIntegerField implement Fielder +var _ Fielder = new(PositiveBigIntegerField) + +// TextField A large text field. +type TextField string + +// Value return TextField value +func (e TextField) Value() string { + return string(e) +} + +// Set the TextField value +func (e *TextField) Set(d string) { + *e = TextField(d) +} + +// String convert TextField to string +func (e *TextField) String() string { + return e.Value() +} + +// FieldType return enum type +func (e *TextField) FieldType() int { + return TypeTextField +} + +// SetRaw convert interface string to string +func (e *TextField) SetRaw(value interface{}) error { + switch d := value.(type) { + case string: + e.Set(d) + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return TextField value +func (e *TextField) RawValue() interface{} { + return e.Value() +} + +// verify TextField implement Fielder +var _ Fielder = new(TextField) + +// JSONField postgres json field. +type JSONField string + +// Value return JSONField value +func (j JSONField) Value() string { + return string(j) +} + +// Set the JSONField value +func (j *JSONField) Set(d string) { + *j = JSONField(d) +} + +// String convert JSONField to string +func (j *JSONField) String() string { + return j.Value() +} + +// FieldType return enum type +func (j *JSONField) FieldType() int { + return TypeJSONField +} + +// SetRaw convert interface string to string +func (j *JSONField) SetRaw(value interface{}) error { + switch d := value.(type) { + case string: + j.Set(d) + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return JSONField value +func (j *JSONField) RawValue() interface{} { + return j.Value() +} + +// verify JSONField implement Fielder +var _ Fielder = new(JSONField) + +// JsonbField postgres json field. +type JsonbField string + +// Value return JsonbField value +func (j JsonbField) Value() string { + return string(j) +} + +// Set the JsonbField value +func (j *JsonbField) Set(d string) { + *j = JsonbField(d) +} + +// String convert JsonbField to string +func (j *JsonbField) String() string { + return j.Value() +} + +// FieldType return enum type +func (j *JsonbField) FieldType() int { + return TypeJsonbField +} + +// SetRaw convert interface string to string +func (j *JsonbField) SetRaw(value interface{}) error { + switch d := value.(type) { + case string: + j.Set(d) + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return JsonbField value +func (j *JsonbField) RawValue() interface{} { + return j.Value() +} + +// verify JsonbField implement Fielder +var _ Fielder = new(JsonbField) diff --git a/pkg/orm/models_info_f.go b/pkg/orm/models_info_f.go new file mode 100644 index 00000000..7044b0bd --- /dev/null +++ b/pkg/orm/models_info_f.go @@ -0,0 +1,473 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "errors" + "fmt" + "reflect" + "strings" +) + +var errSkipField = errors.New("skip field") + +// field info collection +type fields struct { + pk *fieldInfo + columns map[string]*fieldInfo + fields map[string]*fieldInfo + fieldsLow map[string]*fieldInfo + fieldsByType map[int][]*fieldInfo + fieldsRel []*fieldInfo + fieldsReverse []*fieldInfo + fieldsDB []*fieldInfo + rels []*fieldInfo + orders []string + dbcols []string +} + +// add field info +func (f *fields) Add(fi *fieldInfo) (added bool) { + if f.fields[fi.name] == nil && f.columns[fi.column] == nil { + f.columns[fi.column] = fi + f.fields[fi.name] = fi + f.fieldsLow[strings.ToLower(fi.name)] = fi + } else { + return + } + if _, ok := f.fieldsByType[fi.fieldType]; !ok { + f.fieldsByType[fi.fieldType] = make([]*fieldInfo, 0) + } + f.fieldsByType[fi.fieldType] = append(f.fieldsByType[fi.fieldType], fi) + f.orders = append(f.orders, fi.column) + if fi.dbcol { + f.dbcols = append(f.dbcols, fi.column) + f.fieldsDB = append(f.fieldsDB, fi) + } + if fi.rel { + f.fieldsRel = append(f.fieldsRel, fi) + } + if fi.reverse { + f.fieldsReverse = append(f.fieldsReverse, fi) + } + return true +} + +// get field info by name +func (f *fields) GetByName(name string) *fieldInfo { + return f.fields[name] +} + +// get field info by column name +func (f *fields) GetByColumn(column string) *fieldInfo { + return f.columns[column] +} + +// get field info by string, name is prior +func (f *fields) GetByAny(name string) (*fieldInfo, bool) { + if fi, ok := f.fields[name]; ok { + return fi, ok + } + if fi, ok := f.fieldsLow[strings.ToLower(name)]; ok { + return fi, ok + } + if fi, ok := f.columns[name]; ok { + return fi, ok + } + return nil, false +} + +// create new field info collection +func newFields() *fields { + f := new(fields) + f.fields = make(map[string]*fieldInfo) + f.fieldsLow = make(map[string]*fieldInfo) + f.columns = make(map[string]*fieldInfo) + f.fieldsByType = make(map[int][]*fieldInfo) + return f +} + +// single field info +type fieldInfo struct { + mi *modelInfo + fieldIndex []int + fieldType int + dbcol bool // table column fk and onetoone + inModel bool + name string + fullName string + column string + addrValue reflect.Value + sf reflect.StructField + auto bool + pk bool + null bool + index bool + unique bool + colDefault bool // whether has default tag + initial StrTo // store the default value + size int + toText bool + autoNow bool + autoNowAdd bool + rel bool // if type equal to RelForeignKey, RelOneToOne, RelManyToMany then true + reverse bool + reverseField string + reverseFieldInfo *fieldInfo + reverseFieldInfoTwo *fieldInfo + reverseFieldInfoM2M *fieldInfo + relTable string + relThrough string + relThroughModelInfo *modelInfo + relModelInfo *modelInfo + digits int + decimals int + isFielder bool // implement Fielder interface + onDelete string + description string +} + +// new field info +func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mName string) (fi *fieldInfo, err error) { + var ( + tag string + tagValue string + initial StrTo // store the default value + fieldType int + attrs map[string]bool + tags map[string]string + addrField reflect.Value + ) + + fi = new(fieldInfo) + + // if field which CanAddr is the follow type + // A value is addressable if it is an element of a slice, + // an element of an addressable array, a field of an + // addressable struct, or the result of dereferencing a pointer. + addrField = field + if field.CanAddr() && field.Kind() != reflect.Ptr { + addrField = field.Addr() + if _, ok := addrField.Interface().(Fielder); !ok { + if field.Kind() == reflect.Slice { + addrField = field + } + } + } + + attrs, tags = parseStructTag(sf.Tag.Get(defaultStructTagName)) + + if _, ok := attrs["-"]; ok { + return nil, errSkipField + } + + digits := tags["digits"] + decimals := tags["decimals"] + size := tags["size"] + onDelete := tags["on_delete"] + + initial.Clear() + if v, ok := tags["default"]; ok { + initial.Set(v) + } + +checkType: + switch f := addrField.Interface().(type) { + case Fielder: + fi.isFielder = true + if field.Kind() == reflect.Ptr { + err = fmt.Errorf("the model Fielder can not be use ptr") + goto end + } + fieldType = f.FieldType() + if fieldType&IsRelField > 0 { + err = fmt.Errorf("unsupport type custom field, please refer to https://github.com/astaxie/beego/blob/master/orm/models_fields.go#L24-L42") + goto end + } + default: + tag = "rel" + tagValue = tags[tag] + if tagValue != "" { + switch tagValue { + case "fk": + fieldType = RelForeignKey + break checkType + case "one": + fieldType = RelOneToOne + break checkType + case "m2m": + fieldType = RelManyToMany + if tv := tags["rel_table"]; tv != "" { + fi.relTable = tv + } else if tv := tags["rel_through"]; tv != "" { + fi.relThrough = tv + } + break checkType + default: + err = fmt.Errorf("rel only allow these value: fk, one, m2m") + goto wrongTag + } + } + tag = "reverse" + tagValue = tags[tag] + if tagValue != "" { + switch tagValue { + case "one": + fieldType = RelReverseOne + break checkType + case "many": + fieldType = RelReverseMany + if tv := tags["rel_table"]; tv != "" { + fi.relTable = tv + } else if tv := tags["rel_through"]; tv != "" { + fi.relThrough = tv + } + break checkType + default: + err = fmt.Errorf("reverse only allow these value: one, many") + goto wrongTag + } + } + + fieldType, err = getFieldType(addrField) + if err != nil { + goto end + } + if fieldType == TypeVarCharField { + switch tags["type"] { + case "char": + fieldType = TypeCharField + case "text": + fieldType = TypeTextField + case "json": + fieldType = TypeJSONField + case "jsonb": + fieldType = TypeJsonbField + } + } + if fieldType == TypeFloatField && (digits != "" || decimals != "") { + fieldType = TypeDecimalField + } + if fieldType == TypeDateTimeField && tags["type"] == "date" { + fieldType = TypeDateField + } + if fieldType == TypeTimeField && tags["type"] == "time" { + fieldType = TypeTimeField + } + } + + // check the rel and reverse type + // rel should Ptr + // reverse should slice []*struct + switch fieldType { + case RelForeignKey, RelOneToOne, RelReverseOne: + if field.Kind() != reflect.Ptr { + err = fmt.Errorf("rel/reverse:one field must be *%s", field.Type().Name()) + goto end + } + case RelManyToMany, RelReverseMany: + if field.Kind() != reflect.Slice { + err = fmt.Errorf("rel/reverse:many field must be slice") + goto end + } else { + if field.Type().Elem().Kind() != reflect.Ptr { + err = fmt.Errorf("rel/reverse:many slice must be []*%s", field.Type().Elem().Name()) + goto end + } + } + } + + if fieldType&IsFieldType == 0 { + err = fmt.Errorf("wrong field type") + goto end + } + + fi.fieldType = fieldType + fi.name = sf.Name + fi.column = getColumnName(fieldType, addrField, sf, tags["column"]) + fi.addrValue = addrField + fi.sf = sf + fi.fullName = mi.fullName + mName + "." + sf.Name + + fi.description = tags["description"] + fi.null = attrs["null"] + fi.index = attrs["index"] + fi.auto = attrs["auto"] + fi.pk = attrs["pk"] + fi.unique = attrs["unique"] + + // Mark object property if there is attribute "default" in the orm configuration + if _, ok := tags["default"]; ok { + fi.colDefault = true + } + + switch fieldType { + case RelManyToMany, RelReverseMany, RelReverseOne: + fi.null = false + fi.index = false + fi.auto = false + fi.pk = false + fi.unique = false + default: + fi.dbcol = true + } + + switch fieldType { + case RelForeignKey, RelOneToOne, RelManyToMany: + fi.rel = true + if fieldType == RelOneToOne { + fi.unique = true + } + case RelReverseMany, RelReverseOne: + fi.reverse = true + } + + if fi.rel && fi.dbcol { + switch onDelete { + case odCascade, odDoNothing: + case odSetDefault: + if !initial.Exist() { + err = errors.New("on_delete: set_default need set field a default value") + goto end + } + case odSetNULL: + if !fi.null { + err = errors.New("on_delete: set_null need set field null") + goto end + } + default: + if onDelete == "" { + onDelete = odCascade + } else { + err = fmt.Errorf("on_delete value expected choice in `cascade,set_null,set_default,do_nothing`, unknown `%s`", onDelete) + goto end + } + } + + fi.onDelete = onDelete + } + + switch fieldType { + case TypeBooleanField: + case TypeVarCharField, TypeCharField, TypeJSONField, TypeJsonbField: + if size != "" { + v, e := StrTo(size).Int32() + if e != nil { + err = fmt.Errorf("wrong size value `%s`", size) + } else { + fi.size = int(v) + } + } else { + fi.size = 255 + fi.toText = true + } + case TypeTextField: + fi.index = false + fi.unique = false + case TypeTimeField, TypeDateField, TypeDateTimeField: + if attrs["auto_now"] { + fi.autoNow = true + } else if attrs["auto_now_add"] { + fi.autoNowAdd = true + } + case TypeFloatField: + case TypeDecimalField: + d1 := digits + d2 := decimals + v1, er1 := StrTo(d1).Int8() + v2, er2 := StrTo(d2).Int8() + if er1 != nil || er2 != nil { + err = fmt.Errorf("wrong digits/decimals value %s/%s", d2, d1) + goto end + } + fi.digits = int(v1) + fi.decimals = int(v2) + default: + switch { + case fieldType&IsIntegerField > 0: + case fieldType&IsRelField > 0: + } + } + + if fieldType&IsIntegerField == 0 { + if fi.auto { + err = fmt.Errorf("non-integer type cannot set auto") + goto end + } + } + + if fi.auto || fi.pk { + if fi.auto { + switch addrField.Elem().Kind() { + case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: + default: + err = fmt.Errorf("auto primary key only support int, int32, int64, uint, uint32, uint64 but found `%s`", addrField.Elem().Kind()) + goto end + } + fi.pk = true + } + fi.null = false + fi.index = false + fi.unique = false + } + + if fi.unique { + fi.index = false + } + + // can not set default for these type + if fi.auto || fi.pk || fi.unique || fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField { + initial.Clear() + } + + if initial.Exist() { + v := initial + switch fieldType { + case TypeBooleanField: + _, err = v.Bool() + case TypeFloatField, TypeDecimalField: + _, err = v.Float64() + case TypeBitField: + _, err = v.Int8() + case TypeSmallIntegerField: + _, err = v.Int16() + case TypeIntegerField: + _, err = v.Int32() + case TypeBigIntegerField: + _, err = v.Int64() + case TypePositiveBitField: + _, err = v.Uint8() + case TypePositiveSmallIntegerField: + _, err = v.Uint16() + case TypePositiveIntegerField: + _, err = v.Uint32() + case TypePositiveBigIntegerField: + _, err = v.Uint64() + } + if err != nil { + tag, tagValue = "default", tags["default"] + goto wrongTag + } + } + + fi.initial = initial +end: + if err != nil { + return nil, err + } + return +wrongTag: + return nil, fmt.Errorf("wrong tag format: `%s:\"%s\"`, %s", tag, tagValue, err) +} diff --git a/pkg/orm/models_info_m.go b/pkg/orm/models_info_m.go new file mode 100644 index 00000000..a4d733b6 --- /dev/null +++ b/pkg/orm/models_info_m.go @@ -0,0 +1,148 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "os" + "reflect" +) + +// single model info +type modelInfo struct { + pkg string + name string + fullName string + table string + model interface{} + fields *fields + manual bool + addrField reflect.Value //store the original struct value + uniques []string + isThrough bool +} + +// new model info +func newModelInfo(val reflect.Value) (mi *modelInfo) { + mi = &modelInfo{} + mi.fields = newFields() + ind := reflect.Indirect(val) + mi.addrField = val + mi.name = ind.Type().Name() + mi.fullName = getFullName(ind.Type()) + addModelFields(mi, ind, "", []int{}) + return +} + +// index: FieldByIndex returns the nested field corresponding to index +func addModelFields(mi *modelInfo, ind reflect.Value, mName string, index []int) { + var ( + err error + fi *fieldInfo + sf reflect.StructField + ) + + for i := 0; i < ind.NumField(); i++ { + field := ind.Field(i) + sf = ind.Type().Field(i) + // if the field is unexported skip + if sf.PkgPath != "" { + continue + } + // add anonymous struct fields + if sf.Anonymous { + addModelFields(mi, field, mName+"."+sf.Name, append(index, i)) + continue + } + + fi, err = newFieldInfo(mi, field, sf, mName) + if err == errSkipField { + err = nil + continue + } else if err != nil { + break + } + //record current field index + fi.fieldIndex = append(fi.fieldIndex, index...) + fi.fieldIndex = append(fi.fieldIndex, i) + fi.mi = mi + fi.inModel = true + if !mi.fields.Add(fi) { + err = fmt.Errorf("duplicate column name: %s", fi.column) + break + } + if fi.pk { + if mi.fields.pk != nil { + err = fmt.Errorf("one model must have one pk field only") + break + } else { + mi.fields.pk = fi + } + } + } + + if err != nil { + fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err)) + os.Exit(2) + } +} + +// combine related model info to new model info. +// prepare for relation models query. +func newM2MModelInfo(m1, m2 *modelInfo) (mi *modelInfo) { + mi = new(modelInfo) + mi.fields = newFields() + mi.table = m1.table + "_" + m2.table + "s" + mi.name = camelString(mi.table) + mi.fullName = m1.pkg + "." + mi.name + + fa := new(fieldInfo) // pk + f1 := new(fieldInfo) // m1 table RelForeignKey + f2 := new(fieldInfo) // m2 table RelForeignKey + fa.fieldType = TypeBigIntegerField + fa.auto = true + fa.pk = true + fa.dbcol = true + fa.name = "Id" + fa.column = "id" + fa.fullName = mi.fullName + "." + fa.name + + f1.dbcol = true + f2.dbcol = true + f1.fieldType = RelForeignKey + f2.fieldType = RelForeignKey + f1.name = camelString(m1.table) + f2.name = camelString(m2.table) + f1.fullName = mi.fullName + "." + f1.name + f2.fullName = mi.fullName + "." + f2.name + f1.column = m1.table + "_id" + f2.column = m2.table + "_id" + f1.rel = true + f2.rel = true + f1.relTable = m1.table + f2.relTable = m2.table + f1.relModelInfo = m1 + f2.relModelInfo = m2 + f1.mi = mi + f2.mi = mi + + mi.fields.Add(fa) + mi.fields.Add(f1) + mi.fields.Add(f2) + mi.fields.pk = fa + + mi.uniques = []string{f1.column, f2.column} + return +} diff --git a/pkg/orm/models_test.go b/pkg/orm/models_test.go new file mode 100644 index 00000000..e3a635f2 --- /dev/null +++ b/pkg/orm/models_test.go @@ -0,0 +1,497 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "database/sql" + "encoding/json" + "fmt" + "os" + "strings" + "time" + + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" + // As tidb can't use go get, so disable the tidb testing now + // _ "github.com/pingcap/tidb" +) + +// A slice string field. +type SliceStringField []string + +func (e SliceStringField) Value() []string { + return []string(e) +} + +func (e *SliceStringField) Set(d []string) { + *e = SliceStringField(d) +} + +func (e *SliceStringField) Add(v string) { + *e = append(*e, v) +} + +func (e *SliceStringField) String() string { + return strings.Join(e.Value(), ",") +} + +func (e *SliceStringField) FieldType() int { + return TypeVarCharField +} + +func (e *SliceStringField) SetRaw(value interface{}) error { + switch d := value.(type) { + case []string: + e.Set(d) + case string: + if len(d) > 0 { + parts := strings.Split(d, ",") + v := make([]string, 0, len(parts)) + for _, p := range parts { + v = append(v, strings.TrimSpace(p)) + } + e.Set(v) + } + default: + return fmt.Errorf(" unknown value `%v`", value) + } + return nil +} + +func (e *SliceStringField) RawValue() interface{} { + return e.String() +} + +var _ Fielder = new(SliceStringField) + +// A json field. +type JSONFieldTest struct { + Name string + Data string +} + +func (e *JSONFieldTest) String() string { + data, _ := json.Marshal(e) + return string(data) +} + +func (e *JSONFieldTest) FieldType() int { + return TypeTextField +} + +func (e *JSONFieldTest) SetRaw(value interface{}) error { + switch d := value.(type) { + case string: + return json.Unmarshal([]byte(d), e) + default: + return fmt.Errorf(" unknown value `%v`", value) + } +} + +func (e *JSONFieldTest) RawValue() interface{} { + return e.String() +} + +var _ Fielder = new(JSONFieldTest) + +type Data struct { + ID int `orm:"column(id)"` + Boolean bool + Char string `orm:"size(50)"` + Text string `orm:"type(text)"` + JSON string `orm:"type(json);default({\"name\":\"json\"})"` + Jsonb string `orm:"type(jsonb)"` + Time time.Time `orm:"type(time)"` + Date time.Time `orm:"type(date)"` + DateTime time.Time `orm:"column(datetime)"` + Byte byte + Rune rune + Int int + Int8 int8 + Int16 int16 + Int32 int32 + Int64 int64 + Uint uint + Uint8 uint8 + Uint16 uint16 + Uint32 uint32 + Uint64 uint64 + Float32 float32 + Float64 float64 + Decimal float64 `orm:"digits(8);decimals(4)"` +} + +type DataNull struct { + ID int `orm:"column(id)"` + Boolean bool `orm:"null"` + Char string `orm:"null;size(50)"` + Text string `orm:"null;type(text)"` + JSON string `orm:"type(json);null"` + Jsonb string `orm:"type(jsonb);null"` + Time time.Time `orm:"null;type(time)"` + Date time.Time `orm:"null;type(date)"` + DateTime time.Time `orm:"null;column(datetime)"` + Byte byte `orm:"null"` + Rune rune `orm:"null"` + Int int `orm:"null"` + Int8 int8 `orm:"null"` + Int16 int16 `orm:"null"` + Int32 int32 `orm:"null"` + Int64 int64 `orm:"null"` + Uint uint `orm:"null"` + Uint8 uint8 `orm:"null"` + Uint16 uint16 `orm:"null"` + Uint32 uint32 `orm:"null"` + Uint64 uint64 `orm:"null"` + Float32 float32 `orm:"null"` + Float64 float64 `orm:"null"` + Decimal float64 `orm:"digits(8);decimals(4);null"` + NullString sql.NullString `orm:"null"` + NullBool sql.NullBool `orm:"null"` + NullFloat64 sql.NullFloat64 `orm:"null"` + NullInt64 sql.NullInt64 `orm:"null"` + BooleanPtr *bool `orm:"null"` + CharPtr *string `orm:"null;size(50)"` + TextPtr *string `orm:"null;type(text)"` + BytePtr *byte `orm:"null"` + RunePtr *rune `orm:"null"` + IntPtr *int `orm:"null"` + Int8Ptr *int8 `orm:"null"` + Int16Ptr *int16 `orm:"null"` + Int32Ptr *int32 `orm:"null"` + Int64Ptr *int64 `orm:"null"` + UintPtr *uint `orm:"null"` + Uint8Ptr *uint8 `orm:"null"` + Uint16Ptr *uint16 `orm:"null"` + Uint32Ptr *uint32 `orm:"null"` + Uint64Ptr *uint64 `orm:"null"` + Float32Ptr *float32 `orm:"null"` + Float64Ptr *float64 `orm:"null"` + DecimalPtr *float64 `orm:"digits(8);decimals(4);null"` + TimePtr *time.Time `orm:"null;type(time)"` + DatePtr *time.Time `orm:"null;type(date)"` + DateTimePtr *time.Time `orm:"null"` +} + +type String string +type Boolean bool +type Byte byte +type Rune rune +type Int int +type Int8 int8 +type Int16 int16 +type Int32 int32 +type Int64 int64 +type Uint uint +type Uint8 uint8 +type Uint16 uint16 +type Uint32 uint32 +type Uint64 uint64 +type Float32 float64 +type Float64 float64 + +type DataCustom struct { + ID int `orm:"column(id)"` + Boolean Boolean + Char string `orm:"size(50)"` + Text string `orm:"type(text)"` + Byte Byte + Rune Rune + Int Int + Int8 Int8 + Int16 Int16 + Int32 Int32 + Int64 Int64 + Uint Uint + Uint8 Uint8 + Uint16 Uint16 + Uint32 Uint32 + Uint64 Uint64 + Float32 Float32 + Float64 Float64 + Decimal Float64 `orm:"digits(8);decimals(4)"` +} + +// only for mysql +type UserBig struct { + ID uint64 `orm:"column(id)"` + Name string +} + +type User struct { + ID int `orm:"column(id)"` + UserName string `orm:"size(30);unique"` + Email string `orm:"size(100)"` + Password string `orm:"size(100)"` + Status int16 `orm:"column(Status)"` + IsStaff bool + IsActive bool `orm:"default(true)"` + Created time.Time `orm:"auto_now_add;type(date)"` + Updated time.Time `orm:"auto_now"` + Profile *Profile `orm:"null;rel(one);on_delete(set_null)"` + Posts []*Post `orm:"reverse(many)" json:"-"` + ShouldSkip string `orm:"-"` + Nums int + Langs SliceStringField `orm:"size(100)"` + Extra JSONFieldTest `orm:"type(text)"` + unexport bool `orm:"-"` + unexportBool bool +} + +func (u *User) TableIndex() [][]string { + return [][]string{ + {"Id", "UserName"}, + {"Id", "Created"}, + } +} + +func (u *User) TableUnique() [][]string { + return [][]string{ + {"UserName", "Email"}, + } +} + +func NewUser() *User { + obj := new(User) + return obj +} + +type Profile struct { + ID int `orm:"column(id)"` + Age int16 + Money float64 + User *User `orm:"reverse(one)" json:"-"` + BestPost *Post `orm:"rel(one);null"` +} + +func (u *Profile) TableName() string { + return "user_profile" +} + +func NewProfile() *Profile { + obj := new(Profile) + return obj +} + +type Post struct { + ID int `orm:"column(id)"` + User *User `orm:"rel(fk)"` + Title string `orm:"size(60)"` + Content string `orm:"type(text)"` + Created time.Time `orm:"auto_now_add"` + Updated time.Time `orm:"auto_now"` + Tags []*Tag `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.PostTags)"` +} + +func (u *Post) TableIndex() [][]string { + return [][]string{ + {"Id", "Created"}, + } +} + +func NewPost() *Post { + obj := new(Post) + return obj +} + +type Tag struct { + ID int `orm:"column(id)"` + Name string `orm:"size(30)"` + BestPost *Post `orm:"rel(one);null"` + Posts []*Post `orm:"reverse(many)" json:"-"` +} + +func NewTag() *Tag { + obj := new(Tag) + return obj +} + +type PostTags struct { + ID int `orm:"column(id)"` + Post *Post `orm:"rel(fk)"` + Tag *Tag `orm:"rel(fk)"` +} + +func (m *PostTags) TableName() string { + return "prefix_post_tags" +} + +type Comment struct { + ID int `orm:"column(id)"` + Post *Post `orm:"rel(fk);column(post)"` + Content string `orm:"type(text)"` + Parent *Comment `orm:"null;rel(fk)"` + Created time.Time `orm:"auto_now_add"` +} + +func NewComment() *Comment { + obj := new(Comment) + return obj +} + +type Group struct { + ID int `orm:"column(gid);size(32)"` + Name string + Permissions []*Permission `orm:"reverse(many)" json:"-"` +} + +type Permission struct { + ID int `orm:"column(id)"` + Name string + Groups []*Group `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.GroupPermissions)"` +} + +type GroupPermissions struct { + ID int `orm:"column(id)"` + Group *Group `orm:"rel(fk)"` + Permission *Permission `orm:"rel(fk)"` +} + +type ModelID struct { + ID int64 +} + +type ModelBase struct { + ModelID + + Created time.Time `orm:"auto_now_add;type(datetime)"` + Updated time.Time `orm:"auto_now;type(datetime)"` +} + +type InLine struct { + // Common Fields + ModelBase + + // Other Fields + Name string `orm:"unique"` + Email string +} + +func NewInLine() *InLine { + return new(InLine) +} + +type InLineOneToOne struct { + // Common Fields + ModelBase + + Note string + InLine *InLine `orm:"rel(fk);column(inline)"` +} + +func NewInLineOneToOne() *InLineOneToOne { + return new(InLineOneToOne) +} + +type IntegerPk struct { + ID int64 `orm:"pk"` + Value string +} + +type UintPk struct { + ID uint32 `orm:"pk"` + Name string +} + +type PtrPk struct { + ID *IntegerPk `orm:"pk;rel(one)"` + Positive bool +} + +var DBARGS = struct { + Driver string + Source string + Debug string +}{ + os.Getenv("ORM_DRIVER"), + os.Getenv("ORM_SOURCE"), + os.Getenv("ORM_DEBUG"), +} + +var ( + IsMysql = DBARGS.Driver == "mysql" + IsSqlite = DBARGS.Driver == "sqlite3" + IsPostgres = DBARGS.Driver == "postgres" + IsTidb = DBARGS.Driver == "tidb" +) + +var ( + dORM Ormer + dDbBaser dbBaser +) + +var ( + helpinfo = `need driver and source! + + Default DB Drivers. + + driver: url + mysql: https://github.com/go-sql-driver/mysql + sqlite3: https://github.com/mattn/go-sqlite3 + postgres: https://github.com/lib/pq + tidb: https://github.com/pingcap/tidb + + usage: + + go get -u github.com/astaxie/beego/orm + go get -u github.com/go-sql-driver/mysql + go get -u github.com/mattn/go-sqlite3 + go get -u github.com/lib/pq + go get -u github.com/pingcap/tidb + + #### MySQL + mysql -u root -e 'create database orm_test;' + export ORM_DRIVER=mysql + export ORM_SOURCE="root:@/orm_test?charset=utf8" + go test -v github.com/astaxie/beego/orm + + + #### Sqlite3 + export ORM_DRIVER=sqlite3 + export ORM_SOURCE='file:memory_test?mode=memory' + go test -v github.com/astaxie/beego/orm + + + #### PostgreSQL + psql -c 'create database orm_test;' -U postgres + export ORM_DRIVER=postgres + export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" + go test -v github.com/astaxie/beego/orm + + #### TiDB + export ORM_DRIVER=tidb + export ORM_SOURCE='memory://test/test' + go test -v github.com/astaxie/beego/orm + + ` +) + +func init() { + Debug, _ = StrTo(DBARGS.Debug).Bool() + + if DBARGS.Driver == "" || DBARGS.Source == "" { + fmt.Println(helpinfo) + os.Exit(2) + } + + RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20) + + alias := getDbAlias("default") + if alias.Driver == DRMySQL { + alias.Engine = "INNODB" + } + +} diff --git a/pkg/orm/models_utils.go b/pkg/orm/models_utils.go new file mode 100644 index 00000000..71127a6b --- /dev/null +++ b/pkg/orm/models_utils.go @@ -0,0 +1,227 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "database/sql" + "fmt" + "reflect" + "strings" + "time" +) + +// 1 is attr +// 2 is tag +var supportTag = map[string]int{ + "-": 1, + "null": 1, + "index": 1, + "unique": 1, + "pk": 1, + "auto": 1, + "auto_now": 1, + "auto_now_add": 1, + "size": 2, + "column": 2, + "default": 2, + "rel": 2, + "reverse": 2, + "rel_table": 2, + "rel_through": 2, + "digits": 2, + "decimals": 2, + "on_delete": 2, + "type": 2, + "description": 2, +} + +// get reflect.Type name with package path. +func getFullName(typ reflect.Type) string { + return typ.PkgPath() + "." + typ.Name() +} + +// getTableName get struct table name. +// If the struct implement the TableName, then get the result as tablename +// else use the struct name which will apply snakeString. +func getTableName(val reflect.Value) string { + if fun := val.MethodByName("TableName"); fun.IsValid() { + vals := fun.Call([]reflect.Value{}) + // has return and the first val is string + if len(vals) > 0 && vals[0].Kind() == reflect.String { + return vals[0].String() + } + } + return snakeString(reflect.Indirect(val).Type().Name()) +} + +// get table engine, myisam or innodb. +func getTableEngine(val reflect.Value) string { + fun := val.MethodByName("TableEngine") + if fun.IsValid() { + vals := fun.Call([]reflect.Value{}) + if len(vals) > 0 && vals[0].Kind() == reflect.String { + return vals[0].String() + } + } + return "" +} + +// get table index from method. +func getTableIndex(val reflect.Value) [][]string { + fun := val.MethodByName("TableIndex") + if fun.IsValid() { + vals := fun.Call([]reflect.Value{}) + if len(vals) > 0 && vals[0].CanInterface() { + if d, ok := vals[0].Interface().([][]string); ok { + return d + } + } + } + return nil +} + +// get table unique from method +func getTableUnique(val reflect.Value) [][]string { + fun := val.MethodByName("TableUnique") + if fun.IsValid() { + vals := fun.Call([]reflect.Value{}) + if len(vals) > 0 && vals[0].CanInterface() { + if d, ok := vals[0].Interface().([][]string); ok { + return d + } + } + } + return nil +} + +// get snaked column name +func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string { + column := col + if col == "" { + column = nameStrategyMap[nameStrategy](sf.Name) + } + switch ft { + case RelForeignKey, RelOneToOne: + if len(col) == 0 { + column = column + "_id" + } + case RelManyToMany, RelReverseMany, RelReverseOne: + column = sf.Name + } + return column +} + +// return field type as type constant from reflect.Value +func getFieldType(val reflect.Value) (ft int, err error) { + switch val.Type() { + case reflect.TypeOf(new(int8)): + ft = TypeBitField + case reflect.TypeOf(new(int16)): + ft = TypeSmallIntegerField + case reflect.TypeOf(new(int32)), + reflect.TypeOf(new(int)): + ft = TypeIntegerField + case reflect.TypeOf(new(int64)): + ft = TypeBigIntegerField + case reflect.TypeOf(new(uint8)): + ft = TypePositiveBitField + case reflect.TypeOf(new(uint16)): + ft = TypePositiveSmallIntegerField + case reflect.TypeOf(new(uint32)), + reflect.TypeOf(new(uint)): + ft = TypePositiveIntegerField + case reflect.TypeOf(new(uint64)): + ft = TypePositiveBigIntegerField + case reflect.TypeOf(new(float32)), + reflect.TypeOf(new(float64)): + ft = TypeFloatField + case reflect.TypeOf(new(bool)): + ft = TypeBooleanField + case reflect.TypeOf(new(string)): + ft = TypeVarCharField + case reflect.TypeOf(new(time.Time)): + ft = TypeDateTimeField + default: + elm := reflect.Indirect(val) + switch elm.Kind() { + case reflect.Int8: + ft = TypeBitField + case reflect.Int16: + ft = TypeSmallIntegerField + case reflect.Int32, reflect.Int: + ft = TypeIntegerField + case reflect.Int64: + ft = TypeBigIntegerField + case reflect.Uint8: + ft = TypePositiveBitField + case reflect.Uint16: + ft = TypePositiveSmallIntegerField + case reflect.Uint32, reflect.Uint: + ft = TypePositiveIntegerField + case reflect.Uint64: + ft = TypePositiveBigIntegerField + case reflect.Float32, reflect.Float64: + ft = TypeFloatField + case reflect.Bool: + ft = TypeBooleanField + case reflect.String: + ft = TypeVarCharField + default: + if elm.Interface() == nil { + panic(fmt.Errorf("%s is nil pointer, may be miss setting tag", val)) + } + switch elm.Interface().(type) { + case sql.NullInt64: + ft = TypeBigIntegerField + case sql.NullFloat64: + ft = TypeFloatField + case sql.NullBool: + ft = TypeBooleanField + case sql.NullString: + ft = TypeVarCharField + case time.Time: + ft = TypeDateTimeField + } + } + } + if ft&IsFieldType == 0 { + err = fmt.Errorf("unsupport field type %s, may be miss setting tag", val) + } + return +} + +// parse struct tag string +func parseStructTag(data string) (attrs map[string]bool, tags map[string]string) { + attrs = make(map[string]bool) + tags = make(map[string]string) + for _, v := range strings.Split(data, defaultStructTagDelim) { + if v == "" { + continue + } + v = strings.TrimSpace(v) + if t := strings.ToLower(v); supportTag[t] == 1 { + attrs[t] = true + } else if i := strings.Index(v, "("); i > 0 && strings.Index(v, ")") == len(v)-1 { + name := t[:i] + if supportTag[name] == 2 { + v = v[i+1 : len(v)-1] + tags[name] = v + } + } else { + DebugLog.Println("unsupport orm tag", v) + } + } + return +} diff --git a/pkg/orm/orm.go b/pkg/orm/orm.go new file mode 100644 index 00000000..0551b1cd --- /dev/null +++ b/pkg/orm/orm.go @@ -0,0 +1,579 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build go1.8 + +// Package orm provide ORM for MySQL/PostgreSQL/sqlite +// Simple Usage +// +// package main +// +// import ( +// "fmt" +// "github.com/astaxie/beego/orm" +// _ "github.com/go-sql-driver/mysql" // import your used driver +// ) +// +// // Model Struct +// type User struct { +// Id int `orm:"auto"` +// Name string `orm:"size(100)"` +// } +// +// func init() { +// orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30) +// } +// +// func main() { +// o := orm.NewOrm() +// user := User{Name: "slene"} +// // insert +// id, err := o.Insert(&user) +// // update +// user.Name = "astaxie" +// num, err := o.Update(&user) +// // read one +// u := User{Id: user.Id} +// err = o.Read(&u) +// // delete +// num, err = o.Delete(&u) +// } +// +// more docs: http://beego.me/docs/mvc/model/overview.md +package orm + +import ( + "context" + "database/sql" + "errors" + "fmt" + "os" + "reflect" + "sync" + "time" +) + +// DebugQueries define the debug +const ( + DebugQueries = iota +) + +// Define common vars +var ( + Debug = false + DebugLog = NewLog(os.Stdout) + DefaultRowsLimit = -1 + DefaultRelsDepth = 2 + DefaultTimeLoc = time.Local + ErrTxHasBegan = errors.New(" transaction already begin") + ErrTxDone = errors.New(" transaction not begin") + ErrMultiRows = errors.New(" return multi rows") + ErrNoRows = errors.New(" no row found") + ErrStmtClosed = errors.New(" stmt already closed") + ErrArgs = errors.New(" args error may be empty") + ErrNotImplement = errors.New("have not implement") +) + +// Params stores the Params +type Params map[string]interface{} + +// ParamsList stores paramslist +type ParamsList []interface{} + +type orm struct { + alias *alias + db dbQuerier + isTx bool +} + +var _ Ormer = new(orm) + +// get model info and model reflect value +func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) { + val := reflect.ValueOf(md) + ind = reflect.Indirect(val) + typ := ind.Type() + if needPtr && val.Kind() != reflect.Ptr { + panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ))) + } + name := getFullName(typ) + if mi, ok := modelCache.getByFullName(name); ok { + return mi, ind + } + panic(fmt.Errorf(" table: `%s` not found, make sure it was registered with `RegisterModel()`", name)) +} + +// get field info from model info by given field name +func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { + fi, ok := mi.fields.GetByAny(name) + if !ok { + panic(fmt.Errorf(" cannot find field `%s` for model `%s`", name, mi.fullName)) + } + return fi +} + +// read data to model +func (o *orm) Read(md interface{}, cols ...string) error { + mi, ind := o.getMiInd(md, true) + return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) +} + +// read data to model, like Read(), but use "SELECT FOR UPDATE" form +func (o *orm) ReadForUpdate(md interface{}, cols ...string) error { + mi, ind := o.getMiInd(md, true) + return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true) +} + +// Try to read a row from the database, or insert one if it doesn't exist +func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { + cols = append([]string{col1}, cols...) + mi, ind := o.getMiInd(md, true) + err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) + if err == ErrNoRows { + // Create + id, err := o.Insert(md) + return (err == nil), id, err + } + + id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex) + if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { + id = int64(vid.Uint()) + } else if mi.fields.pk.rel { + return o.ReadOrCreate(vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name) + } else { + id = vid.Int() + } + + return false, id, err +} + +// insert model data to database +func (o *orm) Insert(md interface{}) (int64, error) { + mi, ind := o.getMiInd(md, true) + id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) + if err != nil { + return id, err + } + + o.setPk(mi, ind, id) + + return id, nil +} + +// set auto pk field +func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) { + if mi.fields.pk.auto { + if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { + ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id)) + } else { + ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(id) + } + } +} + +// insert some models to database +func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) { + var cnt int64 + + sind := reflect.Indirect(reflect.ValueOf(mds)) + + switch sind.Kind() { + case reflect.Array, reflect.Slice: + if sind.Len() == 0 { + return cnt, ErrArgs + } + default: + return cnt, ErrArgs + } + + if bulk <= 1 { + for i := 0; i < sind.Len(); i++ { + ind := reflect.Indirect(sind.Index(i)) + mi, _ := o.getMiInd(ind.Interface(), false) + id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) + if err != nil { + return cnt, err + } + + o.setPk(mi, ind, id) + + cnt++ + } + } else { + mi, _ := o.getMiInd(sind.Index(0).Interface(), false) + return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ) + } + return cnt, nil +} + +// InsertOrUpdate data to database +func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) { + mi, ind := o.getMiInd(md, true) + id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...) + if err != nil { + return id, err + } + + o.setPk(mi, ind, id) + + return id, nil +} + +// update model to database. +// cols set the columns those want to update. +func (o *orm) Update(md interface{}, cols ...string) (int64, error) { + mi, ind := o.getMiInd(md, true) + return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) +} + +// delete model in database +// cols shows the delete conditions values read from. default is pk +func (o *orm) Delete(md interface{}, cols ...string) (int64, error) { + mi, ind := o.getMiInd(md, true) + num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols) + if err != nil { + return num, err + } + if num > 0 { + o.setPk(mi, ind, 0) + } + return num, nil +} + +// create a models to models queryer +func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { + mi, ind := o.getMiInd(md, true) + fi := o.getFieldInfo(mi, name) + + switch { + case fi.fieldType == RelManyToMany: + case fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough: + default: + panic(fmt.Errorf(" model `%s` . name `%s` is not a m2m field", fi.name, mi.fullName)) + } + + return newQueryM2M(md, o, mi, fi, ind) +} + +// load related models to md model. +// args are limit, offset int and order string. +// +// example: +// orm.LoadRelated(post,"Tags") +// for _,tag := range post.Tags{...} +// +// make sure the relation is defined in model struct tags. +func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { + _, fi, ind, qseter := o.queryRelated(md, name) + + qs := qseter.(*querySet) + + var relDepth int + var limit, offset int64 + var order string + for i, arg := range args { + switch i { + case 0: + if v, ok := arg.(bool); ok { + if v { + relDepth = DefaultRelsDepth + } + } else if v, ok := arg.(int); ok { + relDepth = v + } + case 1: + limit = ToInt64(arg) + case 2: + offset = ToInt64(arg) + case 3: + order, _ = arg.(string) + } + } + + switch fi.fieldType { + case RelOneToOne, RelForeignKey, RelReverseOne: + limit = 1 + offset = 0 + } + + qs.limit = limit + qs.offset = offset + qs.relDepth = relDepth + + if len(order) > 0 { + qs.orders = []string{order} + } + + find := ind.FieldByIndex(fi.fieldIndex) + + var nums int64 + var err error + switch fi.fieldType { + case RelOneToOne, RelForeignKey, RelReverseOne: + val := reflect.New(find.Type().Elem()) + container := val.Interface() + err = qs.One(container) + if err == nil { + find.Set(val) + nums = 1 + } + default: + nums, err = qs.All(find.Addr().Interface()) + } + + return nums, err +} + +// return a QuerySeter for related models to md model. +// it can do all, update, delete in QuerySeter. +// example: +// qs := orm.QueryRelated(post,"Tag") +// qs.All(&[]*Tag{}) +// +func (o *orm) QueryRelated(md interface{}, name string) QuerySeter { + // is this api needed ? + _, _, _, qs := o.queryRelated(md, name) + return qs +} + +// get QuerySeter for related models to md model +func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { + mi, ind := o.getMiInd(md, true) + fi := o.getFieldInfo(mi, name) + + _, _, exist := getExistPk(mi, ind) + if !exist { + panic(ErrMissPK) + } + + var qs *querySet + + switch fi.fieldType { + case RelOneToOne, RelForeignKey, RelManyToMany: + if !fi.inModel { + break + } + qs = o.getRelQs(md, mi, fi) + case RelReverseOne, RelReverseMany: + if !fi.inModel { + break + } + qs = o.getReverseQs(md, mi, fi) + } + + if qs == nil { + panic(fmt.Errorf(" name `%s` for model `%s` is not an available rel/reverse field", md, name)) + } + + return mi, fi, ind, qs +} + +// get reverse relation QuerySeter +func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { + switch fi.fieldType { + case RelReverseOne, RelReverseMany: + default: + panic(fmt.Errorf(" name `%s` for model `%s` is not an available reverse field", fi.name, mi.fullName)) + } + + var q *querySet + + if fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough { + q = newQuerySet(o, fi.relModelInfo).(*querySet) + q.cond = NewCondition().And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md) + } else { + q = newQuerySet(o, fi.reverseFieldInfo.mi).(*querySet) + q.cond = NewCondition().And(fi.reverseFieldInfo.column, md) + } + + return q +} + +// get relation QuerySeter +func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { + switch fi.fieldType { + case RelOneToOne, RelForeignKey, RelManyToMany: + default: + panic(fmt.Errorf(" name `%s` for model `%s` is not an available rel field", fi.name, mi.fullName)) + } + + q := newQuerySet(o, fi.relModelInfo).(*querySet) + q.cond = NewCondition() + + if fi.fieldType == RelManyToMany { + q.cond = q.cond.And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md) + } else { + q.cond = q.cond.And(fi.reverseFieldInfo.column, md) + } + + return q +} + +// return a QuerySeter for table operations. +// table name can be string or struct. +// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), +func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { + var name string + if table, ok := ptrStructOrTableName.(string); ok { + name = nameStrategyMap[defaultNameStrategy](table) + if mi, ok := modelCache.get(name); ok { + qs = newQuerySet(o, mi) + } + } else { + name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName))) + if mi, ok := modelCache.getByFullName(name); ok { + qs = newQuerySet(o, mi) + } + } + if qs == nil { + panic(fmt.Errorf(" table name: `%s` not exists", name)) + } + return +} + +// switch to another registered database driver by given name. +func (o *orm) Using(name string) error { + if o.isTx { + panic(fmt.Errorf(" transaction has been start, cannot change db")) + } + if al, ok := dataBaseCache.get(name); ok { + o.alias = al + if Debug { + o.db = newDbQueryLog(al, al.DB) + } else { + o.db = al.DB + } + } else { + return fmt.Errorf(" unknown db alias name `%s`", name) + } + return nil +} + +// begin transaction +func (o *orm) Begin() error { + return o.BeginTx(context.Background(), nil) +} + +func (o *orm) BeginTx(ctx context.Context, opts *sql.TxOptions) error { + if o.isTx { + return ErrTxHasBegan + } + var tx *sql.Tx + tx, err := o.db.(txer).BeginTx(ctx, opts) + if err != nil { + return err + } + o.isTx = true + if Debug { + o.db.(*dbQueryLog).SetDB(tx) + } else { + o.db = tx + } + return nil +} + +// commit transaction +func (o *orm) Commit() error { + if !o.isTx { + return ErrTxDone + } + err := o.db.(txEnder).Commit() + if err == nil { + o.isTx = false + o.Using(o.alias.Name) + } else if err == sql.ErrTxDone { + return ErrTxDone + } + return err +} + +// rollback transaction +func (o *orm) Rollback() error { + if !o.isTx { + return ErrTxDone + } + err := o.db.(txEnder).Rollback() + if err == nil { + o.isTx = false + o.Using(o.alias.Name) + } else if err == sql.ErrTxDone { + return ErrTxDone + } + return err +} + +// return a raw query seter for raw sql string. +func (o *orm) Raw(query string, args ...interface{}) RawSeter { + return newRawSet(o, query, args) +} + +// return current using database Driver +func (o *orm) Driver() Driver { + return driver(o.alias.Name) +} + +// return sql.DBStats for current database +func (o *orm) DBStats() *sql.DBStats { + if o.alias != nil && o.alias.DB != nil { + stats := o.alias.DB.DB.Stats() + return &stats + } + return nil +} + +// NewOrm create new orm +func NewOrm() Ormer { + BootStrap() // execute only once + + o := new(orm) + err := o.Using("default") + if err != nil { + panic(err) + } + return o +} + +// NewOrmWithDB create a new ormer object with specify *sql.DB for query +func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { + var al *alias + + if dr, ok := drivers[driverName]; ok { + al = new(alias) + al.DbBaser = dbBasers[dr] + al.Driver = dr + } else { + return nil, fmt.Errorf("driver name `%s` have not registered", driverName) + } + + al.Name = aliasName + al.DriverName = driverName + al.DB = &DB{ + RWMutex: new(sync.RWMutex), + DB: db, + stmtDecorators: newStmtDecoratorLruWithEvict(), + } + + detectTZ(al) + + o := new(orm) + o.alias = al + + if Debug { + o.db = newDbQueryLog(o.alias, db) + } else { + o.db = db + } + + return o, nil +} diff --git a/pkg/orm/orm_conds.go b/pkg/orm/orm_conds.go new file mode 100644 index 00000000..f3fd66f0 --- /dev/null +++ b/pkg/orm/orm_conds.go @@ -0,0 +1,153 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "strings" +) + +// ExprSep define the expression separation +const ( + ExprSep = "__" +) + +type condValue struct { + exprs []string + args []interface{} + cond *Condition + isOr bool + isNot bool + isCond bool + isRaw bool + sql string +} + +// Condition struct. +// work for WHERE conditions. +type Condition struct { + params []condValue +} + +// NewCondition return new condition struct +func NewCondition() *Condition { + c := &Condition{} + return c +} + +// Raw add raw sql to condition +func (c Condition) Raw(expr string, sql string) *Condition { + if len(sql) == 0 { + panic(fmt.Errorf(" sql cannot empty")) + } + c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), sql: sql, isRaw: true}) + return &c +} + +// And add expression to condition +func (c Condition) And(expr string, args ...interface{}) *Condition { + if expr == "" || len(args) == 0 { + panic(fmt.Errorf(" args cannot empty")) + } + c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args}) + return &c +} + +// AndNot add NOT expression to condition +func (c Condition) AndNot(expr string, args ...interface{}) *Condition { + if expr == "" || len(args) == 0 { + panic(fmt.Errorf(" args cannot empty")) + } + c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true}) + return &c +} + +// AndCond combine a condition to current condition +func (c *Condition) AndCond(cond *Condition) *Condition { + c = c.clone() + if c == cond { + panic(fmt.Errorf(" cannot use self as sub cond")) + } + if cond != nil { + c.params = append(c.params, condValue{cond: cond, isCond: true}) + } + return c +} + +// AndNotCond combine a AND NOT condition to current condition +func (c *Condition) AndNotCond(cond *Condition) *Condition { + c = c.clone() + if c == cond { + panic(fmt.Errorf(" cannot use self as sub cond")) + } + + if cond != nil { + c.params = append(c.params, condValue{cond: cond, isCond: true, isNot: true}) + } + return c +} + +// Or add OR expression to condition +func (c Condition) Or(expr string, args ...interface{}) *Condition { + if expr == "" || len(args) == 0 { + panic(fmt.Errorf(" args cannot empty")) + } + c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isOr: true}) + return &c +} + +// OrNot add OR NOT expression to condition +func (c Condition) OrNot(expr string, args ...interface{}) *Condition { + if expr == "" || len(args) == 0 { + panic(fmt.Errorf(" args cannot empty")) + } + c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true, isOr: true}) + return &c +} + +// OrCond combine a OR condition to current condition +func (c *Condition) OrCond(cond *Condition) *Condition { + c = c.clone() + if c == cond { + panic(fmt.Errorf(" cannot use self as sub cond")) + } + if cond != nil { + c.params = append(c.params, condValue{cond: cond, isCond: true, isOr: true}) + } + return c +} + +// OrNotCond combine a OR NOT condition to current condition +func (c *Condition) OrNotCond(cond *Condition) *Condition { + c = c.clone() + if c == cond { + panic(fmt.Errorf(" cannot use self as sub cond")) + } + + if cond != nil { + c.params = append(c.params, condValue{cond: cond, isCond: true, isNot: true, isOr: true}) + } + return c +} + +// IsEmpty check the condition arguments are empty or not. +func (c *Condition) IsEmpty() bool { + return len(c.params) == 0 +} + +// clone clone a condition +func (c Condition) clone() *Condition { + return &c +} diff --git a/pkg/orm/orm_log.go b/pkg/orm/orm_log.go new file mode 100644 index 00000000..f107bb59 --- /dev/null +++ b/pkg/orm/orm_log.go @@ -0,0 +1,222 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "database/sql" + "fmt" + "io" + "log" + "strings" + "time" +) + +// Log implement the log.Logger +type Log struct { + *log.Logger +} + +//costomer log func +var LogFunc func(query map[string]interface{}) + +// NewLog set io.Writer to create a Logger. +func NewLog(out io.Writer) *Log { + d := new(Log) + d.Logger = log.New(out, "[ORM]", log.LstdFlags) + return d +} + +func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) { + var logMap = make(map[string]interface{}) + sub := time.Now().Sub(t) / 1e5 + elsp := float64(int(sub)) / 10.0 + logMap["cost_time"] = elsp + flag := " OK" + if err != nil { + flag = "FAIL" + } + logMap["flag"] = flag + con := fmt.Sprintf(" -[Queries/%s] - [%s / %11s / %7.1fms] - [%s]", alias.Name, flag, operaton, elsp, query) + cons := make([]string, 0, len(args)) + for _, arg := range args { + cons = append(cons, fmt.Sprintf("%v", arg)) + } + if len(cons) > 0 { + con += fmt.Sprintf(" - `%s`", strings.Join(cons, "`, `")) + } + if err != nil { + con += " - " + err.Error() + } + logMap["sql"] = fmt.Sprintf("%s-`%s`", query, strings.Join(cons, "`, `")) + if LogFunc != nil{ + LogFunc(logMap) + } + DebugLog.Println(con) +} + +// statement query logger struct. +// if dev mode, use stmtQueryLog, or use stmtQuerier. +type stmtQueryLog struct { + alias *alias + query string + stmt stmtQuerier +} + +var _ stmtQuerier = new(stmtQueryLog) + +func (d *stmtQueryLog) Close() error { + a := time.Now() + err := d.stmt.Close() + debugLogQueies(d.alias, "st.Close", d.query, a, err) + return err +} + +func (d *stmtQueryLog) Exec(args ...interface{}) (sql.Result, error) { + a := time.Now() + res, err := d.stmt.Exec(args...) + debugLogQueies(d.alias, "st.Exec", d.query, a, err, args...) + return res, err +} + +func (d *stmtQueryLog) Query(args ...interface{}) (*sql.Rows, error) { + a := time.Now() + res, err := d.stmt.Query(args...) + debugLogQueies(d.alias, "st.Query", d.query, a, err, args...) + return res, err +} + +func (d *stmtQueryLog) QueryRow(args ...interface{}) *sql.Row { + a := time.Now() + res := d.stmt.QueryRow(args...) + debugLogQueies(d.alias, "st.QueryRow", d.query, a, nil, args...) + return res +} + +func newStmtQueryLog(alias *alias, stmt stmtQuerier, query string) stmtQuerier { + d := new(stmtQueryLog) + d.stmt = stmt + d.alias = alias + d.query = query + return d +} + +// database query logger struct. +// if dev mode, use dbQueryLog, or use dbQuerier. +type dbQueryLog struct { + alias *alias + db dbQuerier + tx txer + txe txEnder +} + +var _ dbQuerier = new(dbQueryLog) +var _ txer = new(dbQueryLog) +var _ txEnder = new(dbQueryLog) + +func (d *dbQueryLog) Prepare(query string) (*sql.Stmt, error) { + a := time.Now() + stmt, err := d.db.Prepare(query) + debugLogQueies(d.alias, "db.Prepare", query, a, err) + return stmt, err +} + +func (d *dbQueryLog) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + a := time.Now() + stmt, err := d.db.PrepareContext(ctx, query) + debugLogQueies(d.alias, "db.Prepare", query, a, err) + return stmt, err +} + +func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error) { + a := time.Now() + res, err := d.db.Exec(query, args...) + debugLogQueies(d.alias, "db.Exec", query, a, err, args...) + return res, err +} + +func (d *dbQueryLog) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + a := time.Now() + res, err := d.db.ExecContext(ctx, query, args...) + debugLogQueies(d.alias, "db.Exec", query, a, err, args...) + return res, err +} + +func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error) { + a := time.Now() + res, err := d.db.Query(query, args...) + debugLogQueies(d.alias, "db.Query", query, a, err, args...) + return res, err +} + +func (d *dbQueryLog) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + a := time.Now() + res, err := d.db.QueryContext(ctx, query, args...) + debugLogQueies(d.alias, "db.Query", query, a, err, args...) + return res, err +} + +func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row { + a := time.Now() + res := d.db.QueryRow(query, args...) + debugLogQueies(d.alias, "db.QueryRow", query, a, nil, args...) + return res +} + +func (d *dbQueryLog) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + a := time.Now() + res := d.db.QueryRowContext(ctx, query, args...) + debugLogQueies(d.alias, "db.QueryRow", query, a, nil, args...) + return res +} + +func (d *dbQueryLog) Begin() (*sql.Tx, error) { + a := time.Now() + tx, err := d.db.(txer).Begin() + debugLogQueies(d.alias, "db.Begin", "START TRANSACTION", a, err) + return tx, err +} + +func (d *dbQueryLog) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { + a := time.Now() + tx, err := d.db.(txer).BeginTx(ctx, opts) + debugLogQueies(d.alias, "db.BeginTx", "START TRANSACTION", a, err) + return tx, err +} + +func (d *dbQueryLog) Commit() error { + a := time.Now() + err := d.db.(txEnder).Commit() + debugLogQueies(d.alias, "tx.Commit", "COMMIT", a, err) + return err +} + +func (d *dbQueryLog) Rollback() error { + a := time.Now() + err := d.db.(txEnder).Rollback() + debugLogQueies(d.alias, "tx.Rollback", "ROLLBACK", a, err) + return err +} + +func (d *dbQueryLog) SetDB(db dbQuerier) { + d.db = db +} + +func newDbQueryLog(alias *alias, db dbQuerier) dbQuerier { + d := new(dbQueryLog) + d.alias = alias + d.db = db + return d +} diff --git a/pkg/orm/orm_object.go b/pkg/orm/orm_object.go new file mode 100644 index 00000000..de3181ce --- /dev/null +++ b/pkg/orm/orm_object.go @@ -0,0 +1,87 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "reflect" +) + +// an insert queryer struct +type insertSet struct { + mi *modelInfo + orm *orm + stmt stmtQuerier + closed bool +} + +var _ Inserter = new(insertSet) + +// insert model ignore it's registered or not. +func (o *insertSet) Insert(md interface{}) (int64, error) { + if o.closed { + return 0, ErrStmtClosed + } + val := reflect.ValueOf(md) + ind := reflect.Indirect(val) + typ := ind.Type() + name := getFullName(typ) + if val.Kind() != reflect.Ptr { + panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", name)) + } + if name != o.mi.fullName { + panic(fmt.Errorf(" need model `%s` but found `%s`", o.mi.fullName, name)) + } + id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind, o.orm.alias.TZ) + if err != nil { + return id, err + } + if id > 0 { + if o.mi.fields.pk.auto { + if o.mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { + ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetUint(uint64(id)) + } else { + ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetInt(id) + } + } + } + return id, nil +} + +// close insert queryer statement +func (o *insertSet) Close() error { + if o.closed { + return ErrStmtClosed + } + o.closed = true + return o.stmt.Close() +} + +// create new insert queryer. +func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) { + bi := new(insertSet) + bi.orm = orm + bi.mi = mi + st, query, err := orm.alias.DbBaser.PrepareInsert(orm.db, mi) + if err != nil { + return nil, err + } + if Debug { + bi.stmt = newStmtQueryLog(orm.alias, st, query) + } else { + bi.stmt = st + } + return bi, nil +} diff --git a/pkg/orm/orm_querym2m.go b/pkg/orm/orm_querym2m.go new file mode 100644 index 00000000..6a270a0d --- /dev/null +++ b/pkg/orm/orm_querym2m.go @@ -0,0 +1,140 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import "reflect" + +// model to model struct +type queryM2M struct { + md interface{} + mi *modelInfo + fi *fieldInfo + qs *querySet + ind reflect.Value +} + +// add models to origin models when creating queryM2M. +// example: +// m2m := orm.QueryM2M(post,"Tag") +// m2m.Add(&Tag1{},&Tag2{}) +// for _,tag := range post.Tags{} +// +// make sure the relation is defined in post model struct tag. +func (o *queryM2M) Add(mds ...interface{}) (int64, error) { + fi := o.fi + mi := fi.relThroughModelInfo + mfi := fi.reverseFieldInfo + rfi := fi.reverseFieldInfoTwo + + orm := o.qs.orm + dbase := orm.alias.DbBaser + + var models []interface{} + var otherValues []interface{} + var otherNames []string + + for _, colname := range mi.fields.dbcols { + if colname != mfi.column && colname != rfi.column && colname != fi.mi.fields.pk.column && + mi.fields.columns[colname] != mi.fields.pk { + otherNames = append(otherNames, colname) + } + } + for i, md := range mds { + if reflect.Indirect(reflect.ValueOf(md)).Kind() != reflect.Struct && i > 0 { + otherValues = append(otherValues, md) + mds = append(mds[:i], mds[i+1:]...) + } + } + for _, md := range mds { + val := reflect.ValueOf(md) + if val.Kind() == reflect.Slice || val.Kind() == reflect.Array { + for i := 0; i < val.Len(); i++ { + v := val.Index(i) + if v.CanInterface() { + models = append(models, v.Interface()) + } + } + } else { + models = append(models, md) + } + } + + _, v1, exist := getExistPk(o.mi, o.ind) + if !exist { + panic(ErrMissPK) + } + + names := []string{mfi.column, rfi.column} + + values := make([]interface{}, 0, len(models)*2) + for _, md := range models { + + ind := reflect.Indirect(reflect.ValueOf(md)) + var v2 interface{} + if ind.Kind() != reflect.Struct { + v2 = ind.Interface() + } else { + _, v2, exist = getExistPk(fi.relModelInfo, ind) + if !exist { + panic(ErrMissPK) + } + } + values = append(values, v1, v2) + + } + names = append(names, otherNames...) + values = append(values, otherValues...) + return dbase.InsertValue(orm.db, mi, true, names, values) +} + +// remove models following the origin model relationship +func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { + fi := o.fi + qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md) + + return qs.Filter(fi.reverseFieldInfoTwo.name+ExprSep+"in", mds).Delete() +} + +// check model is existed in relationship of origin model +func (o *queryM2M) Exist(md interface{}) bool { + fi := o.fi + return o.qs.Filter(fi.reverseFieldInfo.name, o.md). + Filter(fi.reverseFieldInfoTwo.name, md).Exist() +} + +// clean all models in related of origin model +func (o *queryM2M) Clear() (int64, error) { + fi := o.fi + return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete() +} + +// count all related models of origin model +func (o *queryM2M) Count() (int64, error) { + fi := o.fi + return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count() +} + +var _ QueryM2Mer = new(queryM2M) + +// create new M2M queryer. +func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer { + qm2m := new(queryM2M) + qm2m.md = md + qm2m.mi = mi + qm2m.fi = fi + qm2m.ind = ind + qm2m.qs = newQuerySet(o, fi.relThroughModelInfo).(*querySet) + return qm2m +} diff --git a/pkg/orm/orm_queryset.go b/pkg/orm/orm_queryset.go new file mode 100644 index 00000000..878b836b --- /dev/null +++ b/pkg/orm/orm_queryset.go @@ -0,0 +1,300 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "fmt" +) + +type colValue struct { + value int64 + opt operator +} + +type operator int + +// define Col operations +const ( + ColAdd operator = iota + ColMinus + ColMultiply + ColExcept + ColBitAnd + ColBitRShift + ColBitLShift + ColBitXOR + ColBitOr +) + +// ColValue do the field raw changes. e.g Nums = Nums + 10. usage: +// Params{ +// "Nums": ColValue(Col_Add, 10), +// } +func ColValue(opt operator, value interface{}) interface{} { + switch opt { + case ColAdd, ColMinus, ColMultiply, ColExcept, ColBitAnd, ColBitRShift, + ColBitLShift, ColBitXOR, ColBitOr: + default: + panic(fmt.Errorf("orm.ColValue wrong operator")) + } + v, err := StrTo(ToStr(value)).Int64() + if err != nil { + panic(fmt.Errorf("orm.ColValue doesn't support non string/numeric type, %s", err)) + } + var val colValue + val.value = v + val.opt = opt + return val +} + +// real query struct +type querySet struct { + mi *modelInfo + cond *Condition + related []string + relDepth int + limit int64 + offset int64 + groups []string + orders []string + distinct bool + forupdate bool + orm *orm + ctx context.Context + forContext bool +} + +var _ QuerySeter = new(querySet) + +// add condition expression to QuerySeter. +func (o querySet) Filter(expr string, args ...interface{}) QuerySeter { + if o.cond == nil { + o.cond = NewCondition() + } + o.cond = o.cond.And(expr, args...) + return &o +} + +// add raw sql to querySeter. +func (o querySet) FilterRaw(expr string, sql string) QuerySeter { + if o.cond == nil { + o.cond = NewCondition() + } + o.cond = o.cond.Raw(expr, sql) + return &o +} + +// add NOT condition to querySeter. +func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter { + if o.cond == nil { + o.cond = NewCondition() + } + o.cond = o.cond.AndNot(expr, args...) + return &o +} + +// set offset number +func (o *querySet) setOffset(num interface{}) { + o.offset = ToInt64(num) +} + +// add LIMIT value. +// args[0] means offset, e.g. LIMIT num,offset. +func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter { + o.limit = ToInt64(limit) + if len(args) > 0 { + o.setOffset(args[0]) + } + return &o +} + +// add OFFSET value +func (o querySet) Offset(offset interface{}) QuerySeter { + o.setOffset(offset) + return &o +} + +// add GROUP expression +func (o querySet) GroupBy(exprs ...string) QuerySeter { + o.groups = exprs + return &o +} + +// add ORDER expression. +// "column" means ASC, "-column" means DESC. +func (o querySet) OrderBy(exprs ...string) QuerySeter { + o.orders = exprs + return &o +} + +// add DISTINCT to SELECT +func (o querySet) Distinct() QuerySeter { + o.distinct = true + return &o +} + +// add FOR UPDATE to SELECT +func (o querySet) ForUpdate() QuerySeter { + o.forupdate = true + return &o +} + +// set relation model to query together. +// it will query relation models and assign to parent model. +func (o querySet) RelatedSel(params ...interface{}) QuerySeter { + if len(params) == 0 { + o.relDepth = DefaultRelsDepth + } else { + for _, p := range params { + switch val := p.(type) { + case string: + o.related = append(o.related, val) + case int: + o.relDepth = val + default: + panic(fmt.Errorf(" wrong param kind: %v", val)) + } + } + } + return &o +} + +// set condition to QuerySeter. +func (o querySet) SetCond(cond *Condition) QuerySeter { + o.cond = cond + return &o +} + +// get condition from QuerySeter +func (o querySet) GetCond() *Condition { + return o.cond +} + +// return QuerySeter execution result number +func (o *querySet) Count() (int64, error) { + return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) +} + +// check result empty or not after QuerySeter executed +func (o *querySet) Exist() bool { + cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) + return cnt > 0 +} + +// execute update with parameters +func (o *querySet) Update(values Params) (int64, error) { + return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) +} + +// execute delete +func (o *querySet) Delete() (int64, error) { + return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) +} + +// return a insert queryer. +// it can be used in times. +// example: +// i,err := sq.PrepareInsert() +// i.Add(&user1{},&user2{}) +func (o *querySet) PrepareInsert() (Inserter, error) { + return newInsertSet(o.orm, o.mi) +} + +// query all data and map to containers. +// cols means the columns when querying. +func (o *querySet) All(container interface{}, cols ...string) (int64, error) { + return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) +} + +// query one row data and map to containers. +// cols means the columns when querying. +func (o *querySet) One(container interface{}, cols ...string) error { + o.limit = 1 + num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) + if err != nil { + return err + } + if num == 0 { + return ErrNoRows + } + + if num > 1 { + return ErrMultiRows + } + return nil +} + +// query all data and map to []map[string]interface. +// expres means condition expression. +// it converts data to []map[column]value. +func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) { + return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) +} + +// query all data and map to [][]interface +// it converts data to [][column_index]value +func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) { + return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) +} + +// query all data and map to []interface. +// it's designed for one row record set, auto change to []value, not [][column]value. +func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) { + return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) +} + +// query all rows into map[string]interface with specify key and value column name. +// keyCol = "name", valueCol = "value" +// table data +// name | value +// total | 100 +// found | 200 +// to map[string]interface{}{ +// "total": 100, +// "found": 200, +// } +func (o *querySet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) { + panic(ErrNotImplement) +} + +// query all rows into struct with specify key and value column name. +// keyCol = "name", valueCol = "value" +// table data +// name | value +// total | 100 +// found | 200 +// to struct { +// Total int +// Found int +// } +func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) { + panic(ErrNotImplement) +} + +// set context to QuerySeter. +func (o querySet) WithContext(ctx context.Context) QuerySeter { + o.ctx = ctx + o.forContext = true + return &o +} + +// create new QuerySeter. +func newQuerySet(orm *orm, mi *modelInfo) QuerySeter { + o := new(querySet) + o.mi = mi + o.orm = orm + return o +} diff --git a/pkg/orm/orm_raw.go b/pkg/orm/orm_raw.go new file mode 100644 index 00000000..3325a7ea --- /dev/null +++ b/pkg/orm/orm_raw.go @@ -0,0 +1,867 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "database/sql" + "fmt" + "reflect" + "time" +) + +// raw sql string prepared statement +type rawPrepare struct { + rs *rawSet + stmt stmtQuerier + closed bool +} + +func (o *rawPrepare) Exec(args ...interface{}) (sql.Result, error) { + if o.closed { + return nil, ErrStmtClosed + } + return o.stmt.Exec(args...) +} + +func (o *rawPrepare) Close() error { + o.closed = true + return o.stmt.Close() +} + +func newRawPreparer(rs *rawSet) (RawPreparer, error) { + o := new(rawPrepare) + o.rs = rs + + query := rs.query + rs.orm.alias.DbBaser.ReplaceMarks(&query) + + st, err := rs.orm.db.Prepare(query) + if err != nil { + return nil, err + } + if Debug { + o.stmt = newStmtQueryLog(rs.orm.alias, st, query) + } else { + o.stmt = st + } + return o, nil +} + +// raw query seter +type rawSet struct { + query string + args []interface{} + orm *orm +} + +var _ RawSeter = new(rawSet) + +// set args for every query +func (o rawSet) SetArgs(args ...interface{}) RawSeter { + o.args = args + return &o +} + +// execute raw sql and return sql.Result +func (o *rawSet) Exec() (sql.Result, error) { + query := o.query + o.orm.alias.DbBaser.ReplaceMarks(&query) + + args := getFlatParams(nil, o.args, o.orm.alias.TZ) + return o.orm.db.Exec(query, args...) +} + +// set field value to row container +func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { + switch ind.Kind() { + case reflect.Bool: + if value == nil { + ind.SetBool(false) + } else if v, ok := value.(bool); ok { + ind.SetBool(v) + } else { + v, _ := StrTo(ToStr(value)).Bool() + ind.SetBool(v) + } + + case reflect.String: + if value == nil { + ind.SetString("") + } else { + ind.SetString(ToStr(value)) + } + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if value == nil { + ind.SetInt(0) + } else { + val := reflect.ValueOf(value) + switch val.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + ind.SetInt(val.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + ind.SetInt(int64(val.Uint())) + default: + v, _ := StrTo(ToStr(value)).Int64() + ind.SetInt(v) + } + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if value == nil { + ind.SetUint(0) + } else { + val := reflect.ValueOf(value) + switch val.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + ind.SetUint(uint64(val.Int())) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + ind.SetUint(val.Uint()) + default: + v, _ := StrTo(ToStr(value)).Uint64() + ind.SetUint(v) + } + } + case reflect.Float64, reflect.Float32: + if value == nil { + ind.SetFloat(0) + } else { + val := reflect.ValueOf(value) + switch val.Kind() { + case reflect.Float64: + ind.SetFloat(val.Float()) + default: + v, _ := StrTo(ToStr(value)).Float64() + ind.SetFloat(v) + } + } + + case reflect.Struct: + if value == nil { + ind.Set(reflect.Zero(ind.Type())) + return + } + switch ind.Interface().(type) { + case time.Time: + var str string + switch d := value.(type) { + case time.Time: + o.orm.alias.DbBaser.TimeFromDB(&d, o.orm.alias.TZ) + ind.Set(reflect.ValueOf(d)) + case []byte: + str = string(d) + case string: + str = d + } + if str != "" { + if len(str) >= 19 { + str = str[:19] + t, err := time.ParseInLocation(formatDateTime, str, o.orm.alias.TZ) + if err == nil { + t = t.In(DefaultTimeLoc) + ind.Set(reflect.ValueOf(t)) + } + } else if len(str) >= 10 { + str = str[:10] + t, err := time.ParseInLocation(formatDate, str, DefaultTimeLoc) + if err == nil { + ind.Set(reflect.ValueOf(t)) + } + } + } + case sql.NullString, sql.NullInt64, sql.NullFloat64, sql.NullBool: + indi := reflect.New(ind.Type()).Interface() + sc, ok := indi.(sql.Scanner) + if !ok { + return + } + err := sc.Scan(value) + if err == nil { + ind.Set(reflect.Indirect(reflect.ValueOf(sc))) + } + } + + case reflect.Ptr: + if value == nil { + ind.Set(reflect.Zero(ind.Type())) + break + } + ind.Set(reflect.New(ind.Type().Elem())) + o.setFieldValue(reflect.Indirect(ind), value) + } +} + +// set field value in loop for slice container +func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) { + nInds := *nIndsPtr + + cur := 0 + for i := 0; i < len(sInds); i++ { + sInd := sInds[i] + eTyp := eTyps[i] + + typ := eTyp + isPtr := false + if typ.Kind() == reflect.Ptr { + isPtr = true + typ = typ.Elem() + } + if typ.Kind() == reflect.Ptr { + isPtr = true + typ = typ.Elem() + } + + var nInd reflect.Value + if init { + nInd = reflect.New(sInd.Type()).Elem() + } else { + nInd = nInds[i] + } + + val := reflect.New(typ) + ind := val.Elem() + + tpName := ind.Type().String() + + if ind.Kind() == reflect.Struct { + if tpName == "time.Time" { + value := reflect.ValueOf(refs[cur]).Elem().Interface() + if isPtr && value == nil { + val = reflect.New(val.Type()).Elem() + } else { + o.setFieldValue(ind, value) + } + cur++ + } + + } else { + value := reflect.ValueOf(refs[cur]).Elem().Interface() + if isPtr && value == nil { + val = reflect.New(val.Type()).Elem() + } else { + o.setFieldValue(ind, value) + } + cur++ + } + + if nInd.Kind() == reflect.Slice { + if isPtr { + nInd = reflect.Append(nInd, val) + } else { + nInd = reflect.Append(nInd, ind) + } + } else { + if isPtr { + nInd.Set(val) + } else { + nInd.Set(ind) + } + } + + nInds[i] = nInd + } +} + +// query data and map to container +func (o *rawSet) QueryRow(containers ...interface{}) error { + var ( + refs = make([]interface{}, 0, len(containers)) + sInds []reflect.Value + eTyps []reflect.Type + sMi *modelInfo + ) + structMode := false + for _, container := range containers { + val := reflect.ValueOf(container) + ind := reflect.Indirect(val) + + if val.Kind() != reflect.Ptr { + panic(fmt.Errorf(" all args must be use ptr")) + } + + etyp := ind.Type() + typ := etyp + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + + sInds = append(sInds, ind) + eTyps = append(eTyps, etyp) + + if typ.Kind() == reflect.Struct && typ.String() != "time.Time" { + if len(containers) > 1 { + panic(fmt.Errorf(" now support one struct only. see #384")) + } + + structMode = true + fn := getFullName(typ) + if mi, ok := modelCache.getByFullName(fn); ok { + sMi = mi + } + } else { + var ref interface{} + refs = append(refs, &ref) + } + } + + query := o.query + o.orm.alias.DbBaser.ReplaceMarks(&query) + + args := getFlatParams(nil, o.args, o.orm.alias.TZ) + rows, err := o.orm.db.Query(query, args...) + if err != nil { + if err == sql.ErrNoRows { + return ErrNoRows + } + return err + } + + defer rows.Close() + + if rows.Next() { + if structMode { + columns, err := rows.Columns() + if err != nil { + return err + } + + columnsMp := make(map[string]interface{}, len(columns)) + + refs = make([]interface{}, 0, len(columns)) + for _, col := range columns { + var ref interface{} + columnsMp[col] = &ref + refs = append(refs, &ref) + } + + if err := rows.Scan(refs...); err != nil { + return err + } + + ind := sInds[0] + + if ind.Kind() == reflect.Ptr { + if ind.IsNil() || !ind.IsValid() { + ind.Set(reflect.New(eTyps[0].Elem())) + } + ind = ind.Elem() + } + + if sMi != nil { + for _, col := range columns { + if fi := sMi.fields.GetByColumn(col); fi != nil { + value := reflect.ValueOf(columnsMp[col]).Elem().Interface() + field := ind.FieldByIndex(fi.fieldIndex) + if fi.fieldType&IsRelField > 0 { + mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) + field.Set(mf) + field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) + } + o.setFieldValue(field, value) + } + } + } else { + for i := 0; i < ind.NumField(); i++ { + f := ind.Field(i) + fe := ind.Type().Field(i) + _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) + var col string + if col = tags["column"]; col == "" { + col = nameStrategyMap[nameStrategy](fe.Name) + } + if v, ok := columnsMp[col]; ok { + value := reflect.ValueOf(v).Elem().Interface() + o.setFieldValue(f, value) + } + } + } + + } else { + if err := rows.Scan(refs...); err != nil { + return err + } + + nInds := make([]reflect.Value, len(sInds)) + o.loopSetRefs(refs, sInds, &nInds, eTyps, true) + for i, sInd := range sInds { + nInd := nInds[i] + sInd.Set(nInd) + } + } + + } else { + return ErrNoRows + } + + return nil +} + +// query data rows and map to container +func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { + var ( + refs = make([]interface{}, 0, len(containers)) + sInds []reflect.Value + eTyps []reflect.Type + sMi *modelInfo + ) + structMode := false + for _, container := range containers { + val := reflect.ValueOf(container) + sInd := reflect.Indirect(val) + if val.Kind() != reflect.Ptr || sInd.Kind() != reflect.Slice { + panic(fmt.Errorf(" all args must be use ptr slice")) + } + + etyp := sInd.Type().Elem() + typ := etyp + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + + sInds = append(sInds, sInd) + eTyps = append(eTyps, etyp) + + if typ.Kind() == reflect.Struct && typ.String() != "time.Time" { + if len(containers) > 1 { + panic(fmt.Errorf(" now support one struct only. see #384")) + } + + structMode = true + fn := getFullName(typ) + if mi, ok := modelCache.getByFullName(fn); ok { + sMi = mi + } + } else { + var ref interface{} + refs = append(refs, &ref) + } + } + + query := o.query + o.orm.alias.DbBaser.ReplaceMarks(&query) + + args := getFlatParams(nil, o.args, o.orm.alias.TZ) + rows, err := o.orm.db.Query(query, args...) + if err != nil { + return 0, err + } + + defer rows.Close() + + var cnt int64 + nInds := make([]reflect.Value, len(sInds)) + sInd := sInds[0] + + for rows.Next() { + + if structMode { + columns, err := rows.Columns() + if err != nil { + return 0, err + } + + columnsMp := make(map[string]interface{}, len(columns)) + + refs = make([]interface{}, 0, len(columns)) + for _, col := range columns { + var ref interface{} + columnsMp[col] = &ref + refs = append(refs, &ref) + } + + if err := rows.Scan(refs...); err != nil { + return 0, err + } + + if cnt == 0 && !sInd.IsNil() { + sInd.Set(reflect.New(sInd.Type()).Elem()) + } + + var ind reflect.Value + if eTyps[0].Kind() == reflect.Ptr { + ind = reflect.New(eTyps[0].Elem()) + } else { + ind = reflect.New(eTyps[0]) + } + + if ind.Kind() == reflect.Ptr { + ind = ind.Elem() + } + + if sMi != nil { + for _, col := range columns { + if fi := sMi.fields.GetByColumn(col); fi != nil { + value := reflect.ValueOf(columnsMp[col]).Elem().Interface() + field := ind.FieldByIndex(fi.fieldIndex) + if fi.fieldType&IsRelField > 0 { + mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) + field.Set(mf) + field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) + } + o.setFieldValue(field, value) + } + } + } else { + // define recursive function + var recursiveSetField func(rv reflect.Value) + recursiveSetField = func(rv reflect.Value) { + for i := 0; i < rv.NumField(); i++ { + f := rv.Field(i) + fe := rv.Type().Field(i) + + // check if the field is a Struct + // recursive the Struct type + if fe.Type.Kind() == reflect.Struct { + recursiveSetField(f) + } + + _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) + var col string + if col = tags["column"]; col == "" { + col = nameStrategyMap[nameStrategy](fe.Name) + } + if v, ok := columnsMp[col]; ok { + value := reflect.ValueOf(v).Elem().Interface() + o.setFieldValue(f, value) + } + } + } + + // init call the recursive function + recursiveSetField(ind) + } + + if eTyps[0].Kind() == reflect.Ptr { + ind = ind.Addr() + } + + sInd = reflect.Append(sInd, ind) + + } else { + if err := rows.Scan(refs...); err != nil { + return 0, err + } + + o.loopSetRefs(refs, sInds, &nInds, eTyps, cnt == 0) + } + + cnt++ + } + + if cnt > 0 { + + if structMode { + sInds[0].Set(sInd) + } else { + for i, sInd := range sInds { + nInd := nInds[i] + sInd.Set(nInd) + } + } + } + + return cnt, nil +} + +func (o *rawSet) readValues(container interface{}, needCols []string) (int64, error) { + var ( + maps []Params + lists []ParamsList + list ParamsList + ) + + typ := 0 + switch container.(type) { + case *[]Params: + typ = 1 + case *[]ParamsList: + typ = 2 + case *ParamsList: + typ = 3 + default: + panic(fmt.Errorf(" unsupport read values type `%T`", container)) + } + + query := o.query + o.orm.alias.DbBaser.ReplaceMarks(&query) + + args := getFlatParams(nil, o.args, o.orm.alias.TZ) + + var rs *sql.Rows + rs, err := o.orm.db.Query(query, args...) + if err != nil { + return 0, err + } + + defer rs.Close() + + var ( + refs []interface{} + cnt int64 + cols []string + indexs []int + ) + + for rs.Next() { + if cnt == 0 { + columns, err := rs.Columns() + if err != nil { + return 0, err + } + if len(needCols) > 0 { + indexs = make([]int, 0, len(needCols)) + } else { + indexs = make([]int, 0, len(columns)) + } + + cols = columns + refs = make([]interface{}, len(cols)) + for i := range refs { + var ref sql.NullString + refs[i] = &ref + + if len(needCols) > 0 { + for _, c := range needCols { + if c == cols[i] { + indexs = append(indexs, i) + } + } + } else { + indexs = append(indexs, i) + } + } + } + + if err := rs.Scan(refs...); err != nil { + return 0, err + } + + switch typ { + case 1: + params := make(Params, len(cols)) + for _, i := range indexs { + ref := refs[i] + value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) + if value.Valid { + params[cols[i]] = value.String + } else { + params[cols[i]] = nil + } + } + maps = append(maps, params) + case 2: + params := make(ParamsList, 0, len(cols)) + for _, i := range indexs { + ref := refs[i] + value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) + if value.Valid { + params = append(params, value.String) + } else { + params = append(params, nil) + } + } + lists = append(lists, params) + case 3: + for _, i := range indexs { + ref := refs[i] + value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) + if value.Valid { + list = append(list, value.String) + } else { + list = append(list, nil) + } + } + } + + cnt++ + } + + switch v := container.(type) { + case *[]Params: + *v = maps + case *[]ParamsList: + *v = lists + case *ParamsList: + *v = list + } + + return cnt, nil +} + +func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (int64, error) { + var ( + maps Params + ind *reflect.Value + ) + + var typ int + switch container.(type) { + case *Params: + typ = 1 + default: + typ = 2 + vl := reflect.ValueOf(container) + id := reflect.Indirect(vl) + if vl.Kind() != reflect.Ptr || id.Kind() != reflect.Struct { + panic(fmt.Errorf(" RowsTo unsupport type `%T` need ptr struct", container)) + } + + ind = &id + } + + query := o.query + o.orm.alias.DbBaser.ReplaceMarks(&query) + + args := getFlatParams(nil, o.args, o.orm.alias.TZ) + + rs, err := o.orm.db.Query(query, args...) + if err != nil { + return 0, err + } + + defer rs.Close() + + var ( + refs []interface{} + cnt int64 + cols []string + ) + + var ( + keyIndex = -1 + valueIndex = -1 + ) + + for rs.Next() { + if cnt == 0 { + columns, err := rs.Columns() + if err != nil { + return 0, err + } + cols = columns + refs = make([]interface{}, len(cols)) + for i := range refs { + if keyCol == cols[i] { + keyIndex = i + } + if typ == 1 || keyIndex == i { + var ref sql.NullString + refs[i] = &ref + } else { + var ref interface{} + refs[i] = &ref + } + if valueCol == cols[i] { + valueIndex = i + } + } + if keyIndex == -1 || valueIndex == -1 { + panic(fmt.Errorf(" RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol)) + } + } + + if err := rs.Scan(refs...); err != nil { + return 0, err + } + + if cnt == 0 { + switch typ { + case 1: + maps = make(Params) + } + } + + key := reflect.Indirect(reflect.ValueOf(refs[keyIndex])).Interface().(sql.NullString).String + + switch typ { + case 1: + value := reflect.Indirect(reflect.ValueOf(refs[valueIndex])).Interface().(sql.NullString) + if value.Valid { + maps[key] = value.String + } else { + maps[key] = nil + } + + default: + if id := ind.FieldByName(camelString(key)); id.IsValid() { + o.setFieldValue(id, reflect.ValueOf(refs[valueIndex]).Elem().Interface()) + } + } + + cnt++ + } + + if typ == 1 { + v, _ := container.(*Params) + *v = maps + } + + return cnt, nil +} + +// query data to []map[string]interface +func (o *rawSet) Values(container *[]Params, cols ...string) (int64, error) { + return o.readValues(container, cols) +} + +// query data to [][]interface +func (o *rawSet) ValuesList(container *[]ParamsList, cols ...string) (int64, error) { + return o.readValues(container, cols) +} + +// query data to []interface +func (o *rawSet) ValuesFlat(container *ParamsList, cols ...string) (int64, error) { + return o.readValues(container, cols) +} + +// query all rows into map[string]interface with specify key and value column name. +// keyCol = "name", valueCol = "value" +// table data +// name | value +// total | 100 +// found | 200 +// to map[string]interface{}{ +// "total": 100, +// "found": 200, +// } +func (o *rawSet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) { + return o.queryRowsTo(result, keyCol, valueCol) +} + +// query all rows into struct with specify key and value column name. +// keyCol = "name", valueCol = "value" +// table data +// name | value +// total | 100 +// found | 200 +// to struct { +// Total int +// Found int +// } +func (o *rawSet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) { + return o.queryRowsTo(ptrStruct, keyCol, valueCol) +} + +// return prepared raw statement for used in times. +func (o *rawSet) Prepare() (RawPreparer, error) { + return newRawPreparer(o) +} + +func newRawSet(orm *orm, query string, args []interface{}) RawSeter { + o := new(rawSet) + o.query = query + o.args = args + o.orm = orm + return o +} diff --git a/pkg/orm/orm_test.go b/pkg/orm/orm_test.go new file mode 100644 index 00000000..bdb430b6 --- /dev/null +++ b/pkg/orm/orm_test.go @@ -0,0 +1,2494 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build go1.8 + +package orm + +import ( + "bytes" + "context" + "database/sql" + "fmt" + "io/ioutil" + "math" + "os" + "path/filepath" + "reflect" + "runtime" + "strings" + "testing" + "time" +) + +var _ = os.PathSeparator + +var ( + testDate = formatDate + " -0700" + testDateTime = formatDateTime + " -0700" + testTime = formatTime + " -0700" +) + +type argAny []interface{} + +// get interface by index from interface slice +func (a argAny) Get(i int, args ...interface{}) (r interface{}) { + if i >= 0 && i < len(a) { + r = a[i] + } + if len(args) > 0 { + r = args[0] + } + return +} + +func ValuesCompare(is bool, a interface{}, args ...interface{}) (ok bool, err error) { + if len(args) == 0 { + return false, fmt.Errorf("miss args") + } + b := args[0] + arg := argAny(args) + + switch v := a.(type) { + case reflect.Kind: + ok = reflect.ValueOf(b).Kind() == v + case time.Time: + if v2, vo := b.(time.Time); vo { + if arg.Get(1) != nil { + format := ToStr(arg.Get(1)) + a = v.Format(format) + b = v2.Format(format) + ok = a == b + } else { + err = fmt.Errorf("compare datetime miss format") + goto wrongArg + } + } + default: + ok = ToStr(a) == ToStr(b) + } + ok = is && ok || !is && !ok + if !ok { + if is { + err = fmt.Errorf("expected: `%v`, get `%v`", b, a) + } else { + err = fmt.Errorf("expected: `%v`, get `%v`", b, a) + } + } + +wrongArg: + if err != nil { + return false, err + } + + return true, nil +} + +func AssertIs(a interface{}, args ...interface{}) error { + if ok, err := ValuesCompare(true, a, args...); !ok { + return err + } + return nil +} + +func AssertNot(a interface{}, args ...interface{}) error { + if ok, err := ValuesCompare(false, a, args...); !ok { + return err + } + return nil +} + +func getCaller(skip int) string { + pc, file, line, _ := runtime.Caller(skip) + fun := runtime.FuncForPC(pc) + _, fn := filepath.Split(file) + data, err := ioutil.ReadFile(file) + var codes []string + if err == nil { + lines := bytes.Split(data, []byte{'\n'}) + n := 10 + for i := 0; i < n; i++ { + o := line - n + if o < 0 { + continue + } + cur := o + i + 1 + flag := " " + if cur == line { + flag = ">>" + } + code := fmt.Sprintf(" %s %5d: %s", flag, cur, strings.Replace(string(lines[o+i]), "\t", " ", -1)) + if code != "" { + codes = append(codes, code) + } + } + } + funName := fun.Name() + if i := strings.LastIndex(funName, "."); i > -1 { + funName = funName[i+1:] + } + return fmt.Sprintf("%s:%s:%d: \n%s", fn, funName, line, strings.Join(codes, "\n")) +} + +func throwFail(t *testing.T, err error, args ...interface{}) { + if err != nil { + con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) + if len(args) > 0 { + parts := make([]string, 0, len(args)) + for _, arg := range args { + parts = append(parts, fmt.Sprintf("%v", arg)) + } + con += " " + strings.Join(parts, ", ") + } + t.Error(con) + t.Fail() + } +} + +func throwFailNow(t *testing.T, err error, args ...interface{}) { + if err != nil { + con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) + if len(args) > 0 { + parts := make([]string, 0, len(args)) + for _, arg := range args { + parts = append(parts, fmt.Sprintf("%v", arg)) + } + con += " " + strings.Join(parts, ", ") + } + t.Error(con) + t.FailNow() + } +} + +func TestGetDB(t *testing.T) { + if db, err := GetDB(); err != nil { + throwFailNow(t, err) + } else { + err = db.Ping() + throwFailNow(t, err) + } +} + +func TestSyncDb(t *testing.T) { + RegisterModel(new(Data), new(DataNull), new(DataCustom)) + RegisterModel(new(User)) + RegisterModel(new(Profile)) + RegisterModel(new(Post)) + RegisterModel(new(Tag)) + RegisterModel(new(Comment)) + RegisterModel(new(UserBig)) + RegisterModel(new(PostTags)) + RegisterModel(new(Group)) + RegisterModel(new(Permission)) + RegisterModel(new(GroupPermissions)) + RegisterModel(new(InLine)) + RegisterModel(new(InLineOneToOne)) + RegisterModel(new(IntegerPk)) + RegisterModel(new(UintPk)) + RegisterModel(new(PtrPk)) + + err := RunSyncdb("default", true, Debug) + throwFail(t, err) + + modelCache.clean() +} + +func TestRegisterModels(t *testing.T) { + RegisterModel(new(Data), new(DataNull), new(DataCustom)) + RegisterModel(new(User)) + RegisterModel(new(Profile)) + RegisterModel(new(Post)) + RegisterModel(new(Tag)) + RegisterModel(new(Comment)) + RegisterModel(new(UserBig)) + RegisterModel(new(PostTags)) + RegisterModel(new(Group)) + RegisterModel(new(Permission)) + RegisterModel(new(GroupPermissions)) + RegisterModel(new(InLine)) + RegisterModel(new(InLineOneToOne)) + RegisterModel(new(IntegerPk)) + RegisterModel(new(UintPk)) + RegisterModel(new(PtrPk)) + + BootStrap() + + dORM = NewOrm() + dDbBaser = getDbAlias("default").DbBaser +} + +func TestModelSyntax(t *testing.T) { + user := &User{} + ind := reflect.ValueOf(user).Elem() + fn := getFullName(ind.Type()) + mi, ok := modelCache.getByFullName(fn) + throwFail(t, AssertIs(ok, true)) + + mi, ok = modelCache.get("user") + throwFail(t, AssertIs(ok, true)) + if ok { + throwFail(t, AssertIs(mi.fields.GetByName("ShouldSkip") == nil, true)) + } +} + +var DataValues = map[string]interface{}{ + "Boolean": true, + "Char": "char", + "Text": "text", + "JSON": `{"name":"json"}`, + "Jsonb": `{"name": "jsonb"}`, + "Time": time.Now(), + "Date": time.Now(), + "DateTime": time.Now(), + "Byte": byte(1<<8 - 1), + "Rune": rune(1<<31 - 1), + "Int": int(1<<31 - 1), + "Int8": int8(1<<7 - 1), + "Int16": int16(1<<15 - 1), + "Int32": int32(1<<31 - 1), + "Int64": int64(1<<63 - 1), + "Uint": uint(1<<32 - 1), + "Uint8": uint8(1<<8 - 1), + "Uint16": uint16(1<<16 - 1), + "Uint32": uint32(1<<32 - 1), + "Uint64": uint64(1<<63 - 1), // uint64 values with high bit set are not supported + "Float32": float32(100.1234), + "Float64": float64(100.1234), + "Decimal": float64(100.1234), +} + +func TestDataTypes(t *testing.T) { + d := Data{} + ind := reflect.Indirect(reflect.ValueOf(&d)) + + for name, value := range DataValues { + if name == "JSON" { + continue + } + e := ind.FieldByName(name) + e.Set(reflect.ValueOf(value)) + } + id, err := dORM.Insert(&d) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + d = Data{ID: 1} + err = dORM.Read(&d) + throwFail(t, err) + + ind = reflect.Indirect(reflect.ValueOf(&d)) + + for name, value := range DataValues { + e := ind.FieldByName(name) + vu := e.Interface() + switch name { + case "Date": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) + case "DateTime": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + case "Time": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) + } + throwFail(t, AssertIs(vu == value, true), value, vu) + } +} + +func TestNullDataTypes(t *testing.T) { + d := DataNull{} + + if IsPostgres { + // can removed when this fixed + // https://github.com/lib/pq/pull/125 + d.DateTime = time.Now() + } + + id, err := dORM.Insert(&d) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + data := `{"ok":1,"data":{"arr":[1,2],"msg":"gopher"}}` + d = DataNull{ID: 1, JSON: data} + num, err := dORM.Update(&d) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + d = DataNull{ID: 1} + err = dORM.Read(&d) + throwFail(t, err) + + throwFail(t, AssertIs(d.JSON, data)) + + throwFail(t, AssertIs(d.NullBool.Valid, false)) + throwFail(t, AssertIs(d.NullString.Valid, false)) + throwFail(t, AssertIs(d.NullInt64.Valid, false)) + throwFail(t, AssertIs(d.NullFloat64.Valid, false)) + + throwFail(t, AssertIs(d.BooleanPtr, nil)) + throwFail(t, AssertIs(d.CharPtr, nil)) + throwFail(t, AssertIs(d.TextPtr, nil)) + throwFail(t, AssertIs(d.BytePtr, nil)) + throwFail(t, AssertIs(d.RunePtr, nil)) + throwFail(t, AssertIs(d.IntPtr, nil)) + throwFail(t, AssertIs(d.Int8Ptr, nil)) + throwFail(t, AssertIs(d.Int16Ptr, nil)) + throwFail(t, AssertIs(d.Int32Ptr, nil)) + throwFail(t, AssertIs(d.Int64Ptr, nil)) + throwFail(t, AssertIs(d.UintPtr, nil)) + throwFail(t, AssertIs(d.Uint8Ptr, nil)) + throwFail(t, AssertIs(d.Uint16Ptr, nil)) + throwFail(t, AssertIs(d.Uint32Ptr, nil)) + throwFail(t, AssertIs(d.Uint64Ptr, nil)) + throwFail(t, AssertIs(d.Float32Ptr, nil)) + throwFail(t, AssertIs(d.Float64Ptr, nil)) + throwFail(t, AssertIs(d.DecimalPtr, nil)) + throwFail(t, AssertIs(d.TimePtr, nil)) + throwFail(t, AssertIs(d.DatePtr, nil)) + throwFail(t, AssertIs(d.DateTimePtr, nil)) + + _, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec() + throwFail(t, err) + + d = DataNull{ID: 2} + err = dORM.Read(&d) + throwFail(t, err) + + booleanPtr := true + charPtr := string("test") + textPtr := string("test") + bytePtr := byte('t') + runePtr := rune('t') + intPtr := int(42) + int8Ptr := int8(42) + int16Ptr := int16(42) + int32Ptr := int32(42) + int64Ptr := int64(42) + uintPtr := uint(42) + uint8Ptr := uint8(42) + uint16Ptr := uint16(42) + uint32Ptr := uint32(42) + uint64Ptr := uint64(42) + float32Ptr := float32(42.0) + float64Ptr := float64(42.0) + decimalPtr := float64(42.0) + timePtr := time.Now() + datePtr := time.Now() + dateTimePtr := time.Now() + + d = DataNull{ + DateTime: time.Now(), + NullString: sql.NullString{String: "test", Valid: true}, + NullBool: sql.NullBool{Bool: true, Valid: true}, + NullInt64: sql.NullInt64{Int64: 42, Valid: true}, + NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, + BooleanPtr: &booleanPtr, + CharPtr: &charPtr, + TextPtr: &textPtr, + BytePtr: &bytePtr, + RunePtr: &runePtr, + IntPtr: &intPtr, + Int8Ptr: &int8Ptr, + Int16Ptr: &int16Ptr, + Int32Ptr: &int32Ptr, + Int64Ptr: &int64Ptr, + UintPtr: &uintPtr, + Uint8Ptr: &uint8Ptr, + Uint16Ptr: &uint16Ptr, + Uint32Ptr: &uint32Ptr, + Uint64Ptr: &uint64Ptr, + Float32Ptr: &float32Ptr, + Float64Ptr: &float64Ptr, + DecimalPtr: &decimalPtr, + TimePtr: &timePtr, + DatePtr: &datePtr, + DateTimePtr: &dateTimePtr, + } + + id, err = dORM.Insert(&d) + throwFail(t, err) + throwFail(t, AssertIs(id, 3)) + + d = DataNull{ID: 3} + err = dORM.Read(&d) + throwFail(t, err) + + throwFail(t, AssertIs(d.NullBool.Valid, true)) + throwFail(t, AssertIs(d.NullBool.Bool, true)) + + throwFail(t, AssertIs(d.NullString.Valid, true)) + throwFail(t, AssertIs(d.NullString.String, "test")) + + throwFail(t, AssertIs(d.NullInt64.Valid, true)) + throwFail(t, AssertIs(d.NullInt64.Int64, 42)) + + throwFail(t, AssertIs(d.NullFloat64.Valid, true)) + throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42)) + + throwFail(t, AssertIs(*d.BooleanPtr, booleanPtr)) + throwFail(t, AssertIs(*d.CharPtr, charPtr)) + throwFail(t, AssertIs(*d.TextPtr, textPtr)) + throwFail(t, AssertIs(*d.BytePtr, bytePtr)) + throwFail(t, AssertIs(*d.RunePtr, runePtr)) + throwFail(t, AssertIs(*d.IntPtr, intPtr)) + throwFail(t, AssertIs(*d.Int8Ptr, int8Ptr)) + throwFail(t, AssertIs(*d.Int16Ptr, int16Ptr)) + throwFail(t, AssertIs(*d.Int32Ptr, int32Ptr)) + throwFail(t, AssertIs(*d.Int64Ptr, int64Ptr)) + throwFail(t, AssertIs(*d.UintPtr, uintPtr)) + throwFail(t, AssertIs(*d.Uint8Ptr, uint8Ptr)) + throwFail(t, AssertIs(*d.Uint16Ptr, uint16Ptr)) + throwFail(t, AssertIs(*d.Uint32Ptr, uint32Ptr)) + throwFail(t, AssertIs(*d.Uint64Ptr, uint64Ptr)) + throwFail(t, AssertIs(*d.Float32Ptr, float32Ptr)) + throwFail(t, AssertIs(*d.Float64Ptr, float64Ptr)) + throwFail(t, AssertIs(*d.DecimalPtr, decimalPtr)) + throwFail(t, AssertIs((*d.TimePtr).UTC().Format(testTime), timePtr.UTC().Format(testTime))) + throwFail(t, AssertIs((*d.DatePtr).UTC().Format(testDate), datePtr.UTC().Format(testDate))) + throwFail(t, AssertIs((*d.DateTimePtr).UTC().Format(testDateTime), dateTimePtr.UTC().Format(testDateTime))) + + // test support for pointer fields using RawSeter.QueryRows() + var dnList []*DataNull + Q := dDbBaser.TableQuote() + num, err = dORM.Raw(fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q), 3).QueryRows(&dnList) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + equal := reflect.DeepEqual(*dnList[0], d) + throwFailNow(t, AssertIs(equal, true)) +} + +func TestDataCustomTypes(t *testing.T) { + d := DataCustom{} + ind := reflect.Indirect(reflect.ValueOf(&d)) + + for name, value := range DataValues { + e := ind.FieldByName(name) + if !e.IsValid() { + continue + } + e.Set(reflect.ValueOf(value).Convert(e.Type())) + } + + id, err := dORM.Insert(&d) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + d = DataCustom{ID: 1} + err = dORM.Read(&d) + throwFail(t, err) + + ind = reflect.Indirect(reflect.ValueOf(&d)) + + for name, value := range DataValues { + e := ind.FieldByName(name) + if !e.IsValid() { + continue + } + vu := e.Interface() + value = reflect.ValueOf(value).Convert(e.Type()).Interface() + throwFail(t, AssertIs(vu == value, true), value, vu) + } +} + +func TestCRUD(t *testing.T) { + profile := NewProfile() + profile.Age = 30 + profile.Money = 1234.12 + id, err := dORM.Insert(profile) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + user := NewUser() + user.UserName = "slene" + user.Email = "vslene@gmail.com" + user.Password = "pass" + user.Status = 3 + user.IsStaff = true + user.IsActive = true + + id, err = dORM.Insert(user) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + u := &User{ID: user.ID} + err = dORM.Read(u) + throwFail(t, err) + + throwFail(t, AssertIs(u.UserName, "slene")) + throwFail(t, AssertIs(u.Email, "vslene@gmail.com")) + throwFail(t, AssertIs(u.Password, "pass")) + throwFail(t, AssertIs(u.Status, 3)) + throwFail(t, AssertIs(u.IsStaff, true)) + throwFail(t, AssertIs(u.IsActive, true)) + throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), user.Created.In(DefaultTimeLoc), testDate)) + throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), user.Updated.In(DefaultTimeLoc), testDateTime)) + + user.UserName = "astaxie" + user.Profile = profile + num, err := dORM.Update(user) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + u = &User{ID: user.ID} + err = dORM.Read(u) + throwFailNow(t, err) + throwFail(t, AssertIs(u.UserName, "astaxie")) + throwFail(t, AssertIs(u.Profile.ID, profile.ID)) + + u = &User{UserName: "astaxie", Password: "pass"} + err = dORM.Read(u, "UserName") + throwFailNow(t, err) + throwFailNow(t, AssertIs(id, 1)) + + u.UserName = "QQ" + u.Password = "111" + num, err = dORM.Update(u, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + u = &User{ID: user.ID} + err = dORM.Read(u) + throwFailNow(t, err) + throwFail(t, AssertIs(u.UserName, "QQ")) + throwFail(t, AssertIs(u.Password, "pass")) + + num, err = dORM.Delete(profile) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + u = &User{ID: user.ID} + err = dORM.Read(u) + throwFail(t, err) + throwFail(t, AssertIs(true, u.Profile == nil)) + + num, err = dORM.Delete(user) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + u = &User{ID: 100} + err = dORM.Read(u) + throwFail(t, AssertIs(err, ErrNoRows)) + + ub := UserBig{} + ub.Name = "name" + id, err = dORM.Insert(&ub) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + ub = UserBig{ID: 1} + err = dORM.Read(&ub) + throwFail(t, err) + throwFail(t, AssertIs(ub.Name, "name")) + + num, err = dORM.Delete(&ub, "name") + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestInsertTestData(t *testing.T) { + var users []*User + + profile := NewProfile() + profile.Age = 28 + profile.Money = 1234.12 + + id, err := dORM.Insert(profile) + throwFail(t, err) + throwFail(t, AssertIs(id, 2)) + + user := NewUser() + user.UserName = "slene" + user.Email = "vslene@gmail.com" + user.Password = "pass" + user.Status = 1 + user.IsStaff = false + user.IsActive = true + user.Profile = profile + + users = append(users, user) + + id, err = dORM.Insert(user) + throwFail(t, err) + throwFail(t, AssertIs(id, 2)) + + profile = NewProfile() + profile.Age = 30 + profile.Money = 4321.09 + + id, err = dORM.Insert(profile) + throwFail(t, err) + throwFail(t, AssertIs(id, 3)) + + user = NewUser() + user.UserName = "astaxie" + user.Email = "astaxie@gmail.com" + user.Password = "password" + user.Status = 2 + user.IsStaff = true + user.IsActive = false + user.Profile = profile + + users = append(users, user) + + id, err = dORM.Insert(user) + throwFail(t, err) + throwFail(t, AssertIs(id, 3)) + + user = NewUser() + user.UserName = "nobody" + user.Email = "nobody@gmail.com" + user.Password = "nobody" + user.Status = 3 + user.IsStaff = false + user.IsActive = false + + users = append(users, user) + + id, err = dORM.Insert(user) + throwFail(t, err) + throwFail(t, AssertIs(id, 4)) + + tags := []*Tag{ + {Name: "golang", BestPost: &Post{ID: 2}}, + {Name: "example"}, + {Name: "format"}, + {Name: "c++"}, + } + + posts := []*Post{ + {User: users[0], Tags: []*Tag{tags[0]}, Title: "Introduction", Content: `Go is a new language. Although it borrows ideas from existing languages, it has unusual properties that make effective Go programs different in character from programs written in its relatives. A straightforward translation of a C++ or Java program into Go is unlikely to produce a satisfactory result—Java programs are written in Java, not Go. On the other hand, thinking about the problem from a Go perspective could produce a successful but quite different program. In other words, to write Go well, it's important to understand its properties and idioms. It's also important to know the established conventions for programming in Go, such as naming, formatting, program construction, and so on, so that programs you write will be easy for other Go programmers to understand. +This document gives tips for writing clear, idiomatic Go code. It augments the language specification, the Tour of Go, and How to Write Go Code, all of which you should read first.`}, + {User: users[1], Tags: []*Tag{tags[0], tags[1]}, Title: "Examples", Content: `The Go package sources are intended to serve not only as the core library but also as examples of how to use the language. Moreover, many of the packages contain working, self-contained executable examples you can run directly from the golang.org web site, such as this one (click on the word "Example" to open it up). If you have a question about how to approach a problem or how something might be implemented, the documentation, code and examples in the library can provide answers, ideas and background.`}, + {User: users[1], Tags: []*Tag{tags[0], tags[2]}, Title: "Formatting", Content: `Formatting issues are the most contentious but the least consequential. People can adapt to different formatting styles but it's better if they don't have to, and less time is devoted to the topic if everyone adheres to the same style. The problem is how to approach this Utopia without a long prescriptive style guide. +With Go we take an unusual approach and let the machine take care of most formatting issues. The gofmt program (also available as go fmt, which operates at the package level rather than source file level) reads a Go program and emits the source in a standard style of indentation and vertical alignment, retaining and if necessary reformatting comments. If you want to know how to handle some new layout situation, run gofmt; if the answer doesn't seem right, rearrange your program (or file a bug about gofmt), don't work around it.`}, + {User: users[2], Tags: []*Tag{tags[3]}, Title: "Commentary", Content: `Go provides C-style /* */ block comments and C++-style // line comments. Line comments are the norm; block comments appear mostly as package comments, but are useful within an expression or to disable large swaths of code. +The program—and web server—godoc processes Go source files to extract documentation about the contents of the package. Comments that appear before top-level declarations, with no intervening newlines, are extracted along with the declaration to serve as explanatory text for the item. The nature and style of these comments determines the quality of the documentation godoc produces.`}, + } + + comments := []*Comment{ + {Post: posts[0], Content: "a comment"}, + {Post: posts[1], Content: "yes"}, + {Post: posts[1]}, + {Post: posts[1]}, + {Post: posts[2]}, + {Post: posts[2]}, + } + + for _, tag := range tags { + id, err := dORM.Insert(tag) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + } + + for _, post := range posts { + id, err := dORM.Insert(post) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + num := len(post.Tags) + if num > 0 { + nums, err := dORM.QueryM2M(post, "tags").Add(post.Tags) + throwFailNow(t, err) + throwFailNow(t, AssertIs(nums, num)) + } + } + + for _, comment := range comments { + id, err := dORM.Insert(comment) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + } + + permissions := []*Permission{ + {Name: "writePosts"}, + {Name: "readComments"}, + {Name: "readPosts"}, + } + + groups := []*Group{ + { + Name: "admins", + Permissions: []*Permission{permissions[0], permissions[1], permissions[2]}, + }, + { + Name: "users", + Permissions: []*Permission{permissions[1], permissions[2]}, + }, + } + + for _, permission := range permissions { + id, err := dORM.Insert(permission) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + } + + for _, group := range groups { + _, err := dORM.Insert(group) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + num := len(group.Permissions) + if num > 0 { + nums, err := dORM.QueryM2M(group, "permissions").Add(group.Permissions) + throwFailNow(t, err) + throwFailNow(t, AssertIs(nums, num)) + } + } + +} + +func TestCustomField(t *testing.T) { + user := User{ID: 2} + err := dORM.Read(&user) + throwFailNow(t, err) + + user.Langs = append(user.Langs, "zh-CN", "en-US") + user.Extra.Name = "beego" + user.Extra.Data = "orm" + _, err = dORM.Update(&user, "Langs", "Extra") + throwFailNow(t, err) + + user = User{ID: 2} + err = dORM.Read(&user) + throwFailNow(t, err) + throwFailNow(t, AssertIs(len(user.Langs), 2)) + throwFailNow(t, AssertIs(user.Langs[0], "zh-CN")) + throwFailNow(t, AssertIs(user.Langs[1], "en-US")) + + throwFailNow(t, AssertIs(user.Extra.Name, "beego")) + throwFailNow(t, AssertIs(user.Extra.Data, "orm")) +} + +func TestExpr(t *testing.T) { + user := &User{} + qs := dORM.QueryTable(user) + qs = dORM.QueryTable((*User)(nil)) + qs = dORM.QueryTable("User") + qs = dORM.QueryTable("user") + num, err := qs.Filter("UserName", "slene").Filter("user_name", "slene").Filter("profile__Age", 28).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("created", time.Now()).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + // num, err = qs.Filter("created", time.Now().Format(format_Date)).Count() + // throwFail(t, err) + // throwFail(t, AssertIs(num, 3)) +} + +func TestOperators(t *testing.T) { + qs := dORM.QueryTable("user") + num, err := qs.Filter("user_name", "slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name__exact", String("slene")).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name__exact", "slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name__iexact", "Slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name__contains", "e").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + var shouldNum int + + if IsSqlite || IsTidb { + shouldNum = 2 + } else { + shouldNum = 0 + } + + num, err = qs.Filter("user_name__contains", "E").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, shouldNum)) + + num, err = qs.Filter("user_name__icontains", "E").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("user_name__icontains", "E").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("status__gt", 1).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("status__gte", 1).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + num, err = qs.Filter("status__lt", Uint(3)).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("status__lte", Int(3)).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + num, err = qs.Filter("user_name__startswith", "s").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + if IsSqlite || IsTidb { + shouldNum = 1 + } else { + shouldNum = 0 + } + + num, err = qs.Filter("user_name__startswith", "S").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, shouldNum)) + + num, err = qs.Filter("user_name__istartswith", "S").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name__endswith", "e").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + if IsSqlite || IsTidb { + shouldNum = 2 + } else { + shouldNum = 0 + } + + num, err = qs.Filter("user_name__endswith", "E").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, shouldNum)) + + num, err = qs.Filter("user_name__iendswith", "E").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("profile__isnull", true).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("status__in", 1, 2).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("status__in", []int{1, 2}).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + n1, n2 := 1, 2 + num, err = qs.Filter("status__in", []*int{&n1}, &n2).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("id__between", 2, 3).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("id__between", []int{2, 3}).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.FilterRaw("user_name", "= 'slene'").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.FilterRaw("status", "IN (1, 2)").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.FilterRaw("profile_id", "IN (SELECT id FROM user_profile WHERE age=30)").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestSetCond(t *testing.T) { + cond := NewCondition() + cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000) + + qs := dORM.QueryTable("user") + num, err := qs.SetCond(cond1).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + cond2 := cond.AndCond(cond1).OrCond(cond.And("user_name", "slene")) + num, err = qs.SetCond(cond2).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + cond3 := cond.AndNotCond(cond.And("status__in", 1)) + num, err = qs.SetCond(cond3).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + cond4 := cond.And("user_name", "slene").OrNotCond(cond.And("user_name", "slene")) + num, err = qs.SetCond(cond4).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + cond5 := cond.Raw("user_name", "= 'slene'").OrNotCond(cond.And("user_name", "slene")) + num, err = qs.SetCond(cond5).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) +} + +func TestLimit(t *testing.T) { + var posts []*Post + qs := dORM.QueryTable("post") + num, err := qs.Limit(1).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Limit(-1).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 4)) + + num, err = qs.Limit(-1, 2).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Limit(0, 2).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) +} + +func TestOffset(t *testing.T) { + var posts []*Post + qs := dORM.QueryTable("post") + num, err := qs.Limit(1).Offset(2).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Offset(2).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) +} + +func TestOrderBy(t *testing.T) { + qs := dORM.QueryTable("user") + num, err := qs.OrderBy("-status").Filter("user_name", "nobody").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.OrderBy("status").Filter("user_name", "slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestAll(t *testing.T) { + var users []*User + qs := dORM.QueryTable("user") + num, err := qs.OrderBy("Id").All(&users) + throwFail(t, err) + throwFailNow(t, AssertIs(num, 3)) + + throwFail(t, AssertIs(users[0].UserName, "slene")) + throwFail(t, AssertIs(users[1].UserName, "astaxie")) + throwFail(t, AssertIs(users[2].UserName, "nobody")) + + var users2 []User + qs = dORM.QueryTable("user") + num, err = qs.OrderBy("Id").All(&users2) + throwFail(t, err) + throwFailNow(t, AssertIs(num, 3)) + + throwFailNow(t, AssertIs(users2[0].UserName, "slene")) + throwFailNow(t, AssertIs(users2[1].UserName, "astaxie")) + throwFailNow(t, AssertIs(users2[2].UserName, "nobody")) + + qs = dORM.QueryTable("user") + num, err = qs.OrderBy("Id").RelatedSel().All(&users2, "UserName") + throwFail(t, err) + throwFailNow(t, AssertIs(num, 3)) + throwFailNow(t, AssertIs(len(users2), 3)) + throwFailNow(t, AssertIs(users2[0].UserName, "slene")) + throwFailNow(t, AssertIs(users2[1].UserName, "astaxie")) + throwFailNow(t, AssertIs(users2[2].UserName, "nobody")) + throwFailNow(t, AssertIs(users2[0].ID, 0)) + throwFailNow(t, AssertIs(users2[1].ID, 0)) + throwFailNow(t, AssertIs(users2[2].ID, 0)) + throwFailNow(t, AssertIs(users2[0].Profile == nil, false)) + throwFailNow(t, AssertIs(users2[1].Profile == nil, false)) + throwFailNow(t, AssertIs(users2[2].Profile == nil, true)) + + qs = dORM.QueryTable("user") + num, err = qs.Filter("user_name", "nothing").All(&users) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 0)) + + var users3 []*User + qs = dORM.QueryTable("user") + num, err = qs.Filter("user_name", "nothing").All(&users3) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 0)) + throwFailNow(t, AssertIs(users3 == nil, false)) +} + +func TestOne(t *testing.T) { + var user User + qs := dORM.QueryTable("user") + err := qs.One(&user) + throwFail(t, err) + + user = User{} + err = qs.OrderBy("Id").Limit(1).One(&user) + throwFailNow(t, err) + throwFail(t, AssertIs(user.UserName, "slene")) + throwFail(t, AssertNot(err, ErrMultiRows)) + + user = User{} + err = qs.OrderBy("-Id").Limit(100).One(&user) + throwFailNow(t, err) + throwFail(t, AssertIs(user.UserName, "nobody")) + throwFail(t, AssertNot(err, ErrMultiRows)) + + err = qs.Filter("user_name", "nothing").One(&user) + throwFail(t, AssertIs(err, ErrNoRows)) + +} + +func TestValues(t *testing.T) { + var maps []Params + qs := dORM.QueryTable("user") + + num, err := qs.OrderBy("Id").Values(&maps) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(maps[0]["UserName"], "slene")) + throwFail(t, AssertIs(maps[2]["Profile"], nil)) + } + + num, err = qs.OrderBy("Id").Values(&maps, "UserName", "Profile__Age") + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(maps[0]["UserName"], "slene")) + throwFail(t, AssertIs(maps[0]["Profile__Age"], 28)) + throwFail(t, AssertIs(maps[2]["Profile__Age"], nil)) + } + + num, err = qs.Filter("UserName", "slene").Values(&maps) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestValuesList(t *testing.T) { + var list []ParamsList + qs := dORM.QueryTable("user") + + num, err := qs.OrderBy("Id").ValuesList(&list) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(list[0][1], "slene")) + throwFail(t, AssertIs(list[2][9], nil)) + } + + num, err = qs.OrderBy("Id").ValuesList(&list, "UserName", "Profile__Age") + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(list[0][0], "slene")) + throwFail(t, AssertIs(list[0][1], 28)) + throwFail(t, AssertIs(list[2][1], nil)) + } +} + +func TestValuesFlat(t *testing.T) { + var list ParamsList + qs := dORM.QueryTable("user") + + num, err := qs.OrderBy("id").ValuesFlat(&list, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(list[0], "slene")) + throwFail(t, AssertIs(list[1], "astaxie")) + throwFail(t, AssertIs(list[2], "nobody")) + } +} + +func TestRelatedSel(t *testing.T) { + if IsTidb { + // Skip it. TiDB does not support relation now. + return + } + qs := dORM.QueryTable("user") + num, err := qs.Filter("profile__age", 28).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("profile__age__gt", 28).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("profile__user__profile__age__gt", 28).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + var user User + err = qs.Filter("user_name", "slene").RelatedSel("profile").One(&user) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertNot(user.Profile, nil)) + if user.Profile != nil { + throwFail(t, AssertIs(user.Profile.Age, 28)) + } + + err = qs.Filter("user_name", "slene").RelatedSel().One(&user) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertNot(user.Profile, nil)) + if user.Profile != nil { + throwFail(t, AssertIs(user.Profile.Age, 28)) + } + + err = qs.Filter("user_name", "nobody").RelatedSel("profile").One(&user) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertIs(user.Profile, nil)) + + qs = dORM.QueryTable("user_profile") + num, err = qs.Filter("user__username", "slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + var posts []*Post + qs = dORM.QueryTable("post") + num, err = qs.RelatedSel().All(&posts) + throwFail(t, err) + throwFailNow(t, AssertIs(num, 4)) + + throwFailNow(t, AssertIs(posts[0].User.UserName, "slene")) + throwFailNow(t, AssertIs(posts[1].User.UserName, "astaxie")) + throwFailNow(t, AssertIs(posts[2].User.UserName, "astaxie")) + throwFailNow(t, AssertIs(posts[3].User.UserName, "nobody")) +} + +func TestReverseQuery(t *testing.T) { + var profile Profile + err := dORM.QueryTable("user_profile").Filter("User", 3).One(&profile) + throwFailNow(t, err) + throwFailNow(t, AssertIs(profile.Age, 30)) + + profile = Profile{} + err = dORM.QueryTable("user_profile").Filter("User__UserName", "astaxie").One(&profile) + throwFailNow(t, err) + throwFailNow(t, AssertIs(profile.Age, 30)) + + var user User + err = dORM.QueryTable("user").Filter("Posts__Title", "Examples").One(&user) + throwFailNow(t, err) + throwFailNow(t, AssertIs(user.UserName, "astaxie")) + + user = User{} + err = dORM.QueryTable("user").Filter("Posts__User__UserName", "astaxie").Limit(1).One(&user) + throwFailNow(t, err) + throwFailNow(t, AssertIs(user.UserName, "astaxie")) + + user = User{} + err = dORM.QueryTable("user").Filter("Posts__User__UserName", "astaxie").RelatedSel().Limit(1).One(&user) + throwFailNow(t, err) + throwFailNow(t, AssertIs(user.UserName, "astaxie")) + throwFailNow(t, AssertIs(user.Profile == nil, false)) + throwFailNow(t, AssertIs(user.Profile.Age, 30)) + + var posts []*Post + num, err := dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang").All(&posts) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + throwFailNow(t, AssertIs(posts[0].Title, "Introduction")) + + posts = []*Post{} + num, err = dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang").Filter("User__UserName", "slene").All(&posts) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(posts[0].Title, "Introduction")) + + posts = []*Post{} + num, err = dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang"). + Filter("User__UserName", "slene").RelatedSel().All(&posts) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(posts[0].User == nil, false)) + throwFailNow(t, AssertIs(posts[0].User.UserName, "slene")) + + var tags []*Tag + num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction").All(&tags) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(tags[0].Name, "golang")) + + tags = []*Tag{} + num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction"). + Filter("BestPost__User__UserName", "astaxie").All(&tags) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(tags[0].Name, "golang")) + + tags = []*Tag{} + num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction"). + Filter("BestPost__User__UserName", "astaxie").RelatedSel().All(&tags) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(tags[0].Name, "golang")) + throwFailNow(t, AssertIs(tags[0].BestPost == nil, false)) + throwFailNow(t, AssertIs(tags[0].BestPost.Title, "Examples")) + throwFailNow(t, AssertIs(tags[0].BestPost.User == nil, false)) + throwFailNow(t, AssertIs(tags[0].BestPost.User.UserName, "astaxie")) +} + +func TestLoadRelated(t *testing.T) { + // load reverse foreign key + user := User{ID: 3} + + err := dORM.Read(&user) + throwFailNow(t, err) + + num, err := dORM.LoadRelated(&user, "Posts") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(user.Posts), 2)) + throwFailNow(t, AssertIs(user.Posts[0].User.ID, 3)) + + num, err = dORM.LoadRelated(&user, "Posts", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(user.Posts), 2)) + throwFailNow(t, AssertIs(user.Posts[0].User.UserName, "astaxie")) + + num, err = dORM.LoadRelated(&user, "Posts", true, 1) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(len(user.Posts), 1)) + + num, err = dORM.LoadRelated(&user, "Posts", true, 0, 0, "-Id") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(user.Posts), 2)) + throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) + + num, err = dORM.LoadRelated(&user, "Posts", true, 1, 1, "Id") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(len(user.Posts), 1)) + throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) + + // load reverse one to one + profile := Profile{ID: 3} + profile.BestPost = &Post{ID: 2} + num, err = dORM.Update(&profile, "BestPost") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + err = dORM.Read(&profile) + throwFailNow(t, err) + + num, err = dORM.LoadRelated(&profile, "User") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(profile.User == nil, false)) + throwFailNow(t, AssertIs(profile.User.UserName, "astaxie")) + + num, err = dORM.LoadRelated(&profile, "User", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(profile.User == nil, false)) + throwFailNow(t, AssertIs(profile.User.UserName, "astaxie")) + throwFailNow(t, AssertIs(profile.User.Profile.Age, profile.Age)) + + // load rel one to one + err = dORM.Read(&user) + throwFailNow(t, err) + + num, err = dORM.LoadRelated(&user, "Profile") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(user.Profile == nil, false)) + throwFailNow(t, AssertIs(user.Profile.Age, 30)) + + num, err = dORM.LoadRelated(&user, "Profile", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(user.Profile == nil, false)) + throwFailNow(t, AssertIs(user.Profile.Age, 30)) + throwFailNow(t, AssertIs(user.Profile.BestPost == nil, false)) + throwFailNow(t, AssertIs(user.Profile.BestPost.Title, "Examples")) + + post := Post{ID: 2} + + // load rel foreign key + err = dORM.Read(&post) + throwFailNow(t, err) + + num, err = dORM.LoadRelated(&post, "User") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(post.User == nil, false)) + throwFailNow(t, AssertIs(post.User.UserName, "astaxie")) + + num, err = dORM.LoadRelated(&post, "User", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(post.User == nil, false)) + throwFailNow(t, AssertIs(post.User.UserName, "astaxie")) + throwFailNow(t, AssertIs(post.User.Profile == nil, false)) + throwFailNow(t, AssertIs(post.User.Profile.Age, 30)) + + // load rel m2m + post = Post{ID: 2} + + err = dORM.Read(&post) + throwFailNow(t, err) + + num, err = dORM.LoadRelated(&post, "Tags") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(post.Tags), 2)) + throwFailNow(t, AssertIs(post.Tags[0].Name, "golang")) + + num, err = dORM.LoadRelated(&post, "Tags", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(post.Tags), 2)) + throwFailNow(t, AssertIs(post.Tags[0].Name, "golang")) + throwFailNow(t, AssertIs(post.Tags[0].BestPost == nil, false)) + throwFailNow(t, AssertIs(post.Tags[0].BestPost.User.UserName, "astaxie")) + + // load reverse m2m + tag := Tag{ID: 1} + + err = dORM.Read(&tag) + throwFailNow(t, err) + + num, err = dORM.LoadRelated(&tag, "Posts") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) + throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2)) + throwFailNow(t, AssertIs(tag.Posts[0].User.Profile == nil, true)) + + num, err = dORM.LoadRelated(&tag, "Posts", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) + throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2)) + throwFailNow(t, AssertIs(tag.Posts[0].User.UserName, "slene")) +} + +func TestQueryM2M(t *testing.T) { + post := Post{ID: 4} + m2m := dORM.QueryM2M(&post, "Tags") + + tag1 := []*Tag{{Name: "TestTag1"}, {Name: "TestTag2"}} + tag2 := &Tag{Name: "TestTag3"} + tag3 := []interface{}{&Tag{Name: "TestTag4"}} + + tags := []interface{}{tag1[0], tag1[1], tag2, tag3[0]} + + for _, tag := range tags { + _, err := dORM.Insert(tag) + throwFailNow(t, err) + } + + num, err := m2m.Add(tag1) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + + num, err = m2m.Add(tag2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Add(tag3) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 5)) + + num, err = m2m.Remove(tag3) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 4)) + + exist := m2m.Exist(tag2) + throwFailNow(t, AssertIs(exist, true)) + + num, err = m2m.Remove(tag2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + exist = m2m.Exist(tag2) + throwFailNow(t, AssertIs(exist, false)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + + num, err = m2m.Clear() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 0)) + + tag := Tag{Name: "test"} + _, err = dORM.Insert(&tag) + throwFailNow(t, err) + + m2m = dORM.QueryM2M(&tag, "Posts") + + post1 := []*Post{{Title: "TestPost1"}, {Title: "TestPost2"}} + post2 := &Post{Title: "TestPost3"} + post3 := []interface{}{&Post{Title: "TestPost4"}} + + posts := []interface{}{post1[0], post1[1], post2, post3[0]} + + for _, post := range posts { + p := post.(*Post) + p.User = &User{ID: 1} + _, err := dORM.Insert(post) + throwFailNow(t, err) + } + + num, err = m2m.Add(post1) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + + num, err = m2m.Add(post2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Add(post3) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 4)) + + num, err = m2m.Remove(post3) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + + exist = m2m.Exist(post2) + throwFailNow(t, AssertIs(exist, true)) + + num, err = m2m.Remove(post2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + exist = m2m.Exist(post2) + throwFailNow(t, AssertIs(exist, false)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + + num, err = m2m.Clear() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 0)) + + num, err = dORM.Delete(&tag) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) +} + +func TestQueryRelate(t *testing.T) { + // post := &Post{Id: 2} + + // qs := dORM.QueryRelate(post, "Tags") + // num, err := qs.Count() + // throwFailNow(t, err) + // throwFailNow(t, AssertIs(num, 2)) + + // var tags []*Tag + // num, err = qs.All(&tags) + // throwFailNow(t, err) + // throwFailNow(t, AssertIs(num, 2)) + // throwFailNow(t, AssertIs(tags[0].Name, "golang")) + + // num, err = dORM.QueryTable("Tag").Filter("Posts__Post", 2).Count() + // throwFailNow(t, err) + // throwFailNow(t, AssertIs(num, 2)) +} + +func TestPkManyRelated(t *testing.T) { + permission := &Permission{Name: "readPosts"} + err := dORM.Read(permission, "Name") + throwFailNow(t, err) + + var groups []*Group + qs := dORM.QueryTable("Group") + num, err := qs.Filter("Permissions__Permission", permission.ID).All(&groups) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) +} + +func TestPrepareInsert(t *testing.T) { + qs := dORM.QueryTable("user") + i, err := qs.PrepareInsert() + throwFailNow(t, err) + + var user User + user.UserName = "testing1" + num, err := i.Insert(&user) + throwFail(t, err) + throwFail(t, AssertIs(num > 0, true)) + + user.UserName = "testing2" + num, err = i.Insert(&user) + throwFail(t, err) + throwFail(t, AssertIs(num > 0, true)) + + num, err = qs.Filter("user_name__in", "testing1", "testing2").Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + err = i.Close() + throwFail(t, err) + err = i.Close() + throwFail(t, AssertIs(err, ErrStmtClosed)) +} + +func TestRawExec(t *testing.T) { + Q := dDbBaser.TableQuote() + + query := fmt.Sprintf("UPDATE %suser%s SET %suser_name%s = ? WHERE %suser_name%s = ?", Q, Q, Q, Q, Q, Q) + res, err := dORM.Raw(query, "testing", "slene").Exec() + throwFail(t, err) + num, err := res.RowsAffected() + throwFail(t, AssertIs(num, 1), err) + + res, err = dORM.Raw(query, "slene", "testing").Exec() + throwFail(t, err) + num, err = res.RowsAffected() + throwFail(t, AssertIs(num, 1), err) +} + +func TestRawQueryRow(t *testing.T) { + var ( + Boolean bool + Char string + Text string + Time time.Time + Date time.Time + DateTime time.Time + Byte byte + Rune rune + Int int + Int8 int + Int16 int16 + Int32 int32 + Int64 int64 + Uint uint + Uint8 uint8 + Uint16 uint16 + Uint32 uint32 + Uint64 uint64 + Float32 float32 + Float64 float64 + Decimal float64 + ) + + dataValues := make(map[string]interface{}, len(DataValues)) + + for k, v := range DataValues { + dataValues[strings.ToLower(k)] = v + } + + Q := dDbBaser.TableQuote() + + cols := []string{ + "id", "boolean", "char", "text", "time", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32", + "int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal", + } + sep := fmt.Sprintf("%s, %s", Q, Q) + query := fmt.Sprintf("SELECT %s%s%s FROM data WHERE id = ?", Q, strings.Join(cols, sep), Q) + var id int + values := []interface{}{ + &id, &Boolean, &Char, &Text, &Time, &Date, &DateTime, &Byte, &Rune, &Int, &Int8, &Int16, &Int32, + &Int64, &Uint, &Uint8, &Uint16, &Uint32, &Uint64, &Float32, &Float64, &Decimal, + } + err := dORM.Raw(query, 1).QueryRow(values...) + throwFailNow(t, err) + for i, col := range cols { + vu := values[i] + v := reflect.ValueOf(vu).Elem().Interface() + switch col { + case "id": + throwFail(t, AssertIs(id, 1)) + case "time": + v = v.(time.Time).In(DefaultTimeLoc) + value := dataValues[col].(time.Time).In(DefaultTimeLoc) + throwFail(t, AssertIs(v, value, testTime)) + case "date": + v = v.(time.Time).In(DefaultTimeLoc) + value := dataValues[col].(time.Time).In(DefaultTimeLoc) + throwFail(t, AssertIs(v, value, testDate)) + case "datetime": + v = v.(time.Time).In(DefaultTimeLoc) + value := dataValues[col].(time.Time).In(DefaultTimeLoc) + throwFail(t, AssertIs(v, value, testDateTime)) + default: + throwFail(t, AssertIs(v, dataValues[col])) + } + } + + var ( + uid int + status *int + pid *int + ) + + cols = []string{ + "id", "Status", "profile_id", + } + query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s WHERE id = ?", Q, strings.Join(cols, sep), Q, Q, Q) + err = dORM.Raw(query, 4).QueryRow(&uid, &status, &pid) + throwFail(t, err) + throwFail(t, AssertIs(uid, 4)) + throwFail(t, AssertIs(*status, 3)) + throwFail(t, AssertIs(pid, nil)) + + // test for sql.Null* fields + nData := &DataNull{ + NullString: sql.NullString{String: "test sql.null", Valid: true}, + NullBool: sql.NullBool{Bool: true, Valid: true}, + NullInt64: sql.NullInt64{Int64: 42, Valid: true}, + NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, + } + newId, err := dORM.Insert(nData) + throwFailNow(t, err) + + var nd *DataNull + query = fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q) + err = dORM.Raw(query, newId).QueryRow(&nd) + throwFailNow(t, err) + + throwFailNow(t, AssertNot(nd, nil)) + throwFail(t, AssertIs(nd.NullBool.Valid, true)) + throwFail(t, AssertIs(nd.NullBool.Bool, true)) + throwFail(t, AssertIs(nd.NullString.Valid, true)) + throwFail(t, AssertIs(nd.NullString.String, "test sql.null")) + throwFail(t, AssertIs(nd.NullInt64.Valid, true)) + throwFail(t, AssertIs(nd.NullInt64.Int64, 42)) + throwFail(t, AssertIs(nd.NullFloat64.Valid, true)) + throwFail(t, AssertIs(nd.NullFloat64.Float64, 42.42)) +} + +// user_profile table +type userProfile struct { + User + Age int + Money float64 +} + +func TestQueryRows(t *testing.T) { + Q := dDbBaser.TableQuote() + + var datas []*Data + + query := fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q) + num, err := dORM.Raw(query).QueryRows(&datas) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(len(datas), 1)) + + ind := reflect.Indirect(reflect.ValueOf(datas[0])) + + for name, value := range DataValues { + e := ind.FieldByName(name) + vu := e.Interface() + switch name { + case "Time": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) + case "Date": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) + case "DateTime": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + } + throwFail(t, AssertIs(vu == value, true), value, vu) + } + + var datas2 []Data + + query = fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q) + num, err = dORM.Raw(query).QueryRows(&datas2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(len(datas2), 1)) + + ind = reflect.Indirect(reflect.ValueOf(datas2[0])) + + for name, value := range DataValues { + e := ind.FieldByName(name) + vu := e.Interface() + switch name { + case "Time": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) + case "Date": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) + case "DateTime": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + } + throwFail(t, AssertIs(vu == value, true), value, vu) + } + + var ids []int + var usernames []string + query = fmt.Sprintf("SELECT %sid%s, %suser_name%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q, Q, Q) + num, err = dORM.Raw(query).QueryRows(&ids, &usernames) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + throwFailNow(t, AssertIs(len(ids), 3)) + throwFailNow(t, AssertIs(ids[0], 2)) + throwFailNow(t, AssertIs(usernames[0], "slene")) + throwFailNow(t, AssertIs(ids[1], 3)) + throwFailNow(t, AssertIs(usernames[1], "astaxie")) + throwFailNow(t, AssertIs(ids[2], 4)) + throwFailNow(t, AssertIs(usernames[2], "nobody")) + + //test query rows by nested struct + var l []userProfile + query = fmt.Sprintf("SELECT * FROM %suser_profile%s LEFT JOIN %suser%s ON %suser_profile%s.%sid%s = %suser%s.%sid%s", Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q) + num, err = dORM.Raw(query).QueryRows(&l) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(l), 2)) + throwFailNow(t, AssertIs(l[0].UserName, "slene")) + throwFailNow(t, AssertIs(l[0].Age, 28)) + throwFailNow(t, AssertIs(l[1].UserName, "astaxie")) + throwFailNow(t, AssertIs(l[1].Age, 30)) + + // test for sql.Null* fields + nData := &DataNull{ + NullString: sql.NullString{String: "test sql.null", Valid: true}, + NullBool: sql.NullBool{Bool: true, Valid: true}, + NullInt64: sql.NullInt64{Int64: 42, Valid: true}, + NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, + } + newId, err := dORM.Insert(nData) + throwFailNow(t, err) + + var nDataList []*DataNull + query = fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q) + num, err = dORM.Raw(query, newId).QueryRows(&nDataList) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + nd := nDataList[0] + throwFailNow(t, AssertNot(nd, nil)) + throwFail(t, AssertIs(nd.NullBool.Valid, true)) + throwFail(t, AssertIs(nd.NullBool.Bool, true)) + throwFail(t, AssertIs(nd.NullString.Valid, true)) + throwFail(t, AssertIs(nd.NullString.String, "test sql.null")) + throwFail(t, AssertIs(nd.NullInt64.Valid, true)) + throwFail(t, AssertIs(nd.NullInt64.Int64, 42)) + throwFail(t, AssertIs(nd.NullFloat64.Valid, true)) + throwFail(t, AssertIs(nd.NullFloat64.Float64, 42.42)) +} + +func TestRawValues(t *testing.T) { + Q := dDbBaser.TableQuote() + + var maps []Params + query := fmt.Sprintf("SELECT %suser_name%s FROM %suser%s WHERE %sStatus%s = ?", Q, Q, Q, Q, Q, Q) + num, err := dORM.Raw(query, 1).Values(&maps) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + if num == 1 { + throwFail(t, AssertIs(maps[0]["user_name"], "slene")) + } + + var lists []ParamsList + num, err = dORM.Raw(query, 1).ValuesList(&lists) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + if num == 1 { + throwFail(t, AssertIs(lists[0][0], "slene")) + } + + query = fmt.Sprintf("SELECT %sprofile_id%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q) + var list ParamsList + num, err = dORM.Raw(query).ValuesFlat(&list) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(list[0], "2")) + throwFail(t, AssertIs(list[1], "3")) + throwFail(t, AssertIs(list[2], nil)) + } +} + +func TestRawPrepare(t *testing.T) { + switch { + case IsMysql || IsSqlite: + + pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare() + throwFail(t, err) + if pre != nil { + r, err := pre.Exec("name1") + throwFail(t, err) + + tid, err := r.LastInsertId() + throwFail(t, err) + throwFail(t, AssertIs(tid > 0, true)) + + r, err = pre.Exec("name2") + throwFail(t, err) + + id, err := r.LastInsertId() + throwFail(t, err) + throwFail(t, AssertIs(id, tid+1)) + + r, err = pre.Exec("name3") + throwFail(t, err) + + id, err = r.LastInsertId() + throwFail(t, err) + throwFail(t, AssertIs(id, tid+2)) + + err = pre.Close() + throwFail(t, err) + + res, err := dORM.Raw("DELETE FROM tag WHERE name IN (?, ?, ?)", []string{"name1", "name2", "name3"}).Exec() + throwFail(t, err) + + num, err := res.RowsAffected() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + } + + case IsPostgres: + + pre, err := dORM.Raw(`INSERT INTO "tag" ("name") VALUES (?) RETURNING "id"`).Prepare() + throwFail(t, err) + if pre != nil { + _, err := pre.Exec("name1") + throwFail(t, err) + + _, err = pre.Exec("name2") + throwFail(t, err) + + _, err = pre.Exec("name3") + throwFail(t, err) + + err = pre.Close() + throwFail(t, err) + + res, err := dORM.Raw(`DELETE FROM "tag" WHERE "name" IN (?, ?, ?)`, []string{"name1", "name2", "name3"}).Exec() + throwFail(t, err) + + if err == nil { + num, err := res.RowsAffected() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + } + } + } +} + +func TestUpdate(t *testing.T) { + qs := dORM.QueryTable("user") + num, err := qs.Filter("user_name", "slene").Filter("is_staff", false).Update(Params{ + "is_staff": true, + "is_active": true, + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + // with join + num, err = qs.Filter("user_name", "slene").Filter("profile__age", 28).Filter("is_staff", true).Update(Params{ + "is_staff": false, + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "slene").Update(Params{ + "Nums": ColValue(ColAdd, 100), + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "slene").Update(Params{ + "Nums": ColValue(ColMinus, 50), + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "slene").Update(Params{ + "Nums": ColValue(ColMultiply, 3), + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "slene").Update(Params{ + "Nums": ColValue(ColExcept, 5), + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + user := User{UserName: "slene"} + err = dORM.Read(&user, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(user.Nums, 30)) +} + +func TestDelete(t *testing.T) { + qs := dORM.QueryTable("user_profile") + num, err := qs.Filter("user__user_name", "slene").Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + qs = dORM.QueryTable("user") + num, err = qs.Filter("user_name", "slene").Filter("profile__isnull", true).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + qs = dORM.QueryTable("comment") + num, err = qs.Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 6)) + + qs = dORM.QueryTable("post") + num, err = qs.Filter("Id", 3).Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + qs = dORM.QueryTable("comment") + num, err = qs.Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 4)) + + qs = dORM.QueryTable("comment") + num, err = qs.Filter("Post__User", 3).Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + qs = dORM.QueryTable("comment") + num, err = qs.Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestTransaction(t *testing.T) { + // this test worked when database support transaction + + o := NewOrm() + err := o.Begin() + throwFail(t, err) + + var names = []string{"1", "2", "3"} + + var tag Tag + tag.Name = names[0] + id, err := o.Insert(&tag) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + num, err := o.QueryTable("tag").Filter("name", "golang").Update(Params{"name": names[1]}) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + switch { + case IsMysql || IsSqlite: + res, err := o.Raw("INSERT INTO tag (name) VALUES (?)", names[2]).Exec() + throwFail(t, err) + if err == nil { + id, err = res.LastInsertId() + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + } + } + + err = o.Rollback() + throwFail(t, err) + + num, err = o.QueryTable("tag").Filter("name__in", names).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 0)) + + err = o.Begin() + throwFail(t, err) + + tag.Name = "commit" + id, err = o.Insert(&tag) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + o.Commit() + throwFail(t, err) + + num, err = o.QueryTable("tag").Filter("name", "commit").Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + +} + +func TestTransactionIsolationLevel(t *testing.T) { + // this test worked when database support transaction isolation level + if IsSqlite { + return + } + + o1 := NewOrm() + o2 := NewOrm() + + // start two transaction with isolation level repeatable read + err := o1.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + throwFail(t, err) + err = o2.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + throwFail(t, err) + + // o1 insert tag + var tag Tag + tag.Name = "test-transaction" + id, err := o1.Insert(&tag) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + // o2 query tag table, no result + num, err := o2.QueryTable("tag").Filter("name", "test-transaction").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 0)) + + // o1 commit + o1.Commit() + + // o2 query tag table, still no result + num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 0)) + + // o2 commit and query tag table, get the result + o2.Commit() + num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = o1.QueryTable("tag").Filter("name", "test-transaction").Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestBeginTxWithContextCanceled(t *testing.T) { + o := NewOrm() + ctx, cancel := context.WithCancel(context.Background()) + o.BeginTx(ctx, nil) + id, err := o.Insert(&Tag{Name: "test-context"}) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + // cancel the context before commit to make it error + cancel() + err = o.Commit() + throwFail(t, AssertIs(err, context.Canceled)) +} + +func TestReadOrCreate(t *testing.T) { + u := &User{ + UserName: "Kyle", + Email: "kylemcc@gmail.com", + Password: "other_pass", + Status: 7, + IsStaff: false, + IsActive: true, + } + + created, pk, err := dORM.ReadOrCreate(u, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(created, true)) + throwFail(t, AssertIs(u.ID, pk)) + throwFail(t, AssertIs(u.UserName, "Kyle")) + throwFail(t, AssertIs(u.Email, "kylemcc@gmail.com")) + throwFail(t, AssertIs(u.Password, "other_pass")) + throwFail(t, AssertIs(u.Status, 7)) + throwFail(t, AssertIs(u.IsStaff, false)) + throwFail(t, AssertIs(u.IsActive, true)) + throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), testDate)) + throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), testDateTime)) + + nu := &User{UserName: u.UserName, Email: "someotheremail@gmail.com"} + created, pk, err = dORM.ReadOrCreate(nu, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(created, false)) + throwFail(t, AssertIs(nu.ID, u.ID)) + throwFail(t, AssertIs(pk, u.ID)) + throwFail(t, AssertIs(nu.UserName, u.UserName)) + throwFail(t, AssertIs(nu.Email, u.Email)) // should contain the value in the table, not the one specified above + throwFail(t, AssertIs(nu.Password, u.Password)) + throwFail(t, AssertIs(nu.Status, u.Status)) + throwFail(t, AssertIs(nu.IsStaff, u.IsStaff)) + throwFail(t, AssertIs(nu.IsActive, u.IsActive)) + + dORM.Delete(u) +} + +func TestInLine(t *testing.T) { + name := "inline" + email := "hello@go.com" + inline := NewInLine() + inline.Name = name + inline.Email = email + + id, err := dORM.Insert(inline) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + il := NewInLine() + il.ID = 1 + err = dORM.Read(il) + throwFail(t, err) + + throwFail(t, AssertIs(il.Name, name)) + throwFail(t, AssertIs(il.Email, email)) + throwFail(t, AssertIs(il.Created.In(DefaultTimeLoc), inline.Created.In(DefaultTimeLoc), testDate)) + throwFail(t, AssertIs(il.Updated.In(DefaultTimeLoc), inline.Updated.In(DefaultTimeLoc), testDateTime)) +} + +func TestInLineOneToOne(t *testing.T) { + name := "121" + email := "121@go.com" + inline := NewInLine() + inline.Name = name + inline.Email = email + + id, err := dORM.Insert(inline) + throwFail(t, err) + throwFail(t, AssertIs(id, 2)) + + note := "one2one" + il121 := NewInLineOneToOne() + il121.Note = note + il121.InLine = inline + _, err = dORM.Insert(il121) + throwFail(t, err) + throwFail(t, AssertIs(il121.ID, 1)) + + il := NewInLineOneToOne() + err = dORM.QueryTable(il).Filter("Id", 1).RelatedSel().One(il) + + throwFail(t, err) + throwFail(t, AssertIs(il.Note, note)) + throwFail(t, AssertIs(il.InLine.ID, id)) + throwFail(t, AssertIs(il.InLine.Name, name)) + throwFail(t, AssertIs(il.InLine.Email, email)) + + rinline := NewInLine() + err = dORM.QueryTable(rinline).Filter("InLineOneToOne__Id", 1).One(rinline) + + throwFail(t, err) + throwFail(t, AssertIs(rinline.ID, id)) + throwFail(t, AssertIs(rinline.Name, name)) + throwFail(t, AssertIs(rinline.Email, email)) +} + +func TestIntegerPk(t *testing.T) { + its := []IntegerPk{ + {ID: math.MinInt64, Value: "-"}, + {ID: 0, Value: "0"}, + {ID: math.MaxInt64, Value: "+"}, + } + + num, err := dORM.InsertMulti(len(its), its) + throwFail(t, err) + throwFail(t, AssertIs(num, len(its))) + + for _, intPk := range its { + out := IntegerPk{ID: intPk.ID} + err = dORM.Read(&out) + throwFail(t, err) + throwFail(t, AssertIs(out.Value, intPk.Value)) + } + + num, err = dORM.InsertMulti(1, []*IntegerPk{{ + ID: 1, Value: "ok", + }}) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestInsertAuto(t *testing.T) { + u := &User{ + UserName: "autoPre", + Email: "autoPre@gmail.com", + } + + id, err := dORM.Insert(u) + throwFail(t, err) + + id += 100 + su := &User{ + ID: int(id), + UserName: "auto", + Email: "auto@gmail.com", + } + + nid, err := dORM.Insert(su) + throwFail(t, err) + throwFail(t, AssertIs(nid, id)) + + users := []User{ + {ID: int(id + 100), UserName: "auto_100"}, + {ID: int(id + 110), UserName: "auto_110"}, + {ID: int(id + 120), UserName: "auto_120"}, + } + num, err := dORM.InsertMulti(100, users) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + u = &User{ + UserName: "auto_121", + } + + nid, err = dORM.Insert(u) + throwFail(t, err) + throwFail(t, AssertIs(nid, id+120+1)) +} + +func TestUintPk(t *testing.T) { + name := "go" + u := &UintPk{ + ID: 8, + Name: name, + } + + created, _, err := dORM.ReadOrCreate(u, "ID") + throwFail(t, err) + throwFail(t, AssertIs(created, true)) + throwFail(t, AssertIs(u.Name, name)) + + nu := &UintPk{ID: 8} + created, pk, err := dORM.ReadOrCreate(nu, "ID") + throwFail(t, err) + throwFail(t, AssertIs(created, false)) + throwFail(t, AssertIs(nu.ID, u.ID)) + throwFail(t, AssertIs(pk, u.ID)) + throwFail(t, AssertIs(nu.Name, name)) + + dORM.Delete(u) +} + +func TestPtrPk(t *testing.T) { + parent := &IntegerPk{ID: 10, Value: "10"} + + id, _ := dORM.Insert(parent) + if !IsMysql { + // MySql does not support last_insert_id in this case: see #2382 + throwFail(t, AssertIs(id, 10)) + } + + ptr := PtrPk{ID: parent, Positive: true} + num, err := dORM.InsertMulti(2, []PtrPk{ptr}) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertIs(ptr.ID, parent)) + + nptr := &PtrPk{ID: parent} + created, pk, err := dORM.ReadOrCreate(nptr, "ID") + throwFail(t, err) + throwFail(t, AssertIs(created, false)) + throwFail(t, AssertIs(pk, 10)) + throwFail(t, AssertIs(nptr.ID, parent)) + throwFail(t, AssertIs(nptr.Positive, true)) + + nptr = &PtrPk{Positive: true} + created, pk, err = dORM.ReadOrCreate(nptr, "Positive") + throwFail(t, err) + throwFail(t, AssertIs(created, false)) + throwFail(t, AssertIs(pk, 10)) + throwFail(t, AssertIs(nptr.ID, parent)) + + nptr.Positive = false + num, err = dORM.Update(nptr) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertIs(nptr.ID, parent)) + throwFail(t, AssertIs(nptr.Positive, false)) + + num, err = dORM.Delete(nptr) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestSnake(t *testing.T) { + cases := map[string]string{ + "i": "i", + "I": "i", + "iD": "i_d", + "ID": "i_d", + "NO": "n_o", + "NOO": "n_o_o", + "NOOooOOoo": "n_o_ooo_o_ooo", + "OrderNO": "order_n_o", + "tagName": "tag_name", + "tag_Name": "tag__name", + "tag_name": "tag_name", + "_tag_name": "_tag_name", + "tag_666name": "tag_666name", + "tag_666Name": "tag_666_name", + } + for name, want := range cases { + got := snakeString(name) + throwFail(t, AssertIs(got, want)) + } +} + +func TestIgnoreCaseTag(t *testing.T) { + type testTagModel struct { + ID int `orm:"pk"` + NOO string `orm:"column(n)"` + Name01 string `orm:"NULL"` + Name02 string `orm:"COLUMN(Name)"` + Name03 string `orm:"Column(name)"` + } + modelCache.clean() + RegisterModel(&testTagModel{}) + info, ok := modelCache.get("test_tag_model") + throwFail(t, AssertIs(ok, true)) + throwFail(t, AssertNot(info, nil)) + if t == nil { + return + } + throwFail(t, AssertIs(info.fields.GetByName("NOO").column, "n")) + throwFail(t, AssertIs(info.fields.GetByName("Name01").null, true)) + throwFail(t, AssertIs(info.fields.GetByName("Name02").column, "Name")) + throwFail(t, AssertIs(info.fields.GetByName("Name03").column, "name")) +} + +func TestInsertOrUpdate(t *testing.T) { + RegisterModel(new(User)) + user := User{UserName: "unique_username133", Status: 1, Password: "o"} + user1 := User{UserName: "unique_username133", Status: 2, Password: "o"} + user2 := User{UserName: "unique_username133", Status: 3, Password: "oo"} + dORM.Insert(&user) + test := User{UserName: "unique_username133"} + fmt.Println(dORM.Driver().Name()) + if dORM.Driver().Name() == "sqlite3" { + fmt.Println("sqlite3 is nonsupport") + return + } + //test1 + _, err := dORM.InsertOrUpdate(&user1, "user_name") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs(user1.Status, test.Status)) + } + //test2 + _, err = dORM.InsertOrUpdate(&user2, "user_name") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs(user2.Status, test.Status)) + throwFailNow(t, AssertIs(user2.Password, strings.TrimSpace(test.Password))) + } + + //postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values + if IsPostgres { + return + } + //test3 + + _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status+1") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs(user2.Status+1, test.Status)) + } + //test4 - + _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status-1") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs((user2.Status+1)-1, test.Status)) + } + //test5 * + _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status*3") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs(((user2.Status+1)-1)*3, test.Status)) + } + //test6 / + _, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status/3") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs((((user2.Status+1)-1)*3)/3, test.Status)) + } +} diff --git a/pkg/orm/qb.go b/pkg/orm/qb.go new file mode 100644 index 00000000..e0655a17 --- /dev/null +++ b/pkg/orm/qb.go @@ -0,0 +1,62 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import "errors" + +// QueryBuilder is the Query builder interface +type QueryBuilder interface { + Select(fields ...string) QueryBuilder + ForUpdate() QueryBuilder + From(tables ...string) QueryBuilder + InnerJoin(table string) QueryBuilder + LeftJoin(table string) QueryBuilder + RightJoin(table string) QueryBuilder + On(cond string) QueryBuilder + Where(cond string) QueryBuilder + And(cond string) QueryBuilder + Or(cond string) QueryBuilder + In(vals ...string) QueryBuilder + OrderBy(fields ...string) QueryBuilder + Asc() QueryBuilder + Desc() QueryBuilder + Limit(limit int) QueryBuilder + Offset(offset int) QueryBuilder + GroupBy(fields ...string) QueryBuilder + Having(cond string) QueryBuilder + Update(tables ...string) QueryBuilder + Set(kv ...string) QueryBuilder + Delete(tables ...string) QueryBuilder + InsertInto(table string, fields ...string) QueryBuilder + Values(vals ...string) QueryBuilder + Subquery(sub string, alias string) string + String() string +} + +// NewQueryBuilder return the QueryBuilder +func NewQueryBuilder(driver string) (qb QueryBuilder, err error) { + if driver == "mysql" { + qb = new(MySQLQueryBuilder) + } else if driver == "tidb" { + qb = new(TiDBQueryBuilder) + } else if driver == "postgres" { + err = errors.New("postgres query builder is not supported yet") + } else if driver == "sqlite" { + err = errors.New("sqlite query builder is not supported yet") + } else { + err = errors.New("unknown driver for query builder") + } + return +} diff --git a/pkg/orm/qb_mysql.go b/pkg/orm/qb_mysql.go new file mode 100644 index 00000000..23bdc9ee --- /dev/null +++ b/pkg/orm/qb_mysql.go @@ -0,0 +1,185 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "strconv" + "strings" +) + +// CommaSpace is the separation +const CommaSpace = ", " + +// MySQLQueryBuilder is the SQL build +type MySQLQueryBuilder struct { + Tokens []string +} + +// Select will join the fields +func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace)) + return qb +} + +// ForUpdate add the FOR UPDATE clause +func (qb *MySQLQueryBuilder) ForUpdate() QueryBuilder { + qb.Tokens = append(qb.Tokens, "FOR UPDATE") + return qb +} + +// From join the tables +func (qb *MySQLQueryBuilder) From(tables ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace)) + return qb +} + +// InnerJoin INNER JOIN the table +func (qb *MySQLQueryBuilder) InnerJoin(table string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "INNER JOIN", table) + return qb +} + +// LeftJoin LEFT JOIN the table +func (qb *MySQLQueryBuilder) LeftJoin(table string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "LEFT JOIN", table) + return qb +} + +// RightJoin RIGHT JOIN the table +func (qb *MySQLQueryBuilder) RightJoin(table string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table) + return qb +} + +// On join with on cond +func (qb *MySQLQueryBuilder) On(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "ON", cond) + return qb +} + +// Where join the Where cond +func (qb *MySQLQueryBuilder) Where(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "WHERE", cond) + return qb +} + +// And join the and cond +func (qb *MySQLQueryBuilder) And(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "AND", cond) + return qb +} + +// Or join the or cond +func (qb *MySQLQueryBuilder) Or(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "OR", cond) + return qb +} + +// In join the IN (vals) +func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")") + return qb +} + +// OrderBy join the Order by fields +func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace)) + return qb +} + +// Asc join the asc +func (qb *MySQLQueryBuilder) Asc() QueryBuilder { + qb.Tokens = append(qb.Tokens, "ASC") + return qb +} + +// Desc join the desc +func (qb *MySQLQueryBuilder) Desc() QueryBuilder { + qb.Tokens = append(qb.Tokens, "DESC") + return qb +} + +// Limit join the limit num +func (qb *MySQLQueryBuilder) Limit(limit int) QueryBuilder { + qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit)) + return qb +} + +// Offset join the offset num +func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder { + qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset)) + return qb +} + +// GroupBy join the Group by fields +func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace)) + return qb +} + +// Having join the Having cond +func (qb *MySQLQueryBuilder) Having(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "HAVING", cond) + return qb +} + +// Update join the update table +func (qb *MySQLQueryBuilder) Update(tables ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace)) + return qb +} + +// Set join the set kv +func (qb *MySQLQueryBuilder) Set(kv ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace)) + return qb +} + +// Delete join the Delete tables +func (qb *MySQLQueryBuilder) Delete(tables ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "DELETE") + if len(tables) != 0 { + qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace)) + } + return qb +} + +// InsertInto join the insert SQL +func (qb *MySQLQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "INSERT INTO", table) + if len(fields) != 0 { + fieldsStr := strings.Join(fields, CommaSpace) + qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")") + } + return qb +} + +// Values join the Values(vals) +func (qb *MySQLQueryBuilder) Values(vals ...string) QueryBuilder { + valsStr := strings.Join(vals, CommaSpace) + qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")") + return qb +} + +// Subquery join the sub as alias +func (qb *MySQLQueryBuilder) Subquery(sub string, alias string) string { + return fmt.Sprintf("(%s) AS %s", sub, alias) +} + +// String join all Tokens +func (qb *MySQLQueryBuilder) String() string { + return strings.Join(qb.Tokens, " ") +} diff --git a/pkg/orm/qb_tidb.go b/pkg/orm/qb_tidb.go new file mode 100644 index 00000000..87b3ae84 --- /dev/null +++ b/pkg/orm/qb_tidb.go @@ -0,0 +1,182 @@ +// Copyright 2015 TiDB Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "strconv" + "strings" +) + +// TiDBQueryBuilder is the SQL build +type TiDBQueryBuilder struct { + Tokens []string +} + +// Select will join the fields +func (qb *TiDBQueryBuilder) Select(fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace)) + return qb +} + +// ForUpdate add the FOR UPDATE clause +func (qb *TiDBQueryBuilder) ForUpdate() QueryBuilder { + qb.Tokens = append(qb.Tokens, "FOR UPDATE") + return qb +} + +// From join the tables +func (qb *TiDBQueryBuilder) From(tables ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace)) + return qb +} + +// InnerJoin INNER JOIN the table +func (qb *TiDBQueryBuilder) InnerJoin(table string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "INNER JOIN", table) + return qb +} + +// LeftJoin LEFT JOIN the table +func (qb *TiDBQueryBuilder) LeftJoin(table string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "LEFT JOIN", table) + return qb +} + +// RightJoin RIGHT JOIN the table +func (qb *TiDBQueryBuilder) RightJoin(table string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table) + return qb +} + +// On join with on cond +func (qb *TiDBQueryBuilder) On(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "ON", cond) + return qb +} + +// Where join the Where cond +func (qb *TiDBQueryBuilder) Where(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "WHERE", cond) + return qb +} + +// And join the and cond +func (qb *TiDBQueryBuilder) And(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "AND", cond) + return qb +} + +// Or join the or cond +func (qb *TiDBQueryBuilder) Or(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "OR", cond) + return qb +} + +// In join the IN (vals) +func (qb *TiDBQueryBuilder) In(vals ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")") + return qb +} + +// OrderBy join the Order by fields +func (qb *TiDBQueryBuilder) OrderBy(fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace)) + return qb +} + +// Asc join the asc +func (qb *TiDBQueryBuilder) Asc() QueryBuilder { + qb.Tokens = append(qb.Tokens, "ASC") + return qb +} + +// Desc join the desc +func (qb *TiDBQueryBuilder) Desc() QueryBuilder { + qb.Tokens = append(qb.Tokens, "DESC") + return qb +} + +// Limit join the limit num +func (qb *TiDBQueryBuilder) Limit(limit int) QueryBuilder { + qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit)) + return qb +} + +// Offset join the offset num +func (qb *TiDBQueryBuilder) Offset(offset int) QueryBuilder { + qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset)) + return qb +} + +// GroupBy join the Group by fields +func (qb *TiDBQueryBuilder) GroupBy(fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace)) + return qb +} + +// Having join the Having cond +func (qb *TiDBQueryBuilder) Having(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "HAVING", cond) + return qb +} + +// Update join the update table +func (qb *TiDBQueryBuilder) Update(tables ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace)) + return qb +} + +// Set join the set kv +func (qb *TiDBQueryBuilder) Set(kv ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace)) + return qb +} + +// Delete join the Delete tables +func (qb *TiDBQueryBuilder) Delete(tables ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "DELETE") + if len(tables) != 0 { + qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace)) + } + return qb +} + +// InsertInto join the insert SQL +func (qb *TiDBQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "INSERT INTO", table) + if len(fields) != 0 { + fieldsStr := strings.Join(fields, CommaSpace) + qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")") + } + return qb +} + +// Values join the Values(vals) +func (qb *TiDBQueryBuilder) Values(vals ...string) QueryBuilder { + valsStr := strings.Join(vals, CommaSpace) + qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")") + return qb +} + +// Subquery join the sub as alias +func (qb *TiDBQueryBuilder) Subquery(sub string, alias string) string { + return fmt.Sprintf("(%s) AS %s", sub, alias) +} + +// String join all Tokens +func (qb *TiDBQueryBuilder) String() string { + return strings.Join(qb.Tokens, " ") +} diff --git a/pkg/orm/types.go b/pkg/orm/types.go new file mode 100644 index 00000000..2fd10774 --- /dev/null +++ b/pkg/orm/types.go @@ -0,0 +1,473 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "database/sql" + "reflect" + "time" +) + +// Driver define database driver +type Driver interface { + Name() string + Type() DriverType +} + +// Fielder define field info +type Fielder interface { + String() string + FieldType() int + SetRaw(interface{}) error + RawValue() interface{} +} + +// Ormer define the orm interface +type Ormer interface { + // read data to model + // for example: + // this will find User by Id field + // u = &User{Id: user.Id} + // err = Ormer.Read(u) + // this will find User by UserName field + // u = &User{UserName: "astaxie", Password: "pass"} + // err = Ormer.Read(u, "UserName") + Read(md interface{}, cols ...string) error + // Like Read(), but with "FOR UPDATE" clause, useful in transaction. + // Some databases are not support this feature. + ReadForUpdate(md interface{}, cols ...string) error + // Try to read a row from the database, or insert one if it doesn't exist + ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) + // insert model data to database + // for example: + // user := new(User) + // id, err = Ormer.Insert(user) + // user must be a pointer and Insert will set user's pk field + Insert(interface{}) (int64, error) + // mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value") + // if colu type is integer : can use(+-*/), string : convert(colu,"value") + // postgres: InsertOrUpdate(model,"conflictColumnName") or InsertOrUpdate(model,"conflictColumnName","colu=colu+value") + // if colu type is integer : can use(+-*/), string : colu || "value" + InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) + // insert some models to database + InsertMulti(bulk int, mds interface{}) (int64, error) + // update model to database. + // cols set the columns those want to update. + // find model by Id(pk) field and update columns specified by fields, if cols is null then update all columns + // for example: + // user := User{Id: 2} + // user.Langs = append(user.Langs, "zh-CN", "en-US") + // user.Extra.Name = "beego" + // user.Extra.Data = "orm" + // num, err = Ormer.Update(&user, "Langs", "Extra") + Update(md interface{}, cols ...string) (int64, error) + // delete model in database + Delete(md interface{}, cols ...string) (int64, error) + // load related models to md model. + // args are limit, offset int and order string. + // + // example: + // Ormer.LoadRelated(post,"Tags") + // for _,tag := range post.Tags{...} + //args[0] bool true useDefaultRelsDepth ; false depth 0 + //args[0] int loadRelationDepth + //args[1] int limit default limit 1000 + //args[2] int offset default offset 0 + //args[3] string order for example : "-Id" + // make sure the relation is defined in model struct tags. + LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) + // create a models to models queryer + // for example: + // post := Post{Id: 4} + // m2m := Ormer.QueryM2M(&post, "Tags") + QueryM2M(md interface{}, name string) QueryM2Mer + // return a QuerySeter for table operations. + // table name can be string or struct. + // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), + QueryTable(ptrStructOrTableName interface{}) QuerySeter + // switch to another registered database driver by given name. + Using(name string) error + // begin transaction + // for example: + // o := NewOrm() + // err := o.Begin() + // ... + // err = o.Rollback() + Begin() error + // begin transaction with provided context and option + // the provided context is used until the transaction is committed or rolled back. + // if the context is canceled, the transaction will be rolled back. + // the provided TxOptions is optional and may be nil if defaults should be used. + // if a non-default isolation level is used that the driver doesn't support, an error will be returned. + // for example: + // o := NewOrm() + // err := o.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + // ... + // err = o.Rollback() + BeginTx(ctx context.Context, opts *sql.TxOptions) error + // commit transaction + Commit() error + // rollback transaction + Rollback() error + // return a raw query seter for raw sql string. + // for example: + // ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec() + // // update user testing's name to slene + Raw(query string, args ...interface{}) RawSeter + Driver() Driver + DBStats() *sql.DBStats +} + +// Inserter insert prepared statement +type Inserter interface { + Insert(interface{}) (int64, error) + Close() error +} + +// QuerySeter query seter +type QuerySeter interface { + // add condition expression to QuerySeter. + // for example: + // filter by UserName == 'slene' + // qs.Filter("UserName", "slene") + // sql : left outer join profile on t0.id1==t1.id2 where t1.age == 28 + // Filter("profile__Age", 28) + // // time compare + // qs.Filter("created", time.Now()) + Filter(string, ...interface{}) QuerySeter + // add raw sql to querySeter. + // for example: + // qs.FilterRaw("user_id IN (SELECT id FROM profile WHERE age>=18)") + // //sql-> WHERE user_id IN (SELECT id FROM profile WHERE age>=18) + FilterRaw(string, string) QuerySeter + // add NOT condition to querySeter. + // have the same usage as Filter + Exclude(string, ...interface{}) QuerySeter + // set condition to QuerySeter. + // sql's where condition + // cond := orm.NewCondition() + // cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000) + // //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000 + // num, err := qs.SetCond(cond1).Count() + SetCond(*Condition) QuerySeter + // get condition from QuerySeter. + // sql's where condition + // cond := orm.NewCondition() + // cond = cond.And("profile__isnull", false).AndNot("status__in", 1) + // qs = qs.SetCond(cond) + // cond = qs.GetCond() + // cond := cond.Or("profile__age__gt", 2000) + // //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000 + // num, err := qs.SetCond(cond).Count() + GetCond() *Condition + // add LIMIT value. + // args[0] means offset, e.g. LIMIT num,offset. + // if Limit <= 0 then Limit will be set to default limit ,eg 1000 + // if QuerySeter doesn't call Limit, the sql's Limit will be set to default limit, eg 1000 + // for example: + // qs.Limit(10, 2) + // // sql-> limit 10 offset 2 + Limit(limit interface{}, args ...interface{}) QuerySeter + // add OFFSET value + // same as Limit function's args[0] + Offset(offset interface{}) QuerySeter + // add GROUP BY expression + // for example: + // qs.GroupBy("id") + GroupBy(exprs ...string) QuerySeter + // add ORDER expression. + // "column" means ASC, "-column" means DESC. + // for example: + // qs.OrderBy("-status") + OrderBy(exprs ...string) QuerySeter + // set relation model to query together. + // it will query relation models and assign to parent model. + // for example: + // // will load all related fields use left join . + // qs.RelatedSel().One(&user) + // // will load related field only profile + // qs.RelatedSel("profile").One(&user) + // user.Profile.Age = 32 + RelatedSel(params ...interface{}) QuerySeter + // Set Distinct + // for example: + // o.QueryTable("policy").Filter("Groups__Group__Users__User", user). + // Distinct(). + // All(&permissions) + Distinct() QuerySeter + // set FOR UPDATE to query. + // for example: + // o.QueryTable("user").Filter("uid", uid).ForUpdate().All(&users) + ForUpdate() QuerySeter + // return QuerySeter execution result number + // for example: + // num, err = qs.Filter("profile__age__gt", 28).Count() + Count() (int64, error) + // check result empty or not after QuerySeter executed + // the same as QuerySeter.Count > 0 + Exist() bool + // execute update with parameters + // for example: + // num, err = qs.Filter("user_name", "slene").Update(Params{ + // "Nums": ColValue(Col_Minus, 50), + // }) // user slene's Nums will minus 50 + // num, err = qs.Filter("UserName", "slene").Update(Params{ + // "user_name": "slene2" + // }) // user slene's name will change to slene2 + Update(values Params) (int64, error) + // delete from table + //for example: + // num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete() + // //delete two user who's name is testing1 or testing2 + Delete() (int64, error) + // return a insert queryer. + // it can be used in times. + // example: + // i,err := sq.PrepareInsert() + // num, err = i.Insert(&user1) // user table will add one record user1 at once + // num, err = i.Insert(&user2) // user table will add one record user2 at once + // err = i.Close() //don't forget call Close + PrepareInsert() (Inserter, error) + // query all data and map to containers. + // cols means the columns when querying. + // for example: + // var users []*User + // qs.All(&users) // users[0],users[1],users[2] ... + All(container interface{}, cols ...string) (int64, error) + // query one row data and map to containers. + // cols means the columns when querying. + // for example: + // var user User + // qs.One(&user) //user.UserName == "slene" + One(container interface{}, cols ...string) error + // query all data and map to []map[string]interface. + // expres means condition expression. + // it converts data to []map[column]value. + // for example: + // var maps []Params + // qs.Values(&maps) //maps[0]["UserName"]=="slene" + Values(results *[]Params, exprs ...string) (int64, error) + // query all data and map to [][]interface + // it converts data to [][column_index]value + // for example: + // var list []ParamsList + // qs.ValuesList(&list) // list[0][1] == "slene" + ValuesList(results *[]ParamsList, exprs ...string) (int64, error) + // query all data and map to []interface. + // it's designed for one column record set, auto change to []value, not [][column]value. + // for example: + // var list ParamsList + // qs.ValuesFlat(&list, "UserName") // list[0] == "slene" + ValuesFlat(result *ParamsList, expr string) (int64, error) + // query all rows into map[string]interface with specify key and value column name. + // keyCol = "name", valueCol = "value" + // table data + // name | value + // total | 100 + // found | 200 + // to map[string]interface{}{ + // "total": 100, + // "found": 200, + // } + RowsToMap(result *Params, keyCol, valueCol string) (int64, error) + // query all rows into struct with specify key and value column name. + // keyCol = "name", valueCol = "value" + // table data + // name | value + // total | 100 + // found | 200 + // to struct { + // Total int + // Found int + // } + RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) +} + +// QueryM2Mer model to model query struct +// all operations are on the m2m table only, will not affect the origin model table +type QueryM2Mer interface { + // add models to origin models when creating queryM2M. + // example: + // m2m := orm.QueryM2M(post,"Tag") + // m2m.Add(&Tag1{},&Tag2{}) + // for _,tag := range post.Tags{}{ ... } + // param could also be any of the follow + // []*Tag{{Id:3,Name: "TestTag1"}, {Id:4,Name: "TestTag2"}} + // &Tag{Id:5,Name: "TestTag3"} + // []interface{}{&Tag{Id:6,Name: "TestTag4"}} + // insert one or more rows to m2m table + // make sure the relation is defined in post model struct tag. + Add(...interface{}) (int64, error) + // remove models following the origin model relationship + // only delete rows from m2m table + // for example: + //tag3 := &Tag{Id:5,Name: "TestTag3"} + //num, err = m2m.Remove(tag3) + Remove(...interface{}) (int64, error) + // check model is existed in relationship of origin model + Exist(interface{}) bool + // clean all models in related of origin model + Clear() (int64, error) + // count all related models of origin model + Count() (int64, error) +} + +// RawPreparer raw query statement +type RawPreparer interface { + Exec(...interface{}) (sql.Result, error) + Close() error +} + +// RawSeter raw query seter +// create From Ormer.Raw +// for example: +// sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q) +// rs := Ormer.Raw(sql, 1) +type RawSeter interface { + //execute sql and get result + Exec() (sql.Result, error) + //query data and map to container + //for example: + // var name string + // var id int + // rs.QueryRow(&id,&name) // id==2 name=="slene" + QueryRow(containers ...interface{}) error + + // query data rows and map to container + // var ids []int + // var names []int + // query = fmt.Sprintf("SELECT 'id','name' FROM %suser%s", Q, Q) + // num, err = dORM.Raw(query).QueryRows(&ids,&names) // ids=>{1,2},names=>{"nobody","slene"} + QueryRows(containers ...interface{}) (int64, error) + SetArgs(...interface{}) RawSeter + // query data to []map[string]interface + // see QuerySeter's Values + Values(container *[]Params, cols ...string) (int64, error) + // query data to [][]interface + // see QuerySeter's ValuesList + ValuesList(container *[]ParamsList, cols ...string) (int64, error) + // query data to []interface + // see QuerySeter's ValuesFlat + ValuesFlat(container *ParamsList, cols ...string) (int64, error) + // query all rows into map[string]interface with specify key and value column name. + // keyCol = "name", valueCol = "value" + // table data + // name | value + // total | 100 + // found | 200 + // to map[string]interface{}{ + // "total": 100, + // "found": 200, + // } + RowsToMap(result *Params, keyCol, valueCol string) (int64, error) + // query all rows into struct with specify key and value column name. + // keyCol = "name", valueCol = "value" + // table data + // name | value + // total | 100 + // found | 200 + // to struct { + // Total int + // Found int + // } + RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) + + // return prepared raw statement for used in times. + // for example: + // pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare() + // r, err := pre.Exec("name1") // INSERT INTO tag (name) VALUES (`name1`) + Prepare() (RawPreparer, error) +} + +// stmtQuerier statement querier +type stmtQuerier interface { + Close() error + Exec(args ...interface{}) (sql.Result, error) + //ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) + Query(args ...interface{}) (*sql.Rows, error) + //QueryContext(args ...interface{}) (*sql.Rows, error) + QueryRow(args ...interface{}) *sql.Row + //QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row +} + +// db querier +type dbQuerier interface { + Prepare(query string) (*sql.Stmt, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + Exec(query string, args ...interface{}) (sql.Result, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + Query(query string, args ...interface{}) (*sql.Rows, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRow(query string, args ...interface{}) *sql.Row + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row +} + +// type DB interface { +// Begin() (*sql.Tx, error) +// Prepare(query string) (stmtQuerier, error) +// Exec(query string, args ...interface{}) (sql.Result, error) +// Query(query string, args ...interface{}) (*sql.Rows, error) +// QueryRow(query string, args ...interface{}) *sql.Row +// } + +// transaction beginner +type txer interface { + Begin() (*sql.Tx, error) + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) +} + +// transaction ending +type txEnder interface { + Commit() error + Rollback() error +} + +// base database struct +type dbBaser interface { + Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error + Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) + InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) + InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) + InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) + InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) + Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) + Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) + ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) + SupportUpdateJoin() bool + UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) + DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) + Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) + OperatorSQL(string) string + GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{}) + GenerateOperatorLeftCol(*fieldInfo, string, *string) + PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) + ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) + RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) + MaxLimit() uint64 + TableQuote() string + ReplaceMarks(*string) + HasReturningID(*modelInfo, *string) bool + TimeFromDB(*time.Time, *time.Location) + TimeToDB(*time.Time, *time.Location) + DbTypes() map[string]string + GetTables(dbQuerier) (map[string]bool, error) + GetColumns(dbQuerier, string) (map[string][3]string, error) + ShowTablesQuery() string + ShowColumnsQuery(string) string + IndexExists(dbQuerier, string, string) bool + collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) + setval(dbQuerier, *modelInfo, []string) error +} diff --git a/pkg/orm/utils.go b/pkg/orm/utils.go new file mode 100644 index 00000000..3ff76772 --- /dev/null +++ b/pkg/orm/utils.go @@ -0,0 +1,319 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "math/big" + "reflect" + "strconv" + "strings" + "time" +) + +type fn func(string) string + +var ( + nameStrategyMap = map[string]fn{ + defaultNameStrategy: snakeString, + SnakeAcronymNameStrategy: snakeStringWithAcronym, + } + defaultNameStrategy = "snakeString" + SnakeAcronymNameStrategy = "snakeStringWithAcronym" + nameStrategy = defaultNameStrategy +) + +// StrTo is the target string +type StrTo string + +// Set string +func (f *StrTo) Set(v string) { + if v != "" { + *f = StrTo(v) + } else { + f.Clear() + } +} + +// Clear string +func (f *StrTo) Clear() { + *f = StrTo(0x1E) +} + +// Exist check string exist +func (f StrTo) Exist() bool { + return string(f) != string(0x1E) +} + +// Bool string to bool +func (f StrTo) Bool() (bool, error) { + return strconv.ParseBool(f.String()) +} + +// Float32 string to float32 +func (f StrTo) Float32() (float32, error) { + v, err := strconv.ParseFloat(f.String(), 32) + return float32(v), err +} + +// Float64 string to float64 +func (f StrTo) Float64() (float64, error) { + return strconv.ParseFloat(f.String(), 64) +} + +// Int string to int +func (f StrTo) Int() (int, error) { + v, err := strconv.ParseInt(f.String(), 10, 32) + return int(v), err +} + +// Int8 string to int8 +func (f StrTo) Int8() (int8, error) { + v, err := strconv.ParseInt(f.String(), 10, 8) + return int8(v), err +} + +// Int16 string to int16 +func (f StrTo) Int16() (int16, error) { + v, err := strconv.ParseInt(f.String(), 10, 16) + return int16(v), err +} + +// Int32 string to int32 +func (f StrTo) Int32() (int32, error) { + v, err := strconv.ParseInt(f.String(), 10, 32) + return int32(v), err +} + +// Int64 string to int64 +func (f StrTo) Int64() (int64, error) { + v, err := strconv.ParseInt(f.String(), 10, 64) + if err != nil { + i := new(big.Int) + ni, ok := i.SetString(f.String(), 10) // octal + if !ok { + return v, err + } + return ni.Int64(), nil + } + return v, err +} + +// Uint string to uint +func (f StrTo) Uint() (uint, error) { + v, err := strconv.ParseUint(f.String(), 10, 32) + return uint(v), err +} + +// Uint8 string to uint8 +func (f StrTo) Uint8() (uint8, error) { + v, err := strconv.ParseUint(f.String(), 10, 8) + return uint8(v), err +} + +// Uint16 string to uint16 +func (f StrTo) Uint16() (uint16, error) { + v, err := strconv.ParseUint(f.String(), 10, 16) + return uint16(v), err +} + +// Uint32 string to uint32 +func (f StrTo) Uint32() (uint32, error) { + v, err := strconv.ParseUint(f.String(), 10, 32) + return uint32(v), err +} + +// Uint64 string to uint64 +func (f StrTo) Uint64() (uint64, error) { + v, err := strconv.ParseUint(f.String(), 10, 64) + if err != nil { + i := new(big.Int) + ni, ok := i.SetString(f.String(), 10) + if !ok { + return v, err + } + return ni.Uint64(), nil + } + return v, err +} + +// String string to string +func (f StrTo) String() string { + if f.Exist() { + return string(f) + } + return "" +} + +// ToStr interface to string +func ToStr(value interface{}, args ...int) (s string) { + switch v := value.(type) { + case bool: + s = strconv.FormatBool(v) + case float32: + s = strconv.FormatFloat(float64(v), 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 32)) + case float64: + s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64)) + case int: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int8: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int16: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int32: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int64: + s = strconv.FormatInt(v, argInt(args).Get(0, 10)) + case uint: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint8: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint16: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint32: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint64: + s = strconv.FormatUint(v, argInt(args).Get(0, 10)) + case string: + s = v + case []byte: + s = string(v) + default: + s = fmt.Sprintf("%v", v) + } + return s +} + +// ToInt64 interface to int64 +func ToInt64(value interface{}) (d int64) { + val := reflect.ValueOf(value) + switch value.(type) { + case int, int8, int16, int32, int64: + d = val.Int() + case uint, uint8, uint16, uint32, uint64: + d = int64(val.Uint()) + default: + panic(fmt.Errorf("ToInt64 need numeric not `%T`", value)) + } + return +} + +func snakeStringWithAcronym(s string) string { + data := make([]byte, 0, len(s)*2) + num := len(s) + for i := 0; i < num; i++ { + d := s[i] + before := false + after := false + if i > 0 { + before = s[i-1] >= 'a' && s[i-1] <= 'z' + } + if i+1 < num { + after = s[i+1] >= 'a' && s[i+1] <= 'z' + } + if i > 0 && d >= 'A' && d <= 'Z' && (before || after) { + data = append(data, '_') + } + data = append(data, d) + } + return strings.ToLower(string(data[:])) +} + +// snake string, XxYy to xx_yy , XxYY to xx_y_y +func snakeString(s string) string { + data := make([]byte, 0, len(s)*2) + j := false + num := len(s) + for i := 0; i < num; i++ { + d := s[i] + if i > 0 && d >= 'A' && d <= 'Z' && j { + data = append(data, '_') + } + if d != '_' { + j = true + } + data = append(data, d) + } + return strings.ToLower(string(data[:])) +} + +// SetNameStrategy set different name strategy +func SetNameStrategy(s string) { + if SnakeAcronymNameStrategy != s { + nameStrategy = defaultNameStrategy + } + nameStrategy = s +} + +// camel string, xx_yy to XxYy +func camelString(s string) string { + data := make([]byte, 0, len(s)) + flag, num := true, len(s)-1 + for i := 0; i <= num; i++ { + d := s[i] + if d == '_' { + flag = true + continue + } else if flag { + if d >= 'a' && d <= 'z' { + d = d - 32 + } + flag = false + } + data = append(data, d) + } + return string(data[:]) +} + +type argString []string + +// get string by index from string slice +func (a argString) Get(i int, args ...string) (r string) { + if i >= 0 && i < len(a) { + r = a[i] + } else if len(args) > 0 { + r = args[0] + } + return +} + +type argInt []int + +// get int by index from int slice +func (a argInt) Get(i int, args ...int) (r int) { + if i >= 0 && i < len(a) { + r = a[i] + } + if len(args) > 0 { + r = args[0] + } + return +} + +// parse time to string with location +func timeParse(dateString, format string) (time.Time, error) { + tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc) + return tp, err +} + +// get pointer indirect type +func indirectType(v reflect.Type) reflect.Type { + switch v.Kind() { + case reflect.Ptr: + return indirectType(v.Elem()) + default: + return v + } +} diff --git a/pkg/orm/utils_test.go b/pkg/orm/utils_test.go new file mode 100644 index 00000000..7d94cada --- /dev/null +++ b/pkg/orm/utils_test.go @@ -0,0 +1,70 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "testing" +) + +func TestCamelString(t *testing.T) { + snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"} + camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"} + + answer := make(map[string]string) + for i, v := range snake { + answer[v] = camel[i] + } + + for _, v := range snake { + res := camelString(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} + +func TestSnakeString(t *testing.T) { + camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} + snake := []string{"pic_url", "hello_world", "hello_world", "hel_l_o_word", "pic_url1", "xy_x_x"} + + answer := make(map[string]string) + for i, v := range camel { + answer[v] = snake[i] + } + + for _, v := range camel { + res := snakeString(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} + +func TestSnakeStringWithAcronym(t *testing.T) { + camel := []string{"ID", "PicURL", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} + snake := []string{"id", "pic_url", "hello_world", "hello_world", "hel_lo_word", "pic_url1", "xy_xx"} + + answer := make(map[string]string) + for i, v := range camel { + answer[v] = snake[i] + } + + for _, v := range camel { + res := snakeStringWithAcronym(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +}