diff --git a/beego.go b/beego.go index c22719c6..640fb5ab 100644 --- a/beego.go +++ b/beego.go @@ -221,7 +221,7 @@ func Run() { middleware.VERSION = VERSION middleware.AppName = AppName - middleware.RegisterErrorHander() + middleware.RegisterErrorHandler() if EnableAdmin { go BeeAdminApp.Run() diff --git a/middleware/error.go b/middleware/error.go index 35d9eb59..5c12b533 100644 --- a/middleware/error.go +++ b/middleware/error.go @@ -61,6 +61,7 @@ var tpl = ` ` +// render default application error page with error and stack string. func ShowErr(err interface{}, rw http.ResponseWriter, r *http.Request, Stack string) { t, _ := template.New("beegoerrortemp").Parse(tpl) data := make(map[string]string) @@ -175,13 +176,14 @@ var errtpl = ` ` +// map of http handlers for each error string. var ErrorMaps map[string]http.HandlerFunc func init() { ErrorMaps = make(map[string]http.HandlerFunc) } -//404 +// show 404 notfound error. func NotFound(rw http.ResponseWriter, r *http.Request) { t, _ := template.New("beegoerrortemp").Parse(errtpl) data := make(map[string]interface{}) @@ -199,7 +201,7 @@ func NotFound(rw http.ResponseWriter, r *http.Request) { t.Execute(rw, data) } -//401 +// show 401 unauthorized error. func Unauthorized(rw http.ResponseWriter, r *http.Request) { t, _ := template.New("beegoerrortemp").Parse(errtpl) data := make(map[string]interface{}) @@ -215,7 +217,7 @@ func Unauthorized(rw http.ResponseWriter, r *http.Request) { t.Execute(rw, data) } -//403 +// show 403 forbidden error. func Forbidden(rw http.ResponseWriter, r *http.Request) { t, _ := template.New("beegoerrortemp").Parse(errtpl) data := make(map[string]interface{}) @@ -232,7 +234,7 @@ func Forbidden(rw http.ResponseWriter, r *http.Request) { t.Execute(rw, data) } -//503 +// show 503 service unavailable error. func ServiceUnavailable(rw http.ResponseWriter, r *http.Request) { t, _ := template.New("beegoerrortemp").Parse(errtpl) data := make(map[string]interface{}) @@ -248,7 +250,7 @@ func ServiceUnavailable(rw http.ResponseWriter, r *http.Request) { t.Execute(rw, data) } -//500 +// show 500 internal server error. func InternalServerError(rw http.ResponseWriter, r *http.Request) { t, _ := template.New("beegoerrortemp").Parse(errtpl) data := make(map[string]interface{}) @@ -262,15 +264,18 @@ func InternalServerError(rw http.ResponseWriter, r *http.Request) { t.Execute(rw, data) } +// show 500 internal error with simple text string. func SimpleServerError(rw http.ResponseWriter, r *http.Request) { http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) } +// add http handler for given error string. func Errorhandler(err string, h http.HandlerFunc) { ErrorMaps[err] = h } -func RegisterErrorHander() { +// register default error http handlers, 404,401,403,500 and 503. +func RegisterErrorHandler() { if _, ok := ErrorMaps["404"]; !ok { ErrorMaps["404"] = NotFound } @@ -292,6 +297,8 @@ func RegisterErrorHander() { } } +// show error string as simple text message. +// if error string is empty, show 500 error as default. func Exception(errcode string, w http.ResponseWriter, r *http.Request, msg string) { if h, ok := ErrorMaps[errcode]; ok { isint, err := strconv.Atoi(errcode) diff --git a/middleware/exceptions.go b/middleware/exceptions.go index 5bf85956..b221dfcb 100644 --- a/middleware/exceptions.go +++ b/middleware/exceptions.go @@ -2,16 +2,19 @@ package middleware import "fmt" +// http exceptions type HTTPException struct { StatusCode int // http status code 4xx, 5xx Description string } +// return http exception error string, e.g. "400 Bad Request". func (e *HTTPException) Error() string { - // return `status description`, e.g. `400 Bad Request` return fmt.Sprintf("%d %s", e.StatusCode, e.Description) } +// map of http exceptions for each http status code int. +// defined 400,401,403,404,405,500,502,503 and 504 default. var HTTPExceptionMaps map[int]HTTPException func init() { diff --git a/orm/cmd.go b/orm/cmd.go index 97545da4..95be7f4a 100644 --- a/orm/cmd.go +++ b/orm/cmd.go @@ -16,6 +16,7 @@ var ( commands = make(map[string]commander) ) +// print help. func printHelp(errs ...string) { content := `orm command usage: @@ -31,6 +32,7 @@ func printHelp(errs ...string) { os.Exit(2) } +// listen for orm command and then run it if command arguments passed. func RunCommand() { if len(os.Args) < 2 || os.Args[1] != "orm" { return @@ -58,6 +60,7 @@ func RunCommand() { } } +// sync database struct command interface. type commandSyncDb struct { al *alias force bool @@ -66,6 +69,7 @@ type commandSyncDb struct { rtOnError bool } +// parse orm command line arguments. func (d *commandSyncDb) Parse(args []string) { var name string @@ -78,6 +82,7 @@ func (d *commandSyncDb) Parse(args []string) { d.al = getDbAlias(name) } +// run orm line command. func (d *commandSyncDb) Run() error { var drops []string if d.force { @@ -208,10 +213,12 @@ func (d *commandSyncDb) Run() error { return nil } +// database creation commander interface implement. type commandSqlAll struct { al *alias } +// parse orm command line arguments. func (d *commandSqlAll) Parse(args []string) { var name string @@ -222,6 +229,7 @@ func (d *commandSqlAll) Parse(args []string) { d.al = getDbAlias(name) } +// run orm line command. func (d *commandSqlAll) Run() error { sqls, indexes := getDbCreateSql(d.al) var all []string @@ -243,6 +251,10 @@ func init() { commands["sqlall"] = new(commandSqlAll) } +// run syncdb command line. +// name means table's alias name. default is "default". +// force means run next sql if the current is error. +// verbose means show all info when running command or not. func RunSyncdb(name string, force bool, verbose bool) error { BootStrap() diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go index 6fcb4b01..8f6d94db 100644 --- a/orm/cmd_utils.go +++ b/orm/cmd_utils.go @@ -12,6 +12,7 @@ type dbIndex struct { Sql string } +// create database drop sql. func getDbDropSql(al *alias) (sqls []string) { if len(modelCache.cache) == 0 { fmt.Println("no Model found, need register your model") @@ -26,6 +27,7 @@ func getDbDropSql(al *alias) (sqls []string) { return sqls } +// get database column type string. func getColumnTyp(al *alias, fi *fieldInfo) (col string) { T := al.DbBaser.DbTypes() fieldType := fi.fieldType @@ -79,6 +81,7 @@ checkColumn: return } +// create alter sql string. func getColumnAddQuery(al *alias, fi *fieldInfo) string { Q := al.DbBaser.TableQuote() typ := getColumnTyp(al, fi) @@ -90,6 +93,7 @@ func getColumnAddQuery(al *alias, fi *fieldInfo) string { return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s", Q, fi.mi.table, Q, Q, fi.column, Q, typ) } +// create database creation string. func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) { if len(modelCache.cache) == 0 { fmt.Println("no Model found, need register your model") diff --git a/orm/db.go b/orm/db.go index c6e92ec9..f12e76fb 100644 --- a/orm/db.go +++ b/orm/db.go @@ -15,7 +15,7 @@ const ( ) var ( - ErrMissPK = errors.New("missed pk value") + ErrMissPK = errors.New("missed pk value") // missing pk error ) var ( @@ -45,12 +45,15 @@ var ( } ) +// an instance of dbBaser interface/ type dbBase struct { ins dbBaser } +// check dbBase implements dbBaser interface. var _ dbBaser = new(dbBase) +// get struct columns values as interface slice. func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, err error) { var columns []string @@ -87,6 +90,7 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, return } +// get one field value in struct column as interface. func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) { var value interface{} if fi.pk { @@ -155,6 +159,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val return value, nil } +// create insert sql preparation statement object. func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { Q := d.ins.TableQuote() @@ -180,6 +185,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, return stmt, query, err } +// insert struct with prepared statement and given struct reflect value. func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) if err != nil { @@ -200,6 +206,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, } } +// query sql ,read records and persist in dbBaser. func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) error { var whereCols []string var args []interface{} @@ -259,6 +266,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo return nil } +// execute insert sql dbQuerier with given struct reflect.Value. func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { names := make([]string, 0, len(mi.fields.dbcols)-1) values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz) @@ -269,6 +277,7 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. return d.InsertValue(q, mi, false, names, values) } +// multi-insert sql with given slice struct reflect.Value. func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { var ( cnt int64 @@ -325,6 +334,8 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul return cnt, nil } +// execute insert sql with given struct and given values. +// insert the given values, not the field values in struct. func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { Q := d.ins.TableQuote() @@ -364,6 +375,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s } } +// execute update sql dbQuerier with given struct reflect.Value. func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { pkName, pkValue, ok := getExistPk(mi, ind) if ok == false { @@ -404,6 +416,8 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. return 0, nil } +// execute delete sql dbQuerier with given struct reflect.Value. +// delete index is pk. func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { pkName, pkValue, ok := getExistPk(mi, ind) if ok == false { @@ -445,6 +459,8 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. return 0, nil } +// update table-related record by querySet. +// need querySet not struct reflect.Value to update related records. func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { columns := make([]string, 0, len(params)) values := make([]interface{}, 0, len(params)) @@ -520,6 +536,8 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con return 0, nil } +// delete related records. +// do UpdateBanch or DeleteBanch by condition of tables' relationship. func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { for _, fi := range mi.fields.fieldsReverse { fi = fi.reverseFieldInfo @@ -546,6 +564,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz * return nil } +// delete table-related records. func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { tables := newDbTables(mi, d.ins) tables.skipEnd = true @@ -623,6 +642,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con return 0, nil } +// read related records. func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { val := reflect.ValueOf(container) @@ -832,6 +852,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi return cnt, nil } +// excute count sql and return count result int64. func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { tables := newDbTables(mi, d.ins) tables.parseRelated(qs.related, qs.relDepth) @@ -852,6 +873,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition return } +// generate sql with replacing operator string placeholders and replaced values. func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { sql := "" params := getFlatParams(fi, args, tz) @@ -905,10 +927,12 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri return sql, params } +// gernerate sql string with inner function, such as UPPER(text). func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) { // default not use } +// set values to struct column. func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) { for i, column := range cols { val := reflect.Indirect(reflect.ValueOf(values[i])).Interface() @@ -930,6 +954,7 @@ func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, } } +// convert value from database result to value following in field type. func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) { if val == nil { return nil, nil @@ -1082,6 +1107,7 @@ end: } +// set one value to struct column field. func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) { fieldType := fi.fieldType @@ -1156,6 +1182,7 @@ setValue: return value, nil } +// query sql, read values , save to *[]ParamList. func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { var ( @@ -1323,6 +1350,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond return cnt, nil } +// flag of update joined record. func (d *dbBase) SupportUpdateJoin() bool { return true } @@ -1331,30 +1359,37 @@ func (d *dbBase) MaxLimit() uint64 { return 18446744073709551615 } +// return quote. func (d *dbBase) TableQuote() string { return "`" } +// replace value placeholer in parametered sql string. func (d *dbBase) ReplaceMarks(query *string) { // default use `?` as mark, do nothing } +// flag of RETURNING sql. func (d *dbBase) HasReturningID(*modelInfo, *string) bool { return false } +// convert time from db. func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) { *t = t.In(tz) } +// convert time to db. func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) { *t = t.In(tz) } +// get database types. func (d *dbBase) DbTypes() map[string]string { return nil } +// gt all tables. func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) { tables := make(map[string]bool) query := d.ins.ShowTablesQuery() @@ -1379,6 +1414,7 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) { return tables, nil } +// get all cloumns in table. func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { columns := make(map[string][3]string) query := d.ins.ShowColumnsQuery(table) @@ -1405,18 +1441,22 @@ func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, e return columns, nil } +// not implement. func (d *dbBase) OperatorSql(operator string) string { panic(ErrNotImplement) } +// not implement. func (d *dbBase) ShowTablesQuery() string { panic(ErrNotImplement) } +// not implement. func (d *dbBase) ShowColumnsQuery(table string) string { panic(ErrNotImplement) } +// not implement. func (d *dbBase) IndexExists(dbQuerier, string, string) bool { panic(ErrNotImplement) } diff --git a/orm/db_alias.go b/orm/db_alias.go index 24924312..d50b6ebd 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -9,27 +9,32 @@ import ( "time" ) +// database driver constant int. type DriverType int const ( - _ DriverType = iota - DR_MySQL - DR_Sqlite - DR_Oracle - DR_Postgres + _ DriverType = iota // int enum type + DR_MySQL // mysql + DR_Sqlite // sqlite + DR_Oracle // oracle + DR_Postgres // pgsql ) +// database driver string. type driver string +// get type constant int of current driver.. func (d driver) Type() DriverType { a, _ := dataBaseCache.get(string(d)) return a.Driver } +// get name of current driver func (d driver) Name() string { return string(d) } +// check driver iis implemented Driver interface or not. var _ Driver = new(driver) var ( @@ -47,11 +52,13 @@ var ( } ) +// database alias cacher. type _dbCache struct { mux sync.RWMutex cache map[string]*alias } +// add database alias with original name. func (ac *_dbCache) add(name string, al *alias) (added bool) { ac.mux.Lock() defer ac.mux.Unlock() @@ -62,6 +69,7 @@ func (ac *_dbCache) add(name string, al *alias) (added bool) { return } +// get database alias if cached. func (ac *_dbCache) get(name string) (al *alias, ok bool) { ac.mux.RLock() defer ac.mux.RUnlock() @@ -69,6 +77,7 @@ func (ac *_dbCache) get(name string) (al *alias, ok bool) { return } +// get default alias. func (ac *_dbCache) getDefault() (al *alias) { al, _ = ac.get("default") return diff --git a/orm/db_mysql.go b/orm/db_mysql.go index da123079..566f2992 100644 --- a/orm/db_mysql.go +++ b/orm/db_mysql.go @@ -4,6 +4,7 @@ import ( "fmt" ) +// mysql operators. var mysqlOperators = map[string]string{ "exact": "= ?", "iexact": "LIKE ?", @@ -21,6 +22,7 @@ var mysqlOperators = map[string]string{ "iendswith": "LIKE ?", } +// mysql column field types. var mysqlTypes = map[string]string{ "auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY", "pk": "NOT NULL PRIMARY KEY", @@ -41,29 +43,35 @@ var mysqlTypes = map[string]string{ "float64-decimal": "numeric(%d, %d)", } +// mysql dbBaser implementation. type dbBaseMysql struct { dbBase } var _ dbBaser = new(dbBaseMysql) +// get mysql operator. func (d *dbBaseMysql) OperatorSql(operator string) string { return mysqlOperators[operator] } +// get mysql table field types. func (d *dbBaseMysql) DbTypes() map[string]string { return mysqlTypes } +// show table sql for mysql. func (d *dbBaseMysql) ShowTablesQuery() string { return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()" } +// show columns sql of table for mysql. func (d *dbBaseMysql) ShowColumnsQuery(table string) string { return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+ "WHERE table_schema = DATABASE() AND table_name = '%s'", table) } +// execute sql to check index exist. func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool { row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) @@ -72,6 +80,7 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool return cnt > 0 } +// create new mysql dbBaser. func newdbBaseMysql() dbBaser { b := new(dbBaseMysql) b.ins = b diff --git a/orm/db_oracle.go b/orm/db_oracle.go index ca1715ef..8e374122 100644 --- a/orm/db_oracle.go +++ b/orm/db_oracle.go @@ -1,11 +1,13 @@ package orm +// oracle dbBaser type dbBaseOracle struct { dbBase } var _ dbBaser = new(dbBaseOracle) +// create oracle dbBaser. func newdbBaseOracle() dbBaser { b := new(dbBaseOracle) b.ins = b diff --git a/orm/db_postgres.go b/orm/db_postgres.go index 4058fc10..d26511c0 100644 --- a/orm/db_postgres.go +++ b/orm/db_postgres.go @@ -5,6 +5,7 @@ import ( "strconv" ) +// postgresql operators. var postgresOperators = map[string]string{ "exact": "= ?", "iexact": "= UPPER(?)", @@ -20,6 +21,7 @@ var postgresOperators = map[string]string{ "iendswith": "LIKE UPPER(?)", } +// postgresql column field types. var postgresTypes = map[string]string{ "auto": "serial NOT NULL PRIMARY KEY", "pk": "NOT NULL PRIMARY KEY", @@ -40,16 +42,19 @@ var postgresTypes = map[string]string{ "float64-decimal": "numeric(%d, %d)", } +// postgresql dbBaser. type dbBasePostgres struct { dbBase } var _ dbBaser = new(dbBasePostgres) +// get postgresql operator. func (d *dbBasePostgres) OperatorSql(operator string) string { return postgresOperators[operator] } +// generate functioned sql string, such as contains(text). func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { switch operator { case "contains", "startswith", "endswith": @@ -59,6 +64,7 @@ func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, } } +// postgresql unsupports updating joined record. func (d *dbBasePostgres) SupportUpdateJoin() bool { return false } @@ -67,10 +73,13 @@ func (d *dbBasePostgres) MaxLimit() uint64 { return 0 } +// postgresql quote is ". func (d *dbBasePostgres) TableQuote() string { return `"` } +// postgresql value placeholder is $n. +// replace default ? to $n. func (d *dbBasePostgres) ReplaceMarks(query *string) { q := *query num := 0 @@ -97,6 +106,7 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) { *query = string(data) } +// make returning sql support for postgresql. func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) { if mi.fields.pk.auto { if query != nil { @@ -107,18 +117,22 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) return } +// show table sql for postgresql. func (d *dbBasePostgres) ShowTablesQuery() string { return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')" } +// show table columns sql for postgresql. func (d *dbBasePostgres) ShowColumnsQuery(table string) string { return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table) } +// get column types of postgresql. func (d *dbBasePostgres) DbTypes() map[string]string { return postgresTypes } +// check index exist in postgresql. func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool { query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name) row := db.QueryRow(query) @@ -127,6 +141,7 @@ func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bo return cnt > 0 } +// create new postgresql dbBaser. func newdbBasePostgres() dbBaser { b := new(dbBasePostgres) b.ins = b diff --git a/orm/db_sqlite.go b/orm/db_sqlite.go index 7711ded0..81692e2c 100644 --- a/orm/db_sqlite.go +++ b/orm/db_sqlite.go @@ -5,6 +5,7 @@ import ( "fmt" ) +// sqlite operators. var sqliteOperators = map[string]string{ "exact": "= ?", "iexact": "LIKE ? ESCAPE '\\'", @@ -20,6 +21,7 @@ var sqliteOperators = map[string]string{ "iendswith": "LIKE ? ESCAPE '\\'", } +// sqlite column types. var sqliteTypes = map[string]string{ "auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT", "pk": "NOT NULL PRIMARY KEY", @@ -40,38 +42,47 @@ var sqliteTypes = map[string]string{ "float64-decimal": "decimal", } +// sqlite dbBaser. type dbBaseSqlite struct { dbBase } var _ dbBaser = new(dbBaseSqlite) +// get sqlite operator. func (d *dbBaseSqlite) OperatorSql(operator string) string { return sqliteOperators[operator] } +// generate functioned sql for sqlite. +// only support DATE(text). func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { if fi.fieldType == TypeDateField { *leftCol = fmt.Sprintf("DATE(%s)", *leftCol) } } +// unable updating joined record in sqlite. func (d *dbBaseSqlite) SupportUpdateJoin() bool { return false } +// max int in sqlite. func (d *dbBaseSqlite) MaxLimit() uint64 { return 9223372036854775807 } +// get column types in sqlite. func (d *dbBaseSqlite) DbTypes() map[string]string { return sqliteTypes } +// get show tables sql in sqlite. func (d *dbBaseSqlite) ShowTablesQuery() string { return "SELECT name FROM sqlite_master WHERE type = 'table'" } +// get columns in sqlite. func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { query := d.ins.ShowColumnsQuery(table) rows, err := db.Query(query) @@ -92,10 +103,12 @@ func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]str return columns, nil } +// get show columns sql in sqlite. func (d *dbBaseSqlite) ShowColumnsQuery(table string) string { return fmt.Sprintf("pragma table_info('%s')", table) } +// check index exist in sqlite. func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool { query := fmt.Sprintf("PRAGMA index_list('%s')", table) rows, err := db.Query(query) @@ -113,6 +126,7 @@ func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool return false } +// create new sqlite dbBaser. func newdbBaseSqlite() dbBaser { b := new(dbBaseSqlite) b.ins = b diff --git a/orm/db_tables.go b/orm/db_tables.go index f5cacf38..854c4214 100644 --- a/orm/db_tables.go +++ b/orm/db_tables.go @@ -6,6 +6,7 @@ import ( "time" ) +// table info struct. type dbTable struct { id int index string @@ -18,6 +19,7 @@ type dbTable struct { jtl *dbTable } +// tables collection struct, contains some tables. type dbTables struct { tablesM map[string]*dbTable tables []*dbTable @@ -26,6 +28,8 @@ type dbTables struct { skipEnd bool } +// set table info to collection. +// if not exist, create new. func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable { name := strings.Join(names, ExprSep) if j, ok := t.tablesM[name]; ok { @@ -42,6 +46,7 @@ func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) return t.tablesM[name] } +// add table info to collection. func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) { name := strings.Join(names, ExprSep) if _, ok := t.tablesM[name]; ok == false { @@ -54,11 +59,14 @@ func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) return t.tablesM[name], false } +// get table info in collection. func (t *dbTables) get(name string) (*dbTable, bool) { j, ok := t.tablesM[name] return j, ok } +// get related fields info in recursive depth loop. +// loop once, depth decreases one. func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string { if depth < 0 || fi.fieldType == RelManyToMany { return related @@ -79,6 +87,7 @@ func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related [] return related } +// parse related fields. func (t *dbTables) parseRelated(rels []string, depth int) { relsNum := len(rels) @@ -140,6 +149,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) { } } +// generate join string. func (t *dbTables) getJoinSql() (join string) { Q := t.base.TableQuote() @@ -186,6 +196,7 @@ func (t *dbTables) getJoinSql() (join string) { return } +// parse orm model struct field tag expression. func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) { var ( jtl *dbTable @@ -300,6 +311,7 @@ loopFor: return } +// generate condition sql. func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) { if cond == nil || cond.IsEmpty() { return @@ -364,6 +376,7 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe return } +// generate order sql. func (t *dbTables) getOrderSql(orders []string) (orderSql string) { if len(orders) == 0 { return @@ -392,6 +405,7 @@ func (t *dbTables) getOrderSql(orders []string) (orderSql string) { return } +// generate limit sql. func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits string) { if limit == 0 { limit = int64(DefaultRowsLimit) @@ -414,6 +428,7 @@ func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits return } +// crete new tables collection. func newDbTables(mi *modelInfo, base dbBaser) *dbTables { tables := &dbTables{} tables.tablesM = make(map[string]*dbTable) diff --git a/orm/db_utils.go b/orm/db_utils.go index e2178294..34de8186 100644 --- a/orm/db_utils.go +++ b/orm/db_utils.go @@ -6,6 +6,7 @@ import ( "time" ) +// get table alias. func getDbAlias(name string) *alias { if al, ok := dataBaseCache.get(name); ok { return al @@ -15,6 +16,7 @@ func getDbAlias(name string) *alias { return nil } +// get pk column info. func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) { fi := mi.fields.pk @@ -37,6 +39,7 @@ func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interfac return } +// get fields description as flatted string. func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) { outFor: diff --git a/orm/models.go b/orm/models.go index 1cb25c4c..5744d865 100644 --- a/orm/models.go +++ b/orm/models.go @@ -41,6 +41,7 @@ var ( } ) +// model info collection type _modelCache struct { sync.RWMutex orders []string @@ -49,6 +50,7 @@ type _modelCache struct { done bool } +// get all model info func (mc *_modelCache) all() map[string]*modelInfo { m := make(map[string]*modelInfo, len(mc.cache)) for k, v := range mc.cache { @@ -57,6 +59,7 @@ func (mc *_modelCache) all() map[string]*modelInfo { return m } +// get orderd model info func (mc *_modelCache) allOrdered() []*modelInfo { m := make([]*modelInfo, 0, len(mc.orders)) for _, table := range mc.orders { @@ -65,16 +68,19 @@ func (mc *_modelCache) allOrdered() []*modelInfo { return m } +// get model info by table name func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) { mi, ok = mc.cache[table] return } +// get model info by field name func (mc *_modelCache) getByFN(name string) (mi *modelInfo, ok bool) { mi, ok = mc.cacheByFN[name] return } +// set model info to collection func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo { mii := mc.cache[table] mc.cache[table] = mi @@ -85,6 +91,7 @@ func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo { return mii } +// clean all model info. func (mc *_modelCache) clean() { mc.orders = make([]string, 0) mc.cache = make(map[string]*modelInfo) diff --git a/orm/models_boot.go b/orm/models_boot.go index 3274b187..03caeb62 100644 --- a/orm/models_boot.go +++ b/orm/models_boot.go @@ -8,6 +8,8 @@ import ( "strings" ) +// register models. +// prefix means table name prefix. func registerModel(model interface{}, prefix string) { val := reflect.ValueOf(model) ind := reflect.Indirect(val) @@ -67,6 +69,7 @@ func registerModel(model interface{}, prefix string) { modelCache.set(table, info) } +// boostrap models func bootStrap() { if modelCache.done { return @@ -281,6 +284,7 @@ end: } } +// register models func RegisterModel(models ...interface{}) { if modelCache.done { panic(fmt.Errorf("RegisterModel must be run before BootStrap")) @@ -302,6 +306,8 @@ func RegisterModelWithPrefix(prefix string, models ...interface{}) { } } +// bootrap models. +// make all model parsed and can not add more models func BootStrap() { if modelCache.done { return diff --git a/orm/models_info_f.go b/orm/models_info_f.go index 03736091..fadbb335 100644 --- a/orm/models_info_f.go +++ b/orm/models_info_f.go @@ -9,6 +9,7 @@ import ( var errSkipField = errors.New("skip field") +// field info collection type fields struct { pk *fieldInfo columns map[string]*fieldInfo @@ -23,6 +24,7 @@ type fields struct { dbcols []string } +// add field info func (f *fields) Add(fi *fieldInfo) (added bool) { if f.fields[fi.name] == nil && f.columns[fi.column] == nil { f.columns[fi.column] = fi @@ -49,14 +51,17 @@ func (f *fields) Add(fi *fieldInfo) (added bool) { return true } +// get field info by name func (f *fields) GetByName(name string) *fieldInfo { return f.fields[name] } +// get field info by column name func (f *fields) GetByColumn(column string) *fieldInfo { return f.columns[column] } +// get field info by string, name is prior func (f *fields) GetByAny(name string) (*fieldInfo, bool) { if fi, ok := f.fields[name]; ok { return fi, ok @@ -70,6 +75,7 @@ func (f *fields) GetByAny(name string) (*fieldInfo, bool) { return nil, false } +// create new field info collection func newFields() *fields { f := new(fields) f.fields = make(map[string]*fieldInfo) @@ -79,6 +85,7 @@ func newFields() *fields { return f } +// single field info type fieldInfo struct { mi *modelInfo fieldIndex int @@ -115,6 +122,7 @@ type fieldInfo struct { onDelete string } +// new field info func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField) (fi *fieldInfo, err error) { var ( tag string diff --git a/orm/models_info_m.go b/orm/models_info_m.go index 7a173781..b596fc6a 100644 --- a/orm/models_info_m.go +++ b/orm/models_info_m.go @@ -7,6 +7,7 @@ import ( "reflect" ) +// single model info type modelInfo struct { pkg string name string @@ -20,6 +21,7 @@ type modelInfo struct { isThrough bool } +// new model info func newModelInfo(val reflect.Value) (info *modelInfo) { var ( err error @@ -79,6 +81,8 @@ func newModelInfo(val reflect.Value) (info *modelInfo) { return } +// combine related model info to new model info. +// prepare for relation models query. func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) { info = new(modelInfo) info.fields = newFields() diff --git a/orm/models_utils.go b/orm/models_utils.go index 38095b7e..1466a724 100644 --- a/orm/models_utils.go +++ b/orm/models_utils.go @@ -7,10 +7,12 @@ import ( "time" ) +// get reflect.Type name with package path. func getFullName(typ reflect.Type) string { return typ.PkgPath() + "." + typ.Name() } +// get table name. method, or field name. auto snaked. func getTableName(val reflect.Value) string { ind := reflect.Indirect(val) fun := val.MethodByName("TableName") @@ -26,6 +28,7 @@ func getTableName(val reflect.Value) string { return snakeString(ind.Type().Name()) } +// get table engine, mysiam or innodb. func getTableEngine(val reflect.Value) string { fun := val.MethodByName("TableEngine") if fun.IsValid() { @@ -40,6 +43,7 @@ func getTableEngine(val reflect.Value) string { return "" } +// get table index from method. func getTableIndex(val reflect.Value) [][]string { fun := val.MethodByName("TableIndex") if fun.IsValid() { @@ -56,6 +60,7 @@ func getTableIndex(val reflect.Value) [][]string { return nil } +// get table unique from method func getTableUnique(val reflect.Value) [][]string { fun := val.MethodByName("TableUnique") if fun.IsValid() { @@ -72,6 +77,7 @@ func getTableUnique(val reflect.Value) [][]string { return nil } +// get snaked column name func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string { col = strings.ToLower(col) column := col @@ -89,6 +95,7 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col return column } +// return field type as type constant from reflect.Value func getFieldType(val reflect.Value) (ft int, err error) { elm := reflect.Indirect(val) switch elm.Kind() { @@ -128,6 +135,7 @@ func getFieldType(val reflect.Value) (ft int, err error) { return } +// parse struct tag string func parseStructTag(data string, attrs *map[string]bool, tags *map[string]string) { attr := make(map[string]bool) tag := make(map[string]string) diff --git a/orm/orm.go b/orm/orm.go index 9e3c3565..71b4daa4 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -40,6 +40,7 @@ type orm struct { var _ Ormer = new(orm) +// get model info and model reflect value func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) { val := reflect.ValueOf(md) ind = reflect.Indirect(val) @@ -54,6 +55,7 @@ func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect panic(fmt.Errorf(" table: `%s` not found, maybe not RegisterModel", name)) } +// get field info from model info by given field name func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { fi, ok := mi.fields.GetByAny(name) if !ok { @@ -62,6 +64,7 @@ func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { return fi } +// read data to model func (o *orm) Read(md interface{}, cols ...string) error { mi, ind := o.getMiInd(md, true) err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols) @@ -71,6 +74,7 @@ func (o *orm) Read(md interface{}, cols ...string) error { return nil } +// insert model data to database func (o *orm) Insert(md interface{}) (int64, error) { mi, ind := o.getMiInd(md, true) id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) @@ -83,6 +87,7 @@ func (o *orm) Insert(md interface{}) (int64, error) { return id, nil } +// set auto pk field func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) { if mi.fields.pk.auto { if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { @@ -93,6 +98,7 @@ func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) { } } +// insert some models to database func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) { var cnt int64 @@ -127,6 +133,8 @@ func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) { return cnt, nil } +// update model to database. +// cols set the columns those want to update. func (o *orm) Update(md interface{}, cols ...string) (int64, error) { mi, ind := o.getMiInd(md, true) num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) @@ -136,6 +144,7 @@ func (o *orm) Update(md interface{}, cols ...string) (int64, error) { return num, nil } +// delete model in database func (o *orm) Delete(md interface{}) (int64, error) { mi, ind := o.getMiInd(md, true) num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ) @@ -148,6 +157,7 @@ func (o *orm) Delete(md interface{}) (int64, error) { return num, nil } +// create a models to models queryer func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { mi, ind := o.getMiInd(md, true) fi := o.getFieldInfo(mi, name) @@ -162,6 +172,14 @@ func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { return newQueryM2M(md, o, mi, fi, ind) } +// load related models to md model. +// args are limit, offset int and order string. +// +// example: +// orm.LoadRelated(post,"Tags") +// for _,tag := range post.Tags{...} +// +// make sure the relation is defined in model struct tags. func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { _, fi, ind, qseter := o.queryRelated(md, name) @@ -223,12 +241,19 @@ func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int return nums, err } +// return a QuerySeter for related models to md model. +// it can do all, update, delete in QuerySeter. +// example: +// qs := orm.QueryRelated(post,"Tag") +// qs.All(&[]*Tag{}) +// func (o *orm) QueryRelated(md interface{}, name string) QuerySeter { // is this api needed ? _, _, _, qs := o.queryRelated(md, name) return qs } +// get QuerySeter for related models to md model func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { mi, ind := o.getMiInd(md, true) fi := o.getFieldInfo(mi, name) @@ -260,6 +285,7 @@ func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, return mi, fi, ind, qs } +// get reverse relation QuerySeter func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { switch fi.fieldType { case RelReverseOne, RelReverseMany: @@ -280,6 +306,7 @@ func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS return q } +// get relation QuerySeter func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { switch fi.fieldType { case RelOneToOne, RelForeignKey, RelManyToMany: @@ -299,6 +326,9 @@ func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { return q } +// return a QuerySeter for table operations. +// table name can be string or struct. +// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { name := "" if table, ok := ptrStructOrTableName.(string); ok { @@ -318,6 +348,7 @@ func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { return } +// switch to another registered database driver by given name. func (o *orm) Using(name string) error { if o.isTx { panic(fmt.Errorf(" transaction has been start, cannot change db")) @@ -335,6 +366,7 @@ func (o *orm) Using(name string) error { return nil } +// begin transaction func (o *orm) Begin() error { if o.isTx { return ErrTxHasBegan @@ -353,6 +385,7 @@ func (o *orm) Begin() error { return nil } +// commit transaction func (o *orm) Commit() error { if o.isTx == false { return ErrTxDone @@ -367,6 +400,7 @@ func (o *orm) Commit() error { return err } +// rollback transaction func (o *orm) Rollback() error { if o.isTx == false { return ErrTxDone @@ -381,14 +415,17 @@ func (o *orm) Rollback() error { return err } +// return a raw query seter for raw sql string. func (o *orm) Raw(query string, args ...interface{}) RawSeter { return newRawSet(o, query, args) } +// return current using database Driver func (o *orm) Driver() Driver { return driver(o.alias.Name) } +// create new orm func NewOrm() Ormer { BootStrap() // execute only once diff --git a/orm/orm_conds.go b/orm/orm_conds.go index 91d69986..5b1151e2 100644 --- a/orm/orm_conds.go +++ b/orm/orm_conds.go @@ -18,15 +18,19 @@ type condValue struct { isCond bool } +// condition struct. +// work for WHERE conditions. type Condition struct { params []condValue } +// return new condition struct func NewCondition() *Condition { c := &Condition{} return c } +// add expression to condition func (c Condition) And(expr string, args ...interface{}) *Condition { if expr == "" || len(args) == 0 { panic(fmt.Errorf(" args cannot empty")) @@ -35,6 +39,7 @@ func (c Condition) And(expr string, args ...interface{}) *Condition { return &c } +// add NOT expression to condition func (c Condition) AndNot(expr string, args ...interface{}) *Condition { if expr == "" || len(args) == 0 { panic(fmt.Errorf(" args cannot empty")) @@ -43,6 +48,7 @@ func (c Condition) AndNot(expr string, args ...interface{}) *Condition { return &c } +// combine a condition to current condition func (c *Condition) AndCond(cond *Condition) *Condition { c = c.clone() if c == cond { @@ -54,6 +60,7 @@ func (c *Condition) AndCond(cond *Condition) *Condition { return c } +// add OR expression to condition func (c Condition) Or(expr string, args ...interface{}) *Condition { if expr == "" || len(args) == 0 { panic(fmt.Errorf(" args cannot empty")) @@ -62,6 +69,7 @@ func (c Condition) Or(expr string, args ...interface{}) *Condition { return &c } +// add OR NOT expression to condition func (c Condition) OrNot(expr string, args ...interface{}) *Condition { if expr == "" || len(args) == 0 { panic(fmt.Errorf(" args cannot empty")) @@ -70,6 +78,7 @@ func (c Condition) OrNot(expr string, args ...interface{}) *Condition { return &c } +// combine a OR condition to current condition func (c *Condition) OrCond(cond *Condition) *Condition { c = c.clone() if c == cond { @@ -81,10 +90,12 @@ func (c *Condition) OrCond(cond *Condition) *Condition { return c } +// check the condition arguments are empty or not. func (c *Condition) IsEmpty() bool { return len(c.params) == 0 } +// clone a condition func (c Condition) clone() *Condition { return &c } diff --git a/orm/orm_log.go b/orm/orm_log.go index 0bb5d6f9..e6df797a 100644 --- a/orm/orm_log.go +++ b/orm/orm_log.go @@ -13,6 +13,7 @@ type Log struct { *log.Logger } +// set io.Writer to create a Logger. func NewLog(out io.Writer) *Log { d := new(Log) d.Logger = log.New(out, "[ORM]", 1e9) @@ -40,6 +41,8 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error DebugLog.Println(con) } +// statement query logger struct. +// if dev mode, use stmtQueryLog, or use stmtQuerier. type stmtQueryLog struct { alias *alias query string @@ -84,6 +87,8 @@ func newStmtQueryLog(alias *alias, stmt stmtQuerier, query string) stmtQuerier { return d } +// database query logger struct. +// if dev mode, use dbQueryLog, or use dbQuerier. type dbQueryLog struct { alias *alias db dbQuerier diff --git a/orm/orm_object.go b/orm/orm_object.go index 3c6d1f0e..fa644349 100644 --- a/orm/orm_object.go +++ b/orm/orm_object.go @@ -5,6 +5,7 @@ import ( "reflect" ) +// an insert queryer struct type insertSet struct { mi *modelInfo orm *orm @@ -14,6 +15,7 @@ type insertSet struct { var _ Inserter = new(insertSet) +// insert model ignore it's registered or not. func (o *insertSet) Insert(md interface{}) (int64, error) { if o.closed { return 0, ErrStmtClosed @@ -44,6 +46,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) { return id, nil } +// close insert queryer statement func (o *insertSet) Close() error { if o.closed { return ErrStmtClosed @@ -52,6 +55,7 @@ func (o *insertSet) Close() error { return o.stmt.Close() } +// create new insert queryer. func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) { bi := new(insertSet) bi.orm = orm diff --git a/orm/orm_querym2m.go b/orm/orm_querym2m.go index 6f0544d0..f0bc94b7 100644 --- a/orm/orm_querym2m.go +++ b/orm/orm_querym2m.go @@ -4,6 +4,7 @@ import ( "reflect" ) +// model to model struct type queryM2M struct { md interface{} mi *modelInfo @@ -12,6 +13,13 @@ type queryM2M struct { ind reflect.Value } +// add models to origin models when creating queryM2M. +// example: +// m2m := orm.QueryM2M(post,"Tag") +// m2m.Add(&Tag1{},&Tag2{}) +// for _,tag := range post.Tags{} +// +// make sure the relation is defined in post model struct tag. func (o *queryM2M) Add(mds ...interface{}) (int64, error) { fi := o.fi mi := fi.relThroughModelInfo @@ -67,6 +75,7 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { return dbase.InsertValue(orm.db, mi, true, names, values) } +// remove models following the origin model relationship func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { fi := o.fi qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md) @@ -78,17 +87,20 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { return nums, nil } +// check model is existed in relationship of origin model func (o *queryM2M) Exist(md interface{}) bool { fi := o.fi return o.qs.Filter(fi.reverseFieldInfo.name, o.md). Filter(fi.reverseFieldInfoTwo.name, md).Exist() } +// clean all models in related of origin model func (o *queryM2M) Clear() (int64, error) { fi := o.fi return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete() } +// count all related models of origin model func (o *queryM2M) Count() (int64, error) { fi := o.fi return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count() @@ -96,6 +108,7 @@ func (o *queryM2M) Count() (int64, error) { var _ QueryM2Mer = new(queryM2M) +// create new M2M queryer. func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer { qm2m := new(queryM2M) qm2m.md = md diff --git a/orm/orm_queryset.go b/orm/orm_queryset.go index b25d0542..ad8a9374 100644 --- a/orm/orm_queryset.go +++ b/orm/orm_queryset.go @@ -18,6 +18,10 @@ const ( Col_Except ) +// ColValue do the field raw changes. e.g Nums = Nums + 10. usage: +// Params{ +// "Nums": ColValue(Col_Add, 10), +// } func ColValue(opt operator, value interface{}) interface{} { switch opt { case Col_Add, Col_Minus, Col_Multiply, Col_Except: @@ -34,6 +38,7 @@ func ColValue(opt operator, value interface{}) interface{} { return val } +// real query struct type querySet struct { mi *modelInfo cond *Condition @@ -47,6 +52,7 @@ type querySet struct { var _ QuerySeter = new(querySet) +// add condition expression to QuerySeter. func (o querySet) Filter(expr string, args ...interface{}) QuerySeter { if o.cond == nil { o.cond = NewCondition() @@ -55,6 +61,7 @@ func (o querySet) Filter(expr string, args ...interface{}) QuerySeter { return &o } +// add NOT condition to querySeter. func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter { if o.cond == nil { o.cond = NewCondition() @@ -63,10 +70,13 @@ func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter { return &o } +// set offset number func (o *querySet) setOffset(num interface{}) { o.offset = ToInt64(num) } +// add LIMIT value. +// args[0] means offset, e.g. LIMIT num,offset. func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter { o.limit = ToInt64(limit) if len(args) > 0 { @@ -75,16 +85,21 @@ func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter { return &o } +// add OFFSET value func (o querySet) Offset(offset interface{}) QuerySeter { o.setOffset(offset) return &o } +// add ORDER expression. +// "column" means ASC, "-column" means DESC. func (o querySet) OrderBy(exprs ...string) QuerySeter { o.orders = exprs return &o } +// set relation model to query together. +// it will query relation models and assign to parent model. func (o querySet) RelatedSel(params ...interface{}) QuerySeter { var related []string if len(params) == 0 { @@ -105,36 +120,50 @@ func (o querySet) RelatedSel(params ...interface{}) QuerySeter { return &o } +// set condition to QuerySeter. func (o querySet) SetCond(cond *Condition) QuerySeter { o.cond = cond return &o } +// return QuerySeter execution result number func (o *querySet) Count() (int64, error) { return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) } +// check result empty or not after QuerySeter executed func (o *querySet) Exist() bool { cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) return cnt > 0 } +// execute update with parameters func (o *querySet) Update(values Params) (int64, error) { return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) } +// execute delete func (o *querySet) Delete() (int64, error) { return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) } +// return a insert queryer. +// it can be used in times. +// example: +// i,err := sq.PrepareInsert() +// i.Add(&user1{},&user2{}) func (o *querySet) PrepareInsert() (Inserter, error) { return newInsertSet(o.orm, o.mi) } +// query all data and map to containers. +// cols means the columns when querying. func (o *querySet) All(container interface{}, cols ...string) (int64, error) { return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) } +// query one row data and map to containers. +// cols means the columns when querying. func (o *querySet) One(container interface{}, cols ...string) error { num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) if err != nil { @@ -149,18 +178,26 @@ func (o *querySet) One(container interface{}, cols ...string) error { return nil } +// query all data and map to []map[string]interface. +// expres means condition expression. +// it converts data to []map[column]value. func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) { return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) } +// query all data and map to [][]interface +// it converts data to [][column_index]value func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) { return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) } +// query all data and map to []interface. +// it's designed for one row record set, auto change to []value, not [][column]value. func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) { return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) } +// create new QuerySeter. func newQuerySet(orm *orm, mi *modelInfo) QuerySeter { o := new(querySet) o.mi = mi diff --git a/orm/orm_raw.go b/orm/orm_raw.go index a713dbac..3f5fb162 100644 --- a/orm/orm_raw.go +++ b/orm/orm_raw.go @@ -7,6 +7,7 @@ import ( "time" ) +// raw sql string prepared statement type rawPrepare struct { rs *rawSet stmt stmtQuerier @@ -44,6 +45,7 @@ func newRawPreparer(rs *rawSet) (RawPreparer, error) { return o, nil } +// raw query seter type rawSet struct { query string args []interface{} @@ -52,11 +54,13 @@ type rawSet struct { var _ RawSeter = new(rawSet) +// set args for every query func (o rawSet) SetArgs(args ...interface{}) RawSeter { o.args = args return &o } +// execute raw sql and return sql.Result func (o *rawSet) Exec() (sql.Result, error) { query := o.query o.orm.alias.DbBaser.ReplaceMarks(&query) @@ -65,6 +69,7 @@ func (o *rawSet) Exec() (sql.Result, error) { return o.orm.db.Exec(query, args...) } +// set field value to row container func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { switch ind.Kind() { case reflect.Bool: @@ -163,6 +168,7 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { } } +// set field value in loop for slice container func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) { nInds := *nIndsPtr @@ -233,6 +239,7 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr } } +// query data and map to container func (o *rawSet) QueryRow(containers ...interface{}) error { refs := make([]interface{}, 0, len(containers)) sInds := make([]reflect.Value, 0) @@ -362,6 +369,7 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { return nil } +// query data rows and map to container func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { refs := make([]interface{}, 0, len(containers)) sInds := make([]reflect.Value, 0) @@ -615,18 +623,22 @@ func (o *rawSet) readValues(container interface{}) (int64, error) { return cnt, nil } +// query data to []map[string]interface func (o *rawSet) Values(container *[]Params) (int64, error) { return o.readValues(container) } +// query data to [][]interface func (o *rawSet) ValuesList(container *[]ParamsList) (int64, error) { return o.readValues(container) } +// query data to []interface func (o *rawSet) ValuesFlat(container *ParamsList) (int64, error) { return o.readValues(container) } +// return prepared raw statement for used in times. func (o *rawSet) Prepare() (RawPreparer, error) { return newRawPreparer(o) } diff --git a/orm/types.go b/orm/types.go index a6487fc0..6f13ed67 100644 --- a/orm/types.go +++ b/orm/types.go @@ -6,11 +6,13 @@ import ( "time" ) +// database driver type Driver interface { Name() string Type() DriverType } +// field info type Fielder interface { String() string FieldType() int @@ -18,6 +20,7 @@ type Fielder interface { RawValue() interface{} } +// orm struct type Ormer interface { Read(interface{}, ...string) error Insert(interface{}) (int64, error) @@ -35,11 +38,13 @@ type Ormer interface { Driver() Driver } +// insert prepared statement type Inserter interface { Insert(interface{}) (int64, error) Close() error } +// query seter type QuerySeter interface { Filter(string, ...interface{}) QuerySeter Exclude(string, ...interface{}) QuerySeter @@ -60,6 +65,7 @@ type QuerySeter interface { ValuesFlat(*ParamsList, string) (int64, error) } +// model to model query struct type QueryM2Mer interface { Add(...interface{}) (int64, error) Remove(...interface{}) (int64, error) @@ -68,11 +74,13 @@ type QueryM2Mer interface { Count() (int64, error) } +// raw query statement type RawPreparer interface { Exec(...interface{}) (sql.Result, error) Close() error } +// raw query seter type RawSeter interface { Exec() (sql.Result, error) QueryRow(...interface{}) error @@ -84,6 +92,7 @@ type RawSeter interface { Prepare() (RawPreparer, error) } +// statement querier type stmtQuerier interface { Close() error Exec(args ...interface{}) (sql.Result, error) @@ -91,6 +100,7 @@ type stmtQuerier interface { QueryRow(args ...interface{}) *sql.Row } +// db querier type dbQuerier interface { Prepare(query string) (*sql.Stmt, error) Exec(query string, args ...interface{}) (sql.Result, error) @@ -98,15 +108,18 @@ type dbQuerier interface { QueryRow(query string, args ...interface{}) *sql.Row } +// transaction beginner type txer interface { Begin() (*sql.Tx, error) } +// transaction ending type txEnder interface { Commit() error Rollback() error } +// base database struct type dbBaser interface { Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) diff --git a/orm/utils.go b/orm/utils.go index 237b3edf..2e347278 100644 --- a/orm/utils.go +++ b/orm/utils.go @@ -10,6 +10,7 @@ import ( type StrTo string +// set string func (f *StrTo) Set(v string) { if v != "" { *f = StrTo(v) @@ -18,77 +19,93 @@ func (f *StrTo) Set(v string) { } } +// clean string func (f *StrTo) Clear() { *f = StrTo(0x1E) } +// check string exist func (f StrTo) Exist() bool { return string(f) != string(0x1E) } +// string to bool func (f StrTo) Bool() (bool, error) { return strconv.ParseBool(f.String()) } +// string to float32 func (f StrTo) Float32() (float32, error) { v, err := strconv.ParseFloat(f.String(), 32) return float32(v), err } +// string to float64 func (f StrTo) Float64() (float64, error) { return strconv.ParseFloat(f.String(), 64) } +// string to int func (f StrTo) Int() (int, error) { v, err := strconv.ParseInt(f.String(), 10, 32) return int(v), err } +// string to int8 func (f StrTo) Int8() (int8, error) { v, err := strconv.ParseInt(f.String(), 10, 8) return int8(v), err } +// string to int16 func (f StrTo) Int16() (int16, error) { v, err := strconv.ParseInt(f.String(), 10, 16) return int16(v), err } +// string to int32 func (f StrTo) Int32() (int32, error) { v, err := strconv.ParseInt(f.String(), 10, 32) return int32(v), err } +// string to int64 func (f StrTo) Int64() (int64, error) { v, err := strconv.ParseInt(f.String(), 10, 64) return int64(v), err } +// string to uint func (f StrTo) Uint() (uint, error) { v, err := strconv.ParseUint(f.String(), 10, 32) return uint(v), err } +// string to uint8 func (f StrTo) Uint8() (uint8, error) { v, err := strconv.ParseUint(f.String(), 10, 8) return uint8(v), err } +// string to uint16 func (f StrTo) Uint16() (uint16, error) { v, err := strconv.ParseUint(f.String(), 10, 16) return uint16(v), err } +// string to uint31 func (f StrTo) Uint32() (uint32, error) { v, err := strconv.ParseUint(f.String(), 10, 32) return uint32(v), err } +// string to uint64 func (f StrTo) Uint64() (uint64, error) { v, err := strconv.ParseUint(f.String(), 10, 64) return uint64(v), err } +// string to string func (f StrTo) String() string { if f.Exist() { return string(f) @@ -96,6 +113,7 @@ func (f StrTo) String() string { return "" } +// interface to string func ToStr(value interface{}, args ...int) (s string) { switch v := value.(type) { case bool: @@ -134,6 +152,7 @@ func ToStr(value interface{}, args ...int) (s string) { return s } +// interface to int64 func ToInt64(value interface{}) (d int64) { val := reflect.ValueOf(value) switch value.(type) { @@ -147,6 +166,7 @@ func ToInt64(value interface{}) (d int64) { return } +// snake string, XxYy to xx_yy func snakeString(s string) string { data := make([]byte, 0, len(s)*2) j := false @@ -164,6 +184,7 @@ func snakeString(s string) string { return strings.ToLower(string(data[:len(data)])) } +// camel string, xx_yy to XxYy func camelString(s string) string { data := make([]byte, 0, len(s)) j := false @@ -190,6 +211,7 @@ func camelString(s string) string { type argString []string +// get string by index from string slice func (a argString) Get(i int, args ...string) (r string) { if i >= 0 && i < len(a) { r = a[i] @@ -201,6 +223,7 @@ func (a argString) Get(i int, args ...string) (r string) { type argInt []int +// get int by index from int slice func (a argInt) Get(i int, args ...int) (r int) { if i >= 0 && i < len(a) { r = a[i] @@ -213,6 +236,7 @@ func (a argInt) Get(i int, args ...int) (r int) { type argAny []interface{} +// get interface by index from interface slice func (a argAny) Get(i int, args ...interface{}) (r interface{}) { if i >= 0 && i < len(a) { r = a[i] @@ -223,15 +247,18 @@ func (a argAny) Get(i int, args ...interface{}) (r interface{}) { return } +// parse time to string with location func timeParse(dateString, format string) (time.Time, error) { tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc) return tp, err } +// format time string func timeFormat(t time.Time, format string) string { return t.Format(format) } +// get pointer indirect type func indirectType(v reflect.Type) reflect.Type { switch v.Kind() { case reflect.Ptr: