diff --git a/orm/cmd.go b/orm/cmd.go index 97545da4..95be7f4a 100644 --- a/orm/cmd.go +++ b/orm/cmd.go @@ -16,6 +16,7 @@ var ( commands = make(map[string]commander) ) +// print help. func printHelp(errs ...string) { content := `orm command usage: @@ -31,6 +32,7 @@ func printHelp(errs ...string) { os.Exit(2) } +// listen for orm command and then run it if command arguments passed. func RunCommand() { if len(os.Args) < 2 || os.Args[1] != "orm" { return @@ -58,6 +60,7 @@ func RunCommand() { } } +// sync database struct command interface. type commandSyncDb struct { al *alias force bool @@ -66,6 +69,7 @@ type commandSyncDb struct { rtOnError bool } +// parse orm command line arguments. func (d *commandSyncDb) Parse(args []string) { var name string @@ -78,6 +82,7 @@ func (d *commandSyncDb) Parse(args []string) { d.al = getDbAlias(name) } +// run orm line command. func (d *commandSyncDb) Run() error { var drops []string if d.force { @@ -208,10 +213,12 @@ func (d *commandSyncDb) Run() error { return nil } +// database creation commander interface implement. type commandSqlAll struct { al *alias } +// parse orm command line arguments. func (d *commandSqlAll) Parse(args []string) { var name string @@ -222,6 +229,7 @@ func (d *commandSqlAll) Parse(args []string) { d.al = getDbAlias(name) } +// run orm line command. func (d *commandSqlAll) Run() error { sqls, indexes := getDbCreateSql(d.al) var all []string @@ -243,6 +251,10 @@ func init() { commands["sqlall"] = new(commandSqlAll) } +// run syncdb command line. +// name means table's alias name. default is "default". +// force means run next sql if the current is error. +// verbose means show all info when running command or not. func RunSyncdb(name string, force bool, verbose bool) error { BootStrap() diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go index 6fcb4b01..8f6d94db 100644 --- a/orm/cmd_utils.go +++ b/orm/cmd_utils.go @@ -12,6 +12,7 @@ type dbIndex struct { Sql string } +// create database drop sql. func getDbDropSql(al *alias) (sqls []string) { if len(modelCache.cache) == 0 { fmt.Println("no Model found, need register your model") @@ -26,6 +27,7 @@ func getDbDropSql(al *alias) (sqls []string) { return sqls } +// get database column type string. func getColumnTyp(al *alias, fi *fieldInfo) (col string) { T := al.DbBaser.DbTypes() fieldType := fi.fieldType @@ -79,6 +81,7 @@ checkColumn: return } +// create alter sql string. func getColumnAddQuery(al *alias, fi *fieldInfo) string { Q := al.DbBaser.TableQuote() typ := getColumnTyp(al, fi) @@ -90,6 +93,7 @@ func getColumnAddQuery(al *alias, fi *fieldInfo) string { return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s", Q, fi.mi.table, Q, Q, fi.column, Q, typ) } +// create database creation string. func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) { if len(modelCache.cache) == 0 { fmt.Println("no Model found, need register your model") diff --git a/orm/db.go b/orm/db.go index c6e92ec9..10967fc5 100644 --- a/orm/db.go +++ b/orm/db.go @@ -15,7 +15,7 @@ const ( ) var ( - ErrMissPK = errors.New("missed pk value") + ErrMissPK = errors.New("missed pk value") // missing pk error ) var ( @@ -45,12 +45,15 @@ var ( } ) +// an instance of dbBaser interface/ type dbBase struct { ins dbBaser } +// check dbBase implements dbBaser interface. var _ dbBaser = new(dbBase) +// get struct columns values as interface slice. func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, err error) { var columns []string @@ -87,6 +90,7 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, return } +// get one field value in struct column as interface. func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) { var value interface{} if fi.pk { @@ -155,6 +159,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val return value, nil } +// create insert sql preparation statement object. func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { Q := d.ins.TableQuote() @@ -180,6 +185,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, return stmt, query, err } +// insert struct with prepared statement and given struct reflect value. func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) if err != nil { @@ -200,6 +206,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, } } +// query sql ,read records and persist in dbBaser. func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) error { var whereCols []string var args []interface{} @@ -259,6 +266,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo return nil } +// execute insert sql dbQuerier with given struct reflect.Value. func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { names := make([]string, 0, len(mi.fields.dbcols)-1) values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz) @@ -269,6 +277,7 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. return d.InsertValue(q, mi, false, names, values) } +// multi-insert sql with given slice struct reflect.Value. func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { var ( cnt int64 @@ -325,6 +334,8 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul return cnt, nil } +// execute insert sql with given struct and given values. +// insert the given values, not the field values in struct. func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { Q := d.ins.TableQuote() @@ -364,6 +375,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s } } +// execute update sql dbQuerier with given struct reflect.Value. func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { pkName, pkValue, ok := getExistPk(mi, ind) if ok == false { @@ -404,6 +416,8 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. return 0, nil } +// execute delete sql dbQuerier with given struct reflect.Value. +// delete index is pk. func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { pkName, pkValue, ok := getExistPk(mi, ind) if ok == false { @@ -445,6 +459,8 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. return 0, nil } +// update table-related record by querySet. +// need querySet not struct reflect.Value to update related records. func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { columns := make([]string, 0, len(params)) values := make([]interface{}, 0, len(params)) @@ -520,6 +536,8 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con return 0, nil } +// delete related records. +// do UpdateBanch or DeleteBanch by condition of tables' relationship. func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { for _, fi := range mi.fields.fieldsReverse { fi = fi.reverseFieldInfo @@ -546,6 +564,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz * return nil } +// delete table-related records. func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { tables := newDbTables(mi, d.ins) tables.skipEnd = true @@ -623,6 +642,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con return 0, nil } +// read related records. func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { val := reflect.ValueOf(container) @@ -832,6 +852,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi return cnt, nil } +// excute count sql and return count result int64. func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { tables := newDbTables(mi, d.ins) tables.parseRelated(qs.related, qs.relDepth) @@ -852,6 +873,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition return } +// generate sql with replacing operator string placeholders and replaced values. func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { sql := "" params := getFlatParams(fi, args, tz) @@ -909,6 +931,7 @@ func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) { // default not use } +// set values to struct column. func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) { for i, column := range cols { val := reflect.Indirect(reflect.ValueOf(values[i])).Interface() @@ -930,6 +953,7 @@ func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, } } +// convert value from database result to value following in field type. func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) { if val == nil { return nil, nil @@ -1082,6 +1106,7 @@ end: } +// set one value to struct column field. func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) { fieldType := fi.fieldType @@ -1156,6 +1181,7 @@ setValue: return value, nil } +// query sql, read values , save to *[]ParamList. func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { var ( @@ -1323,6 +1349,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond return cnt, nil } +// flag of update joined record. func (d *dbBase) SupportUpdateJoin() bool { return true } @@ -1331,30 +1358,37 @@ func (d *dbBase) MaxLimit() uint64 { return 18446744073709551615 } +// return quote. func (d *dbBase) TableQuote() string { return "`" } +// replace value placeholer in parametered sql string. func (d *dbBase) ReplaceMarks(query *string) { // default use `?` as mark, do nothing } +// flag of RETURNING sql. func (d *dbBase) HasReturningID(*modelInfo, *string) bool { return false } +// convert time from db. func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) { *t = t.In(tz) } +// convert time to db. func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) { *t = t.In(tz) } +// get database types. func (d *dbBase) DbTypes() map[string]string { return nil } +// gt all tables. func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) { tables := make(map[string]bool) query := d.ins.ShowTablesQuery() @@ -1379,6 +1413,7 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) { return tables, nil } +// get all cloumns in table. func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { columns := make(map[string][3]string) query := d.ins.ShowColumnsQuery(table) @@ -1405,18 +1440,22 @@ func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, e return columns, nil } +// not implement. func (d *dbBase) OperatorSql(operator string) string { panic(ErrNotImplement) } +// not implement. func (d *dbBase) ShowTablesQuery() string { panic(ErrNotImplement) } +// not implement. func (d *dbBase) ShowColumnsQuery(table string) string { panic(ErrNotImplement) } +// not implement. func (d *dbBase) IndexExists(dbQuerier, string, string) bool { panic(ErrNotImplement) } diff --git a/orm/db_alias.go b/orm/db_alias.go index 24924312..d50b6ebd 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -9,27 +9,32 @@ import ( "time" ) +// database driver constant int. type DriverType int const ( - _ DriverType = iota - DR_MySQL - DR_Sqlite - DR_Oracle - DR_Postgres + _ DriverType = iota // int enum type + DR_MySQL // mysql + DR_Sqlite // sqlite + DR_Oracle // oracle + DR_Postgres // pgsql ) +// database driver string. type driver string +// get type constant int of current driver.. func (d driver) Type() DriverType { a, _ := dataBaseCache.get(string(d)) return a.Driver } +// get name of current driver func (d driver) Name() string { return string(d) } +// check driver iis implemented Driver interface or not. var _ Driver = new(driver) var ( @@ -47,11 +52,13 @@ var ( } ) +// database alias cacher. type _dbCache struct { mux sync.RWMutex cache map[string]*alias } +// add database alias with original name. func (ac *_dbCache) add(name string, al *alias) (added bool) { ac.mux.Lock() defer ac.mux.Unlock() @@ -62,6 +69,7 @@ func (ac *_dbCache) add(name string, al *alias) (added bool) { return } +// get database alias if cached. func (ac *_dbCache) get(name string) (al *alias, ok bool) { ac.mux.RLock() defer ac.mux.RUnlock() @@ -69,6 +77,7 @@ func (ac *_dbCache) get(name string) (al *alias, ok bool) { return } +// get default alias. func (ac *_dbCache) getDefault() (al *alias) { al, _ = ac.get("default") return