1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-22 13:10:54 +00:00

golint orm

This commit is contained in:
astaxie 2015-09-12 21:46:43 +08:00
parent 542e143e55
commit 68ec133aa8
25 changed files with 574 additions and 501 deletions

View File

@ -46,7 +46,7 @@ func printHelp(errs ...string) {
os.Exit(2)
}
// listen for orm command and then run it if command arguments passed.
// RunCommand listen for orm command and then run it if command arguments passed.
func RunCommand() {
if len(os.Args) < 2 || os.Args[1] != "orm" {
return
@ -100,7 +100,7 @@ func (d *commandSyncDb) Parse(args []string) {
func (d *commandSyncDb) Run() error {
var drops []string
if d.force {
drops = getDbDropSql(d.al)
drops = getDbDropSQL(d.al)
}
db := d.al.DB
@ -124,7 +124,7 @@ func (d *commandSyncDb) Run() error {
}
}
sqls, indexes := getDbCreateSql(d.al)
sqls, indexes := getDbCreateSQL(d.al)
tables, err := d.al.DbBaser.GetTables(db)
if err != nil {
@ -180,7 +180,7 @@ func (d *commandSyncDb) Run() error {
fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
}
query := idx.Sql
query := idx.SQL
_, err := db.Exec(query)
if d.verbose {
fmt.Printf(" %s\n", query)
@ -203,7 +203,7 @@ func (d *commandSyncDb) Run() error {
queries := []string{sqls[i]}
for _, idx := range indexes[mi.table] {
queries = append(queries, idx.Sql)
queries = append(queries, idx.SQL)
}
for _, query := range queries {
@ -228,12 +228,12 @@ func (d *commandSyncDb) Run() error {
}
// database creation commander interface implement.
type commandSqlAll struct {
type commandSQLAll struct {
al *alias
}
// parse orm command line arguments.
func (d *commandSqlAll) Parse(args []string) {
func (d *commandSQLAll) Parse(args []string) {
var name string
flagSet := flag.NewFlagSet("orm command: sqlall", flag.ExitOnError)
@ -244,13 +244,13 @@ func (d *commandSqlAll) Parse(args []string) {
}
// run orm line command.
func (d *commandSqlAll) Run() error {
sqls, indexes := getDbCreateSql(d.al)
func (d *commandSQLAll) Run() error {
sqls, indexes := getDbCreateSQL(d.al)
var all []string
for i, mi := range modelCache.allOrdered() {
queries := []string{sqls[i]}
for _, idx := range indexes[mi.table] {
queries = append(queries, idx.Sql)
queries = append(queries, idx.SQL)
}
sql := strings.Join(queries, "\n")
all = append(all, sql)
@ -262,10 +262,10 @@ func (d *commandSqlAll) Run() error {
func init() {
commands["syncdb"] = new(commandSyncDb)
commands["sqlall"] = new(commandSqlAll)
commands["sqlall"] = new(commandSQLAll)
}
// run syncdb command line.
// RunSyncdb run syncdb command line.
// name means table's alias name. default is "default".
// force means run next sql if the current is error.
// verbose means show all info when running command or not.

View File

@ -23,11 +23,11 @@ import (
type dbIndex struct {
Table string
Name string
Sql string
SQL string
}
// create database drop sql.
func getDbDropSql(al *alias) (sqls []string) {
func getDbDropSQL(al *alias) (sqls []string) {
if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model")
os.Exit(2)
@ -65,7 +65,7 @@ checkColumn:
case TypeIntegerField:
col = T["int32"]
case TypeBigIntegerField:
if al.Driver == DR_Sqlite {
if al.Driver == DRSqlite {
fieldType = TypeIntegerField
goto checkColumn
}
@ -112,7 +112,7 @@ func getColumnAddQuery(al *alias, fi *fieldInfo) string {
}
// create database creation string.
func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model")
os.Exit(2)
@ -142,7 +142,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
if fi.auto {
switch al.Driver {
case DR_Sqlite, DR_Postgres:
case DRSqlite, DRPostgres:
column += T["auto"]
default:
column += col + " " + T["auto"]
@ -201,7 +201,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
sql += strings.Join(columns, ",\n")
sql += "\n)"
if al.Driver == DR_MySQL {
if al.Driver == DRMySQL {
var engine string
if mi.model != nil {
engine = getTableEngine(mi.addrField)
@ -237,7 +237,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
index := dbIndex{}
index.Table = mi.table
index.Name = name
index.Sql = sql
index.SQL = sql
tableIndexes[mi.table] = append(tableIndexes[mi.table], index)
}
@ -247,7 +247,6 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
return
}
// Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands
func getColumnDefault(fi *fieldInfo) string {
var (
@ -264,7 +263,7 @@ func getColumnDefault(fi *fieldInfo) string {
// These defaults will be useful if there no config value orm:"default" and NOT NULL is on
switch fi.fieldType {
case TypeDateField, TypeDateTimeField:
return v;
return v
case TypeBooleanField, TypeBitField, TypeSmallIntegerField, TypeIntegerField,
TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField,

150
orm/db.go
View File

@ -24,12 +24,13 @@ import (
)
const (
format_Date = "2006-01-02"
format_DateTime = "2006-01-02 15:04:05"
formatDate = "2006-01-02"
formatDateTime = "2006-01-02 15:04:05"
)
var (
ErrMissPK = errors.New("missed pk value") // missing pk error
// ErrMissPK missing pk error
ErrMissPK = errors.New("missed pk value")
)
var (
@ -216,14 +217,14 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
}
}
if fi.null == false && value == nil {
return nil, errors.New(fmt.Sprintf("field `%s` cannot be NULL", fi.fullName))
return nil, fmt.Errorf("field `%s` cannot be NULL", fi.fullName)
}
}
}
}
switch fi.fieldType {
case TypeDateField, TypeDateTimeField:
if fi.auto_now || fi.auto_now_add && insert {
if fi.autoNow || fi.autoNowAdd && insert {
if insert {
if t, ok := value.(time.Time); ok && !t.IsZero() {
break
@ -282,14 +283,13 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
var id int64
err := row.Scan(&id)
return id, err
} else {
if res, err := stmt.Exec(values...); err == nil {
}
res, err := stmt.Exec(values...)
if err == nil {
return res.LastInsertId()
} else {
}
return 0, err
}
}
}
// query sql ,read records and persist in dbBaser.
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) error {
@ -339,15 +339,11 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
return ErrNoRows
}
return err
} else {
}
elm := reflect.New(mi.addrField.Elem().Type())
mind := reflect.Indirect(elm)
d.setColsValues(mi, &mind, mi.fields.dbcols, refs, tz)
ind.Set(mind)
}
return nil
}
@ -444,21 +440,20 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) {
if res, err := q.Exec(query, values...); err == nil {
res, err := q.Exec(query, values...)
if err == nil {
if isMulti {
return res.RowsAffected()
}
return res.LastInsertId()
} else {
}
return 0, err
}
} else {
row := q.QueryRow(query, values...)
var id int64
err := row.Scan(&id)
return id, err
}
}
// execute update sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
@ -495,9 +490,8 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
if res, err := q.Exec(query, setValues...); err == nil {
return res.RowsAffected()
} else {
return 0, err
}
return 0, err
}
// execute delete sql dbQuerier with given struct reflect.Value.
@ -513,14 +507,12 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, pkName, Q)
d.ins.ReplaceMarks(&query)
if res, err := q.Exec(query, pkValue); err == nil {
res, err := q.Exec(query, pkValue)
if err == nil {
num, err := res.RowsAffected()
if err != nil {
return 0, err
}
if num > 0 {
if mi.fields.pk.auto {
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
@ -529,17 +521,14 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
}
}
err := d.deleteRels(q, mi, []interface{}{pkValue}, tz)
if err != nil {
return num, err
}
}
return num, err
} else {
return 0, err
}
return 0, err
}
// update table-related record by querySet.
@ -565,11 +554,11 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
tables.parseRelated(qs.related, qs.relDepth)
}
where, args := tables.getCondSql(cond, false, tz)
where, args := tables.getCondSQL(cond, false, tz)
values = append(values, args...)
join := tables.getJoinSql()
join := tables.getJoinSQL()
var query, T string
@ -585,13 +574,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
col := fmt.Sprintf("%s%s%s%s", T, Q, v, Q)
if c, ok := values[i].(colValue); ok {
switch c.opt {
case Col_Add:
case ColAdd:
cols = append(cols, col+" = "+col+" + ?")
case Col_Minus:
case ColMinus:
cols = append(cols, col+" = "+col+" - ?")
case Col_Multiply:
case ColMultiply:
cols = append(cols, col+" = "+col+" * ?")
case Col_Except:
case ColExcept:
cols = append(cols, col+" = "+col+" / ?")
}
values[i] = c.value
@ -610,12 +599,11 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
}
d.ins.ReplaceMarks(&query)
if res, err := q.Exec(query, values...); err == nil {
res, err := q.Exec(query, values...)
if err == nil {
return res.RowsAffected()
} else {
return 0, err
}
return 0, err
}
// delete related records.
@ -624,23 +612,23 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *
for _, fi := range mi.fields.fieldsReverse {
fi = fi.reverseFieldInfo
switch fi.onDelete {
case od_CASCADE:
case odCascade:
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
_, err := d.DeleteBatch(q, nil, fi.mi, cond, tz)
if err != nil {
return err
}
case od_SET_DEFAULT, od_SET_NULL:
case odSetDefault, odSetNULL:
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
params := Params{fi.column: nil}
if fi.onDelete == od_SET_DEFAULT {
if fi.onDelete == odSetDefault {
params[fi.column] = fi.initial.String()
}
_, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz)
if err != nil {
return err
}
case od_DO_NOTHING:
case odDoNothing:
}
}
return nil
@ -661,8 +649,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
Q := d.ins.TableQuote()
where, args := tables.getCondSql(cond, false, tz)
join := tables.getJoinSql()
where, args := tables.getCondSQL(cond, false, tz)
join := tables.getJoinSQL()
cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q)
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s", cols, Q, mi.table, Q, join, where)
@ -670,16 +658,14 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
d.ins.ReplaceMarks(&query)
var rs *sql.Rows
if r, err := q.Query(query, args...); err != nil {
r, err := q.Query(query, args...)
if err != nil {
return 0, err
} else {
rs = r
}
rs = r
defer rs.Close()
var ref interface{}
args = make([]interface{}, 0)
cnt := 0
for rs.Next() {
@ -702,24 +688,21 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql)
d.ins.ReplaceMarks(&query)
if res, err := q.Exec(query, args...); err == nil {
res, err := q.Exec(query, args...)
if err == nil {
num, err := res.RowsAffected()
if err != nil {
return 0, err
}
if num > 0 {
err := d.deleteRels(q, mi, args, tz)
if err != nil {
return num, err
}
}
return num, nil
} else {
return 0, err
}
return 0, err
}
// read related records.
@ -801,11 +784,11 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth)
where, args := tables.getCondSql(cond, false, tz)
groupBy := tables.getGroupSql(qs.groups)
orderBy := tables.getOrderSql(qs.orders)
limit := tables.getLimitSql(mi, offset, rlimit)
join := tables.getJoinSql()
where, args := tables.getCondSQL(cond, false, tz)
groupBy := tables.getGroupSQL(qs.groups)
orderBy := tables.getOrderSQL(qs.orders)
limit := tables.getLimitSQL(mi, offset, rlimit)
join := tables.getJoinSQL()
for _, tbl := range tables.tables {
if tbl.sel {
@ -824,11 +807,11 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
d.ins.ReplaceMarks(&query)
var rs *sql.Rows
if r, err := q.Query(query, args...); err != nil {
r, err := q.Query(query, args...)
if err != nil {
return 0, err
} else {
rs = r
}
rs = r
refs := make([]interface{}, colsNum)
for i := range refs {
@ -942,9 +925,9 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth)
where, args := tables.getCondSql(cond, false, tz)
tables.getOrderSql(qs.orders)
join := tables.getJoinSql()
where, args := tables.getCondSQL(cond, false, tz)
tables.getOrderSQL(qs.orders)
join := tables.getJoinSQL()
Q := d.ins.TableQuote()
@ -959,7 +942,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
}
// generate sql with replacing operator string placeholders and replaced values.
func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
sql := ""
params := getFlatParams(fi, args, tz)
@ -984,7 +967,7 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri
if len(params) > 1 {
panic(fmt.Errorf("operator `%s` need 1 args not %d", operator, len(params)))
}
sql = d.ins.OperatorSql(operator)
sql = d.ins.OperatorSQL(operator)
switch operator {
case "exact":
if arg == nil {
@ -1112,12 +1095,12 @@ setValue:
)
if len(s) >= 19 {
s = s[:19]
t, err = time.ParseInLocation(format_DateTime, s, tz)
t, err = time.ParseInLocation(formatDateTime, s, tz)
} else {
if len(s) > 10 {
s = s[:10]
}
t, err = time.ParseInLocation(format_Date, s, tz)
t, err = time.ParseInLocation(formatDate, s, tz)
}
t = t.In(DefaultTimeLoc)
@ -1448,11 +1431,11 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
}
}
where, args := tables.getCondSql(cond, false, tz)
groupBy := tables.getGroupSql(qs.groups)
orderBy := tables.getOrderSql(qs.orders)
limit := tables.getLimitSql(mi, qs.offset, qs.limit)
join := tables.getJoinSql()
where, args := tables.getCondSQL(cond, false, tz)
groupBy := tables.getGroupSQL(qs.groups)
orderBy := tables.getOrderSQL(qs.orders)
limit := tables.getLimitSQL(mi, qs.offset, qs.limit)
join := tables.getJoinSQL()
sels := strings.Join(cols, ", ")
@ -1460,13 +1443,10 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
d.ins.ReplaceMarks(&query)
var rs *sql.Rows
if r, err := q.Query(query, args...); err != nil {
rs, err := q.Query(query, args...)
if err != nil {
return 0, err
} else {
rs = r
}
refs := make([]interface{}, len(cols))
for i := range refs {
var ref interface{}
@ -1481,11 +1461,11 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
)
for rs.Next() {
if cnt == 0 {
if cols, err := rs.Columns(); err != nil {
cols, err := rs.Columns()
if err != nil {
return 0, err
} else {
columns = cols
}
columns = cols
}
if err := rs.Scan(refs...); err != nil {
@ -1649,7 +1629,7 @@ func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, e
}
// not implement.
func (d *dbBase) OperatorSql(operator string) string {
func (d *dbBase) OperatorSQL(operator string) string {
panic(ErrNotImplement)
}

View File

@ -22,15 +22,16 @@ import (
"time"
)
// database driver constant int.
// DriverType database driver constant int.
type DriverType int
// Enum the Database driver
const (
_ DriverType = iota // int enum type
DR_MySQL // mysql
DR_Sqlite // sqlite
DR_Oracle // oracle
DR_Postgres // pgsql
DRMySQL // mysql
DRSqlite // sqlite
DROracle // oracle
DRPostgres // pgsql
)
// database driver string.
@ -53,15 +54,15 @@ var _ Driver = new(driver)
var (
dataBaseCache = &_dbCache{cache: make(map[string]*alias)}
drivers = map[string]DriverType{
"mysql": DR_MySQL,
"postgres": DR_Postgres,
"sqlite3": DR_Sqlite,
"mysql": DRMySQL,
"postgres": DRPostgres,
"sqlite3": DRSqlite,
}
dbBasers = map[DriverType]dbBaser{
DR_MySQL: newdbBaseMysql(),
DR_Sqlite: newdbBaseSqlite(),
DR_Oracle: newdbBaseMysql(),
DR_Postgres: newdbBasePostgres(),
DRMySQL: newdbBaseMysql(),
DRSqlite: newdbBaseSqlite(),
DROracle: newdbBaseMysql(),
DRPostgres: newdbBasePostgres(),
}
)
@ -119,7 +120,7 @@ func detectTZ(al *alias) {
}
switch al.Driver {
case DR_MySQL:
case DRMySQL:
row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
var tz string
row.Scan(&tz)
@ -147,10 +148,10 @@ func detectTZ(al *alias) {
al.Engine = "INNODB"
}
case DR_Sqlite:
case DRSqlite:
al.TZ = time.UTC
case DR_Postgres:
case DRPostgres:
row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')")
var tz string
row.Scan(&tz)
@ -188,12 +189,13 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
return al, nil
}
// AddAliasWthDB add a aliasName for the drivename
func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error {
_, err := addAliasWthDB(aliasName, driverName, db)
return err
}
// Setting the database connect params. Use the database driver self dataSource args.
// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args.
func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error {
var (
err error
@ -236,7 +238,7 @@ end:
return err
}
// Register a database driver use specify driver name, this can be definition the driver is which database type.
// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type.
func RegisterDriver(driverName string, typ DriverType) error {
if t, ok := drivers[driverName]; ok == false {
drivers[driverName] = typ
@ -248,7 +250,7 @@ func RegisterDriver(driverName string, typ DriverType) error {
return nil
}
// Change the database default used timezone
// SetDataBaseTZ Change the database default used timezone
func SetDataBaseTZ(aliasName string, tz *time.Location) error {
if al, ok := dataBaseCache.get(aliasName); ok {
al.TZ = tz
@ -258,14 +260,14 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error {
return nil
}
// Change the max idle conns for *sql.DB, use specify database alias name
// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name
func SetMaxIdleConns(aliasName string, maxIdleConns int) {
al := getDbAlias(aliasName)
al.MaxIdleConns = maxIdleConns
al.DB.SetMaxIdleConns(maxIdleConns)
}
// Change the max open conns for *sql.DB, use specify database alias name
// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name
func SetMaxOpenConns(aliasName string, maxOpenConns int) {
al := getDbAlias(aliasName)
al.MaxOpenConns = maxOpenConns
@ -275,7 +277,7 @@ func SetMaxOpenConns(aliasName string, maxOpenConns int) {
}
}
// Get *sql.DB from registered database by db alias name.
// GetDB Get *sql.DB from registered database by db alias name.
// Use "default" as alias name if you not set.
func GetDB(aliasNames ...string) (*sql.DB, error) {
var name string
@ -284,9 +286,9 @@ func GetDB(aliasNames ...string) (*sql.DB, error) {
} else {
name = "default"
}
if al, ok := dataBaseCache.get(name); ok {
al, ok := dataBaseCache.get(name)
if ok {
return al.DB, nil
} else {
}
return nil, fmt.Errorf("DataBase of alias name `%s` not found\n", name)
}
}

View File

@ -67,7 +67,7 @@ type dbBaseMysql struct {
var _ dbBaser = new(dbBaseMysql)
// get mysql operator.
func (d *dbBaseMysql) OperatorSql(operator string) string {
func (d *dbBaseMysql) OperatorSQL(operator string) string {
return mysqlOperators[operator]
}

View File

@ -66,7 +66,7 @@ type dbBasePostgres struct {
var _ dbBaser = new(dbBasePostgres)
// get postgresql operator.
func (d *dbBasePostgres) OperatorSql(operator string) string {
func (d *dbBasePostgres) OperatorSQL(operator string) string {
return postgresOperators[operator]
}
@ -101,7 +101,7 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
num := 0
for _, c := range q {
if c == '?' {
num += 1
num++
}
}
if num == 0 {
@ -114,7 +114,7 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
if c == '?' {
data = append(data, '$')
data = append(data, []byte(strconv.Itoa(num))...)
num += 1
num++
} else {
data = append(data, c)
}

View File

@ -66,7 +66,7 @@ type dbBaseSqlite struct {
var _ dbBaser = new(dbBaseSqlite)
// get sqlite operator.
func (d *dbBaseSqlite) OperatorSql(operator string) string {
func (d *dbBaseSqlite) OperatorSQL(operator string) string {
return sqliteOperators[operator]
}

View File

@ -164,7 +164,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
}
// generate join string.
func (t *dbTables) getJoinSql() (join string) {
func (t *dbTables) getJoinSQL() (join string) {
Q := t.base.TableQuote()
for _, jt := range t.tables {
@ -220,7 +220,7 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
)
num := len(exprs) - 1
names := make([]string, 0)
var names []string
inner := true
@ -326,7 +326,7 @@ loopFor:
}
// generate condition sql.
func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
if cond == nil || cond.IsEmpty() {
return
}
@ -347,7 +347,7 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
where += "NOT "
}
if p.isCond {
w, ps := t.getCondSql(p.cond, true, tz)
w, ps := t.getCondSQL(p.cond, true, tz)
if w != "" {
w = fmt.Sprintf("( %s) ", w)
}
@ -372,12 +372,12 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
operator = "exact"
}
operSql, args := t.base.GenerateOperatorSql(mi, fi, operator, p.args, tz)
operSQL, args := t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz)
leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)
t.base.GenerateOperatorLeftCol(fi, operator, &leftCol)
where += fmt.Sprintf("%s %s ", leftCol, operSql)
where += fmt.Sprintf("%s %s ", leftCol, operSQL)
params = append(params, args...)
}
@ -391,7 +391,7 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
}
// generate group sql.
func (t *dbTables) getGroupSql(groups []string) (groupSql string) {
func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) {
if len(groups) == 0 {
return
}
@ -410,12 +410,12 @@ func (t *dbTables) getGroupSql(groups []string) (groupSql string) {
groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q))
}
groupSql = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", "))
groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", "))
return
}
// generate order sql.
func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) {
if len(orders) == 0 {
return
}
@ -439,12 +439,12 @@ func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc))
}
orderSql = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
return
}
// generate limit sql.
func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits string) {
func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) {
if limit == 0 {
limit = int64(DefaultRowsLimit)
}

View File

@ -24,9 +24,8 @@ import (
func getDbAlias(name string) *alias {
if al, ok := dataBaseCache.get(name); ok {
return al
} else {
panic(fmt.Errorf("unknown DataBase alias name %s", name))
}
panic(fmt.Errorf("unknown DataBase alias name %s", name))
}
// get pk column info.
@ -80,19 +79,19 @@ outFor:
var err error
if len(v) >= 19 {
s := v[:19]
t, err = time.ParseInLocation(format_DateTime, s, DefaultTimeLoc)
t, err = time.ParseInLocation(formatDateTime, s, DefaultTimeLoc)
} else {
s := v
if len(v) > 10 {
s = v[:10]
}
t, err = time.ParseInLocation(format_Date, s, tz)
t, err = time.ParseInLocation(formatDate, s, tz)
}
if err == nil {
if fi.fieldType == TypeDateField {
v = t.In(tz).Format(format_Date)
v = t.In(tz).Format(formatDate)
} else {
v = t.In(tz).Format(format_DateTime)
v = t.In(tz).Format(formatDateTime)
}
}
}
@ -137,9 +136,9 @@ outFor:
case reflect.Struct:
if v, ok := arg.(time.Time); ok {
if fi != nil && fi.fieldType == TypeDateField {
arg = v.In(tz).Format(format_Date)
arg = v.In(tz).Format(formatDate)
} else {
arg = v.In(tz).Format(format_DateTime)
arg = v.In(tz).Format(formatDateTime)
}
} else {
typ := val.Type()

View File

@ -19,10 +19,10 @@ import (
)
const (
od_CASCADE = "cascade"
od_SET_NULL = "set_null"
od_SET_DEFAULT = "set_default"
od_DO_NOTHING = "do_nothing"
odCascade = "cascade"
odSetNULL = "set_null"
odSetDefault = "set_default"
odDoNothing = "do_nothing"
defaultStructTagName = "orm"
defaultStructTagDelim = ";"
)
@ -113,7 +113,7 @@ func (mc *_modelCache) clean() {
mc.done = false
}
// Clean model cache. Then you can re-RegisterModel.
// ResetModelCache Clean model cache. Then you can re-RegisterModel.
// Common use this api for test case.
func ResetModelCache() {
modelCache.clean()

View File

@ -51,12 +51,10 @@ func registerModel(prefix string, model interface{}) {
}
info := newModelInfo(val)
if info.fields.pk == nil {
outFor:
for _, fi := range info.fields.fieldsDB {
if fi.name == "Id" {
if fi.sf.Tag.Get(defaultStructTagName) == "" {
if strings.ToLower(fi.name) == "id" {
switch fi.addrValue.Elem().Kind() {
case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64:
fi.auto = true
@ -66,7 +64,6 @@ func registerModel(prefix string, model interface{}) {
}
}
}
}
if info.fields.pk == nil {
fmt.Printf("<orm.RegisterModel> `%s` need a primary key field\n", name)
@ -298,12 +295,12 @@ end:
}
}
// register models
// RegisterModel register models
func RegisterModel(models ...interface{}) {
RegisterModelWithPrefix("", models...)
}
// register models with a prefix
// RegisterModelWithPrefix register models with a prefix
func RegisterModelWithPrefix(prefix string, models ...interface{}) {
if modelCache.done {
panic(fmt.Errorf("RegisterModel must be run before BootStrap"))
@ -314,7 +311,7 @@ func RegisterModelWithPrefix(prefix string, models ...interface{}) {
}
}
// bootrap models.
// BootStrap bootrap models.
// make all model parsed and can not add more models
func BootStrap() {
if modelCache.done {

View File

@ -15,49 +15,28 @@
package orm
import (
"errors"
"fmt"
"strconv"
"time"
)
// Define the Type enum
const (
// bool
TypeBooleanField = 1 << iota
// string
TypeCharField
// string
TypeTextField
// time.Time
TypeDateField
// time.Time
TypeDateTimeField
// int8
TypeBitField
// int16
TypeSmallIntegerField
// int32
TypeIntegerField
// int64
TypeBigIntegerField
// uint8
TypePositiveBitField
// uint16
TypePositiveSmallIntegerField
// uint32
TypePositiveIntegerField
// uint64
TypePositiveBigIntegerField
// float64
TypeFloatField
// float64
TypeDecimalField
RelForeignKey
RelOneToOne
RelManyToMany
@ -65,6 +44,7 @@ const (
RelReverseMany
)
// Define some logic enum
const (
IsIntegerField = ^-TypePositiveBigIntegerField >> 4 << 5
IsPostiveIntegerField = ^-TypePositiveBigIntegerField >> 8 << 9
@ -72,25 +52,30 @@ const (
IsFieldType = ^-RelReverseMany<<1 + 1
)
// A true/false field.
// BooleanField A true/false field.
type BooleanField bool
// Value return the BooleanField
func (e BooleanField) Value() bool {
return bool(e)
}
// Set will set the BooleanField
func (e *BooleanField) Set(d bool) {
*e = BooleanField(d)
}
// String format the Bool to string
func (e *BooleanField) String() string {
return strconv.FormatBool(e.Value())
}
// FieldType return BooleanField the type
func (e *BooleanField) FieldType() int {
return TypeBooleanField
}
// SetRaw set the interface to bool
func (e *BooleanField) SetRaw(value interface{}) error {
switch d := value.(type) {
case bool:
@ -102,56 +87,65 @@ func (e *BooleanField) SetRaw(value interface{}) error {
}
return err
default:
return errors.New(fmt.Sprintf("<BooleanField.SetRaw> unknown value `%s`", value))
return fmt.Errorf("<BooleanField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the current value
func (e *BooleanField) RawValue() interface{} {
return e.Value()
}
// verify the BooleanField implement the Fielder interface
var _ Fielder = new(BooleanField)
// A string field
// CharField A string field
// required values tag: size
// The size is enforced at the database level and in modelss validation.
// eg: `orm:"size(120)"`
type CharField string
// Value return the CharField's Value
func (e CharField) Value() string {
return string(e)
}
// Set CharField value
func (e *CharField) Set(d string) {
*e = CharField(d)
}
// String return the CharField
func (e *CharField) String() string {
return e.Value()
}
// FieldType return the enum type
func (e *CharField) FieldType() int {
return TypeCharField
}
// SetRaw set the interface to string
func (e *CharField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
e.Set(d)
default:
return errors.New(fmt.Sprintf("<CharField.SetRaw> unknown value `%s`", value))
return fmt.Errorf("<CharField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the CharField value
func (e *CharField) RawValue() interface{} {
return e.Value()
}
// verify CharField implement Fielder
var _ Fielder = new(CharField)
// A date, represented in go by a time.Time instance.
// DateField A date, represented in go by a time.Time instance.
// only date values like 2006-01-02
// Has a few extra, optional attr tag:
//
@ -166,106 +160,125 @@ var _ Fielder = new(CharField)
// eg: `orm:"auto_now"` or `orm:"auto_now_add"`
type DateField time.Time
// Value return the time.Time
func (e DateField) Value() time.Time {
return time.Time(e)
}
// Set set the DateField's value
func (e *DateField) Set(d time.Time) {
*e = DateField(d)
}
// String convert datatime to string
func (e *DateField) String() string {
return e.Value().String()
}
// FieldType return enum type Date
func (e *DateField) FieldType() int {
return TypeDateField
}
// SetRaw convert the interface to time.Time. Allow string and time.Time
func (e *DateField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
v, err := timeParse(d, format_Date)
v, err := timeParse(d, formatDate)
if err != nil {
e.Set(v)
}
return err
default:
return errors.New(fmt.Sprintf("<DateField.SetRaw> unknown value `%s`", value))
return fmt.Errorf("<DateField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return Date value
func (e *DateField) RawValue() interface{} {
return e.Value()
}
// verify DateField implement fielder interface
var _ Fielder = new(DateField)
// A date, represented in go by a time.Time instance.
// DateTimeField A date, represented in go by a time.Time instance.
// datetime values like 2006-01-02 15:04:05
// Takes the same extra arguments as DateField.
type DateTimeField time.Time
// Value return the datatime value
func (e DateTimeField) Value() time.Time {
return time.Time(e)
}
// Set set the time.Time to datatime
func (e *DateTimeField) Set(d time.Time) {
*e = DateTimeField(d)
}
// String return the time's String
func (e *DateTimeField) String() string {
return e.Value().String()
}
// FieldType return the enum TypeDateTimeField
func (e *DateTimeField) FieldType() int {
return TypeDateTimeField
}
// SetRaw convert the string or time.Time to DateTimeField
func (e *DateTimeField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
v, err := timeParse(d, format_DateTime)
v, err := timeParse(d, formatDateTime)
if err != nil {
e.Set(v)
}
return err
default:
return errors.New(fmt.Sprintf("<DateTimeField.SetRaw> unknown value `%s`", value))
return fmt.Errorf("<DateTimeField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the datatime value
func (e *DateTimeField) RawValue() interface{} {
return e.Value()
}
// verify datatime implement fielder
var _ Fielder = new(DateTimeField)
// A floating-point number represented in go by a float32 value.
// FloatField A floating-point number represented in go by a float32 value.
type FloatField float64
// Value return the FloatField value
func (e FloatField) Value() float64 {
return float64(e)
}
// Set the Float64
func (e *FloatField) Set(d float64) {
*e = FloatField(d)
}
// String return the string
func (e *FloatField) String() string {
return ToStr(e.Value(), -1, 32)
}
// FieldType return the enum type
func (e *FloatField) FieldType() int {
return TypeFloatField
}
// SetRaw converter interface Float64 float32 or string to FloatField
func (e *FloatField) SetRaw(value interface{}) error {
switch d := value.(type) {
case float32:
@ -278,36 +291,43 @@ func (e *FloatField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<FloatField.SetRaw> unknown value `%s`", value))
return fmt.Errorf("<FloatField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the FloatField value
func (e *FloatField) RawValue() interface{} {
return e.Value()
}
// verify FloatField implement Fielder
var _ Fielder = new(FloatField)
// -32768 to 32767
// SmallIntegerField -32768 to 32767
type SmallIntegerField int16
// Value return int16 value
func (e SmallIntegerField) Value() int16 {
return int16(e)
}
// Set the SmallIntegerField value
func (e *SmallIntegerField) Set(d int16) {
*e = SmallIntegerField(d)
}
// String convert smallint to string
func (e *SmallIntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return enum type SmallIntegerField
func (e *SmallIntegerField) FieldType() int {
return TypeSmallIntegerField
}
// SetRaw convert interface int16/string to int16
func (e *SmallIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int16:
@ -318,36 +338,43 @@ func (e *SmallIntegerField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<SmallIntegerField.SetRaw> unknown value `%s`", value))
return fmt.Errorf("<SmallIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return smallint value
func (e *SmallIntegerField) RawValue() interface{} {
return e.Value()
}
// verify SmallIntegerField implement Fielder
var _ Fielder = new(SmallIntegerField)
// -2147483648 to 2147483647
// IntegerField -2147483648 to 2147483647
type IntegerField int32
// Value return the int32
func (e IntegerField) Value() int32 {
return int32(e)
}
// Set IntegerField value
func (e *IntegerField) Set(d int32) {
*e = IntegerField(d)
}
// String convert Int32 to string
func (e *IntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return the enum type
func (e *IntegerField) FieldType() int {
return TypeIntegerField
}
// SetRaw convert interface int32/string to int32
func (e *IntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int32:
@ -358,36 +385,43 @@ func (e *IntegerField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<IntegerField.SetRaw> unknown value `%s`", value))
return fmt.Errorf("<IntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return IntegerField value
func (e *IntegerField) RawValue() interface{} {
return e.Value()
}
// verify IntegerField implement Fielder
var _ Fielder = new(IntegerField)
// -9223372036854775808 to 9223372036854775807.
// BigIntegerField -9223372036854775808 to 9223372036854775807.
type BigIntegerField int64
// Value return int64
func (e BigIntegerField) Value() int64 {
return int64(e)
}
// Set the BigIntegerField value
func (e *BigIntegerField) Set(d int64) {
*e = BigIntegerField(d)
}
// String convert BigIntegerField to string
func (e *BigIntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return enum type
func (e *BigIntegerField) FieldType() int {
return TypeBigIntegerField
}
// SetRaw convert interface int64/string to int64
func (e *BigIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int64:
@ -398,36 +432,43 @@ func (e *BigIntegerField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<BigIntegerField.SetRaw> unknown value `%s`", value))
return fmt.Errorf("<BigIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return BigIntegerField value
func (e *BigIntegerField) RawValue() interface{} {
return e.Value()
}
// verify BigIntegerField implement Fielder
var _ Fielder = new(BigIntegerField)
// 0 to 65535
// PositiveSmallIntegerField 0 to 65535
type PositiveSmallIntegerField uint16
// Value return uint16
func (e PositiveSmallIntegerField) Value() uint16 {
return uint16(e)
}
// Set PositiveSmallIntegerField value
func (e *PositiveSmallIntegerField) Set(d uint16) {
*e = PositiveSmallIntegerField(d)
}
// String convert uint16 to string
func (e *PositiveSmallIntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return enum type
func (e *PositiveSmallIntegerField) FieldType() int {
return TypePositiveSmallIntegerField
}
// SetRaw convert Interface uint16/string to uint16
func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint16:
@ -438,36 +479,43 @@ func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<PositiveSmallIntegerField.SetRaw> unknown value `%s`", value))
return fmt.Errorf("<PositiveSmallIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue returns PositiveSmallIntegerField value
func (e *PositiveSmallIntegerField) RawValue() interface{} {
return e.Value()
}
// verify PositiveSmallIntegerField implement Fielder
var _ Fielder = new(PositiveSmallIntegerField)
// 0 to 4294967295
// PositiveIntegerField 0 to 4294967295
type PositiveIntegerField uint32
// Value return PositiveIntegerField value. Uint32
func (e PositiveIntegerField) Value() uint32 {
return uint32(e)
}
// Set the PositiveIntegerField value
func (e *PositiveIntegerField) Set(d uint32) {
*e = PositiveIntegerField(d)
}
// String convert PositiveIntegerField to string
func (e *PositiveIntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return enum type
func (e *PositiveIntegerField) FieldType() int {
return TypePositiveIntegerField
}
// SetRaw convert interface uint32/string to Uint32
func (e *PositiveIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint32:
@ -478,36 +526,43 @@ func (e *PositiveIntegerField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<PositiveIntegerField.SetRaw> unknown value `%s`", value))
return fmt.Errorf("<PositiveIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the PositiveIntegerField Value
func (e *PositiveIntegerField) RawValue() interface{} {
return e.Value()
}
// verify PositiveIntegerField implement Fielder
var _ Fielder = new(PositiveIntegerField)
// 0 to 18446744073709551615
// PositiveBigIntegerField 0 to 18446744073709551615
type PositiveBigIntegerField uint64
// Value return uint64
func (e PositiveBigIntegerField) Value() uint64 {
return uint64(e)
}
// Set PositiveBigIntegerField value
func (e *PositiveBigIntegerField) Set(d uint64) {
*e = PositiveBigIntegerField(d)
}
// String convert PositiveBigIntegerField to string
func (e *PositiveBigIntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return enum type
func (e *PositiveBigIntegerField) FieldType() int {
return TypePositiveIntegerField
}
// SetRaw convert interface uint64/string to Uint64
func (e *PositiveBigIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint64:
@ -518,48 +573,57 @@ func (e *PositiveBigIntegerField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<PositiveBigIntegerField.SetRaw> unknown value `%s`", value))
return fmt.Errorf("<PositiveBigIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return PositiveBigIntegerField value
func (e *PositiveBigIntegerField) RawValue() interface{} {
return e.Value()
}
// verify PositiveBigIntegerField implement Fielder
var _ Fielder = new(PositiveBigIntegerField)
// A large text field.
// TextField A large text field.
type TextField string
// Value return TextField value
func (e TextField) Value() string {
return string(e)
}
// Set the TextField value
func (e *TextField) Set(d string) {
*e = TextField(d)
}
// String convert TextField to string
func (e *TextField) String() string {
return e.Value()
}
// FieldType return enum type
func (e *TextField) FieldType() int {
return TypeTextField
}
// SetRaw convert interface string to string
func (e *TextField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
e.Set(d)
default:
return errors.New(fmt.Sprintf("<TextField.SetRaw> unknown value `%s`", value))
return fmt.Errorf("<TextField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return TextField value
func (e *TextField) RawValue() interface{} {
return e.Value()
}
// verify TextField implement Fielder
var _ Fielder = new(TextField)

View File

@ -119,8 +119,8 @@ type fieldInfo struct {
colDefault bool
initial StrTo
size int
auto_now bool
auto_now_add bool
autoNow bool
autoNowAdd bool
rel bool
reverse bool
reverseField string
@ -309,20 +309,20 @@ checkType:
if fi.rel && fi.dbcol {
switch onDelete {
case od_CASCADE, od_DO_NOTHING:
case od_SET_DEFAULT:
case odCascade, odDoNothing:
case odSetDefault:
if initial.Exist() == false {
err = errors.New("on_delete: set_default need set field a default value")
goto end
}
case od_SET_NULL:
case odSetNULL:
if fi.null == false {
err = errors.New("on_delete: set_null need set field null")
goto end
}
default:
if onDelete == "" {
onDelete = od_CASCADE
onDelete = odCascade
} else {
err = fmt.Errorf("on_delete value expected choice in `cascade,set_null,set_default,do_nothing`, unknown `%s`", onDelete)
goto end
@ -350,9 +350,9 @@ checkType:
fi.unique = false
case TypeDateField, TypeDateTimeField:
if attrs["auto_now"] {
fi.auto_now = true
fi.autoNow = true
} else if attrs["auto_now_add"] {
fi.auto_now_add = true
fi.autoNowAdd = true
}
case TypeFloatField:
case TypeDecimalField:

View File

@ -15,7 +15,6 @@
package orm
import (
"errors"
"fmt"
"os"
"reflect"
@ -72,13 +71,13 @@ func newModelInfo(val reflect.Value) (info *modelInfo) {
added := info.fields.Add(fi)
if added == false {
err = errors.New(fmt.Sprintf("duplicate column name: %s", fi.column))
err = fmt.Errorf("duplicate column name: %s", fi.column)
break
}
if fi.pk {
if info.fields.pk != nil {
err = errors.New(fmt.Sprintf("one model must have one pk field only"))
err = fmt.Errorf("one model must have one pk field only")
break
} else {
info.fields.pk = fi

View File

@ -76,21 +76,21 @@ func (e *SliceStringField) RawValue() interface{} {
var _ Fielder = new(SliceStringField)
// A json field.
type JsonField struct {
type JSONField struct {
Name string
Data string
}
func (e *JsonField) String() string {
func (e *JSONField) String() string {
data, _ := json.Marshal(e)
return string(data)
}
func (e *JsonField) FieldType() int {
func (e *JSONField) FieldType() int {
return TypeTextField
}
func (e *JsonField) SetRaw(value interface{}) error {
func (e *JSONField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
return json.Unmarshal([]byte(d), e)
@ -99,14 +99,14 @@ func (e *JsonField) SetRaw(value interface{}) error {
}
}
func (e *JsonField) RawValue() interface{} {
func (e *JSONField) RawValue() interface{} {
return e.String()
}
var _ Fielder = new(JsonField)
var _ Fielder = new(JSONField)
type Data struct {
Id int
ID int `orm:"column(id)"`
Boolean bool
Char string `orm:"size(50)"`
Text string `orm:"type(text)"`
@ -130,7 +130,7 @@ type Data struct {
}
type DataNull struct {
Id int
ID int `orm:"column(id)"`
Boolean bool `orm:"null"`
Char string `orm:"null;size(50)"`
Text string `orm:"null;type(text)"`
@ -193,7 +193,7 @@ type Float32 float64
type Float64 float64
type DataCustom struct {
Id int
ID int `orm:"column(id)"`
Boolean Boolean
Char string `orm:"size(50)"`
Text string `orm:"type(text)"`
@ -216,12 +216,12 @@ type DataCustom struct {
// only for mysql
type UserBig struct {
Id uint64
ID uint64 `orm:"column(id)"`
Name string
}
type User struct {
Id int
ID int `orm:"column(id)"`
UserName string `orm:"size(30);unique"`
Email string `orm:"size(100)"`
Password string `orm:"size(100)"`
@ -235,9 +235,9 @@ type User struct {
ShouldSkip string `orm:"-"`
Nums int
Langs SliceStringField `orm:"size(100)"`
Extra JsonField `orm:"type(text)"`
Extra JSONField `orm:"type(text)"`
unexport bool `orm:"-"`
unexport_ bool
unexportBool bool
}
func (u *User) TableIndex() [][]string {
@ -259,7 +259,7 @@ func NewUser() *User {
}
type Profile struct {
Id int
ID int `orm:"column(id)"`
Age int16
Money float64
User *User `orm:"reverse(one)" json:"-"`
@ -276,7 +276,7 @@ func NewProfile() *Profile {
}
type Post struct {
Id int
ID int `orm:"column(id)"`
User *User `orm:"rel(fk)"`
Title string `orm:"size(60)"`
Content string `orm:"type(text)"`
@ -297,7 +297,7 @@ func NewPost() *Post {
}
type Tag struct {
Id int
ID int `orm:"column(id)"`
Name string `orm:"size(30)"`
BestPost *Post `orm:"rel(one);null"`
Posts []*Post `orm:"reverse(many)" json:"-"`
@ -309,7 +309,7 @@ func NewTag() *Tag {
}
type PostTags struct {
Id int
ID int `orm:"column(id)"`
Post *Post `orm:"rel(fk)"`
Tag *Tag `orm:"rel(fk)"`
}
@ -319,7 +319,7 @@ func (m *PostTags) TableName() string {
}
type Comment struct {
Id int
ID int `orm:"column(id)"`
Post *Post `orm:"rel(fk);column(post)"`
Content string `orm:"type(text)"`
Parent *Comment `orm:"null;rel(fk)"`
@ -397,7 +397,7 @@ go test -v github.com/astaxie/beego/orm
RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20)
alias := getDbAlias("default")
if alias.Driver == DR_MySQL {
if alias.Driver == DRMySQL {
alias.Engine = "INNODB"
}

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// Package orm provide ORM for MySQL/PostgreSQL/sqlite
// Simple Usage
//
// package main
@ -59,12 +60,13 @@ import (
"time"
)
// DebugQueries define the debug
const (
Debug_Queries = iota
DebugQueries = iota
)
// Define common vars
var (
// DebugLevel = Debug_Queries
Debug = false
DebugLog = NewLog(os.Stderr)
DefaultRowsLimit = 1000
@ -79,7 +81,10 @@ var (
ErrNotImplement = errors.New("have not implement")
)
// Params stores the Params
type Params map[string]interface{}
// ParamsList stores paramslist
type ParamsList []interface{}
type orm struct {
@ -188,7 +193,7 @@ func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) {
o.setPk(mi, ind, id)
cnt += 1
cnt++
}
} else {
mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
@ -489,7 +494,7 @@ func (o *orm) Driver() Driver {
return driver(o.alias.Name)
}
// create new orm
// NewOrm create new orm
func NewOrm() Ormer {
BootStrap() // execute only once
@ -501,7 +506,7 @@ func NewOrm() Ormer {
return o
}
// create a new ormer object with specify *sql.DB for query
// NewOrmWithDB create a new ormer object with specify *sql.DB for query
func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) {
var al *alias

View File

@ -19,6 +19,7 @@ import (
"strings"
)
// ExprSep define the expression seperation
const (
ExprSep = "__"
)
@ -32,19 +33,19 @@ type condValue struct {
isCond bool
}
// condition struct.
// Condition struct.
// work for WHERE conditions.
type Condition struct {
params []condValue
}
// return new condition struct
// NewCondition return new condition struct
func NewCondition() *Condition {
c := &Condition{}
return c
}
// add expression to condition
// And add expression to condition
func (c Condition) And(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.And> args cannot empty"))
@ -53,7 +54,7 @@ func (c Condition) And(expr string, args ...interface{}) *Condition {
return &c
}
// add NOT expression to condition
// AndNot add NOT expression to condition
func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.AndNot> args cannot empty"))
@ -62,7 +63,7 @@ func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
return &c
}
// combine a condition to current condition
// AndCond combine a condition to current condition
func (c *Condition) AndCond(cond *Condition) *Condition {
c = c.clone()
if c == cond {
@ -74,7 +75,7 @@ func (c *Condition) AndCond(cond *Condition) *Condition {
return c
}
// add OR expression to condition
// Or add OR expression to condition
func (c Condition) Or(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.Or> args cannot empty"))
@ -83,7 +84,7 @@ func (c Condition) Or(expr string, args ...interface{}) *Condition {
return &c
}
// add OR NOT expression to condition
// OrNot add OR NOT expression to condition
func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.OrNot> args cannot empty"))
@ -92,7 +93,7 @@ func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
return &c
}
// combine a OR condition to current condition
// OrCond combine a OR condition to current condition
func (c *Condition) OrCond(cond *Condition) *Condition {
c = c.clone()
if c == cond {
@ -104,12 +105,12 @@ func (c *Condition) OrCond(cond *Condition) *Condition {
return c
}
// check the condition arguments are empty or not.
// IsEmpty check the condition arguments are empty or not.
func (c *Condition) IsEmpty() bool {
return len(c.params) == 0
}
// clone a condition
// clone clone a condition
func (c Condition) clone() *Condition {
return &c
}

View File

@ -23,11 +23,12 @@ import (
"time"
)
// Log implement the log.Logger
type Log struct {
*log.Logger
}
// set io.Writer to create a Logger.
// NewLog set io.Writer to create a Logger.
func NewLog(out io.Writer) *Log {
d := new(Log)
d.Logger = log.New(out, "[ORM]", 1e9)
@ -41,7 +42,7 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error
if err != nil {
flag = "FAIL"
}
con := fmt.Sprintf(" - %s - [Queries/%s] - [%s / %11s / %7.1fms] - [%s]", t.Format(format_DateTime), alias.Name, flag, operaton, elsp, query)
con := fmt.Sprintf(" - %s - [Queries/%s] - [%s / %11s / %7.1fms] - [%s]", t.Format(formatDateTime), alias.Name, flag, operaton, elsp, query)
cons := make([]string, 0, len(args))
for _, arg := range args {
cons = append(cons, fmt.Sprintf("%v", arg))

View File

@ -25,11 +25,12 @@ type colValue struct {
type operator int
// define Col operations
const (
Col_Add operator = iota
Col_Minus
Col_Multiply
Col_Except
ColAdd operator = iota
ColMinus
ColMultiply
ColExcept
)
// ColValue do the field raw changes. e.g Nums = Nums + 10. usage:
@ -38,7 +39,7 @@ const (
// }
func ColValue(opt operator, value interface{}) interface{} {
switch opt {
case Col_Add, Col_Minus, Col_Multiply, Col_Except:
case ColAdd, ColMinus, ColMultiply, ColExcept:
default:
panic(fmt.Errorf("orm.ColValue wrong operator"))
}

View File

@ -165,14 +165,14 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
if str != "" {
if len(str) >= 19 {
str = str[:19]
t, err := time.ParseInLocation(format_DateTime, str, o.orm.alias.TZ)
t, err := time.ParseInLocation(formatDateTime, str, o.orm.alias.TZ)
if err == nil {
t = t.In(DefaultTimeLoc)
ind.Set(reflect.ValueOf(t))
}
} else if len(str) >= 10 {
str = str[:10]
t, err := time.ParseInLocation(format_Date, str, DefaultTimeLoc)
t, err := time.ParseInLocation(formatDate, str, DefaultTimeLoc)
if err == nil {
ind.Set(reflect.ValueOf(t))
}
@ -255,12 +255,13 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr
// query data and map to container
func (o *rawSet) QueryRow(containers ...interface{}) error {
refs := make([]interface{}, 0, len(containers))
sInds := make([]reflect.Value, 0)
eTyps := make([]reflect.Type, 0)
var (
refs = make([]interface{}, 0, len(containers))
sInds []reflect.Value
eTyps []reflect.Type
sMi *modelInfo
)
structMode := false
var sMi *modelInfo
for _, container := range containers {
val := reflect.ValueOf(container)
ind := reflect.Indirect(val)
@ -385,12 +386,13 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
// query data rows and map to container
func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
refs := make([]interface{}, 0, len(containers))
sInds := make([]reflect.Value, 0)
eTyps := make([]reflect.Type, 0)
var (
refs = make([]interface{}, 0, len(containers))
sInds []reflect.Value
eTyps []reflect.Type
sMi *modelInfo
)
structMode := false
var sMi *modelInfo
for _, container := range containers {
val := reflect.ValueOf(container)
sInd := reflect.Indirect(val)
@ -557,10 +559,9 @@ func (o *rawSet) readValues(container interface{}, needCols []string) (int64, er
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
var rs *sql.Rows
if r, err := o.orm.db.Query(query, args...); err != nil {
rs, err := o.orm.db.Query(query, args...)
if err != nil {
return 0, err
} else {
rs = r
}
defer rs.Close()
@ -574,9 +575,10 @@ func (o *rawSet) readValues(container interface{}, needCols []string) (int64, er
for rs.Next() {
if cnt == 0 {
if columns, err := rs.Columns(); err != nil {
columns, err := rs.Columns()
if err != nil {
return 0, err
} else {
}
if len(needCols) > 0 {
indexs = make([]int, 0, len(needCols))
} else {
@ -600,7 +602,6 @@ func (o *rawSet) readValues(container interface{}, needCols []string) (int64, er
}
}
}
}
if err := rs.Scan(refs...); err != nil {
return 0, err
@ -684,11 +685,9 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
var rs *sql.Rows
if r, err := o.orm.db.Query(query, args...); err != nil {
rs, err := o.orm.db.Query(query, args...)
if err != nil {
return 0, err
} else {
rs = r
}
defer rs.Close()
@ -706,16 +705,16 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in
for rs.Next() {
if cnt == 0 {
if columns, err := rs.Columns(); err != nil {
columns, err := rs.Columns()
if err != nil {
return 0, err
} else {
}
cols = columns
refs = make([]interface{}, len(cols))
for i := range refs {
if keyCol == cols[i] {
keyIndex = i
}
if typ == 1 || keyIndex == i {
var ref sql.NullString
refs[i] = &ref
@ -723,17 +722,14 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in
var ref interface{}
refs[i] = &ref
}
if valueCol == cols[i] {
valueIndex = i
}
}
if keyIndex == -1 || valueIndex == -1 {
panic(fmt.Errorf("<RawSeter> RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol))
}
}
}
if err := rs.Scan(refs...); err != nil {
return 0, err

View File

@ -31,13 +31,13 @@ import (
var _ = os.PathSeparator
var (
test_Date = format_Date + " -0700"
test_DateTime = format_DateTime + " -0700"
testDate = formatDate + " -0700"
testDateTime = formatDateTime + " -0700"
)
func ValuesCompare(is bool, a interface{}, args ...interface{}) (err error, ok bool) {
func ValuesCompare(is bool, a interface{}, args ...interface{}) (ok bool, err error) {
if len(args) == 0 {
return fmt.Errorf("miss args"), false
return false, fmt.Errorf("miss args")
}
b := args[0]
arg := argAny(args)
@ -71,21 +71,21 @@ func ValuesCompare(is bool, a interface{}, args ...interface{}) (err error, ok b
wrongArg:
if err != nil {
return err, false
return false, err
}
return nil, true
return true, nil
}
func AssertIs(a interface{}, args ...interface{}) error {
if err, ok := ValuesCompare(true, a, args...); ok == false {
if ok, err := ValuesCompare(true, a, args...); ok == false {
return err
}
return nil
}
func AssertNot(a interface{}, args ...interface{}) error {
if err, ok := ValuesCompare(false, a, args...); ok == false {
if ok, err := ValuesCompare(false, a, args...); ok == false {
return err
}
return nil
@ -208,7 +208,7 @@ func TestModelSyntax(t *testing.T) {
}
}
var Data_Values = map[string]interface{}{
var DataValues = map[string]interface{}{
"Boolean": true,
"Char": "char",
"Text": "text",
@ -235,7 +235,7 @@ func TestDataTypes(t *testing.T) {
d := Data{}
ind := reflect.Indirect(reflect.ValueOf(&d))
for name, value := range Data_Values {
for name, value := range DataValues {
e := ind.FieldByName(name)
e.Set(reflect.ValueOf(value))
}
@ -244,22 +244,22 @@ func TestDataTypes(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
d = Data{Id: 1}
d = Data{ID: 1}
err = dORM.Read(&d)
throwFail(t, err)
ind = reflect.Indirect(reflect.ValueOf(&d))
for name, value := range Data_Values {
for name, value := range DataValues {
e := ind.FieldByName(name)
vu := e.Interface()
switch name {
case "Date":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date)
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate)
value = value.(time.Time).In(DefaultTimeLoc).Format(testDate)
case "DateTime":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
}
throwFail(t, AssertIs(vu == value, true), value, vu)
}
@ -278,7 +278,7 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
d = DataNull{Id: 1}
d = DataNull{ID: 1}
err = dORM.Read(&d)
throwFail(t, err)
@ -309,7 +309,7 @@ func TestNullDataTypes(t *testing.T) {
_, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec()
throwFail(t, err)
d = DataNull{Id: 2}
d = DataNull{ID: 2}
err = dORM.Read(&d)
throwFail(t, err)
@ -362,7 +362,7 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(id, 3))
d = DataNull{Id: 3}
d = DataNull{ID: 3}
err = dORM.Read(&d)
throwFail(t, err)
@ -402,7 +402,7 @@ func TestDataCustomTypes(t *testing.T) {
d := DataCustom{}
ind := reflect.Indirect(reflect.ValueOf(&d))
for name, value := range Data_Values {
for name, value := range DataValues {
e := ind.FieldByName(name)
if !e.IsValid() {
continue
@ -414,13 +414,13 @@ func TestDataCustomTypes(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
d = DataCustom{Id: 1}
d = DataCustom{ID: 1}
err = dORM.Read(&d)
throwFail(t, err)
ind = reflect.Indirect(reflect.ValueOf(&d))
for name, value := range Data_Values {
for name, value := range DataValues {
e := ind.FieldByName(name)
if !e.IsValid() {
continue
@ -451,7 +451,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
u := &User{Id: user.Id}
u := &User{ID: user.ID}
err = dORM.Read(u)
throwFail(t, err)
@ -461,8 +461,8 @@ func TestCRUD(t *testing.T) {
throwFail(t, AssertIs(u.Status, 3))
throwFail(t, AssertIs(u.IsStaff, true))
throwFail(t, AssertIs(u.IsActive, true))
throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), user.Created.In(DefaultTimeLoc), test_Date))
throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), user.Updated.In(DefaultTimeLoc), test_DateTime))
throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), user.Created.In(DefaultTimeLoc), testDate))
throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), user.Updated.In(DefaultTimeLoc), testDateTime))
user.UserName = "astaxie"
user.Profile = profile
@ -470,11 +470,11 @@ func TestCRUD(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
u = &User{Id: user.Id}
u = &User{ID: user.ID}
err = dORM.Read(u)
throwFailNow(t, err)
throwFail(t, AssertIs(u.UserName, "astaxie"))
throwFail(t, AssertIs(u.Profile.Id, profile.Id))
throwFail(t, AssertIs(u.Profile.ID, profile.ID))
u = &User{UserName: "astaxie", Password: "pass"}
err = dORM.Read(u, "UserName")
@ -487,7 +487,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
u = &User{Id: user.Id}
u = &User{ID: user.ID}
err = dORM.Read(u)
throwFailNow(t, err)
throwFail(t, AssertIs(u.UserName, "QQ"))
@ -497,7 +497,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
u = &User{Id: user.Id}
u = &User{ID: user.ID}
err = dORM.Read(u)
throwFail(t, err)
throwFail(t, AssertIs(true, u.Profile == nil))
@ -506,7 +506,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
u = &User{Id: 100}
u = &User{ID: 100}
err = dORM.Read(u)
throwFail(t, AssertIs(err, ErrNoRows))
@ -516,7 +516,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
ub = UserBig{Id: 1}
ub = UserBig{ID: 1}
err = dORM.Read(&ub)
throwFail(t, err)
throwFail(t, AssertIs(ub.Name, "name"))
@ -586,7 +586,7 @@ func TestInsertTestData(t *testing.T) {
throwFail(t, AssertIs(id, 4))
tags := []*Tag{
{Name: "golang", BestPost: &Post{Id: 2}},
{Name: "golang", BestPost: &Post{ID: 2}},
{Name: "example"},
{Name: "format"},
{Name: "c++"},
@ -638,7 +638,7 @@ The program—and web server—godoc processes Go source files to extract docume
}
func TestCustomField(t *testing.T) {
user := User{Id: 2}
user := User{ID: 2}
err := dORM.Read(&user)
throwFailNow(t, err)
@ -648,7 +648,7 @@ func TestCustomField(t *testing.T) {
_, err = dORM.Update(&user, "Langs", "Extra")
throwFailNow(t, err)
user = User{Id: 2}
user = User{ID: 2}
err = dORM.Read(&user)
throwFailNow(t, err)
throwFailNow(t, AssertIs(len(user.Langs), 2))
@ -889,9 +889,9 @@ func TestAll(t *testing.T) {
throwFailNow(t, AssertIs(users2[0].UserName, "slene"))
throwFailNow(t, AssertIs(users2[1].UserName, "astaxie"))
throwFailNow(t, AssertIs(users2[2].UserName, "nobody"))
throwFailNow(t, AssertIs(users2[0].Id, 0))
throwFailNow(t, AssertIs(users2[1].Id, 0))
throwFailNow(t, AssertIs(users2[2].Id, 0))
throwFailNow(t, AssertIs(users2[0].ID, 0))
throwFailNow(t, AssertIs(users2[1].ID, 0))
throwFailNow(t, AssertIs(users2[2].ID, 0))
throwFailNow(t, AssertIs(users2[0].Profile == nil, false))
throwFailNow(t, AssertIs(users2[1].Profile == nil, false))
throwFailNow(t, AssertIs(users2[2].Profile == nil, true))
@ -1112,7 +1112,7 @@ func TestReverseQuery(t *testing.T) {
func TestLoadRelated(t *testing.T) {
// load reverse foreign key
user := User{Id: 3}
user := User{ID: 3}
err := dORM.Read(&user)
throwFailNow(t, err)
@ -1121,7 +1121,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 2))
throwFailNow(t, AssertIs(len(user.Posts), 2))
throwFailNow(t, AssertIs(user.Posts[0].User.Id, 3))
throwFailNow(t, AssertIs(user.Posts[0].User.ID, 3))
num, err = dORM.LoadRelated(&user, "Posts", true)
throwFailNow(t, err)
@ -1143,8 +1143,8 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting"))
// load reverse one to one
profile := Profile{Id: 3}
profile.BestPost = &Post{Id: 2}
profile := Profile{ID: 3}
profile.BestPost = &Post{ID: 2}
num, err = dORM.Update(&profile, "BestPost")
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1))
@ -1183,7 +1183,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(user.Profile.BestPost == nil, false))
throwFailNow(t, AssertIs(user.Profile.BestPost.Title, "Examples"))
post := Post{Id: 2}
post := Post{ID: 2}
// load rel foreign key
err = dORM.Read(&post)
@ -1204,7 +1204,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(post.User.Profile.Age, 30))
// load rel m2m
post = Post{Id: 2}
post = Post{ID: 2}
err = dORM.Read(&post)
throwFailNow(t, err)
@ -1224,7 +1224,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(post.Tags[0].BestPost.User.UserName, "astaxie"))
// load reverse m2m
tag := Tag{Id: 1}
tag := Tag{ID: 1}
err = dORM.Read(&tag)
throwFailNow(t, err)
@ -1233,19 +1233,19 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 3))
throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction"))
throwFailNow(t, AssertIs(tag.Posts[0].User.Id, 2))
throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2))
throwFailNow(t, AssertIs(tag.Posts[0].User.Profile == nil, true))
num, err = dORM.LoadRelated(&tag, "Posts", true)
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 3))
throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction"))
throwFailNow(t, AssertIs(tag.Posts[0].User.Id, 2))
throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2))
throwFailNow(t, AssertIs(tag.Posts[0].User.UserName, "slene"))
}
func TestQueryM2M(t *testing.T) {
post := Post{Id: 4}
post := Post{ID: 4}
m2m := dORM.QueryM2M(&post, "Tags")
tag1 := []*Tag{{Name: "TestTag1"}, {Name: "TestTag2"}}
@ -1319,7 +1319,7 @@ func TestQueryM2M(t *testing.T) {
for _, post := range posts {
p := post.(*Post)
p.User = &User{Id: 1}
p.User = &User{ID: 1}
_, err := dORM.Insert(post)
throwFailNow(t, err)
}
@ -1459,10 +1459,10 @@ func TestRawQueryRow(t *testing.T) {
Decimal float64
)
data_values := make(map[string]interface{}, len(Data_Values))
dataValues := make(map[string]interface{}, len(DataValues))
for k, v := range Data_Values {
data_values[strings.ToLower(k)] = v
for k, v := range DataValues {
dataValues[strings.ToLower(k)] = v
}
Q := dDbBaser.TableQuote()
@ -1488,14 +1488,14 @@ func TestRawQueryRow(t *testing.T) {
throwFail(t, AssertIs(id, 1))
case "date":
v = v.(time.Time).In(DefaultTimeLoc)
value := data_values[col].(time.Time).In(DefaultTimeLoc)
throwFail(t, AssertIs(v, value, test_Date))
value := dataValues[col].(time.Time).In(DefaultTimeLoc)
throwFail(t, AssertIs(v, value, testDate))
case "datetime":
v = v.(time.Time).In(DefaultTimeLoc)
value := data_values[col].(time.Time).In(DefaultTimeLoc)
throwFail(t, AssertIs(v, value, test_DateTime))
value := dataValues[col].(time.Time).In(DefaultTimeLoc)
throwFail(t, AssertIs(v, value, testDateTime))
default:
throwFail(t, AssertIs(v, data_values[col]))
throwFail(t, AssertIs(v, dataValues[col]))
}
}
@ -1529,16 +1529,16 @@ func TestQueryRows(t *testing.T) {
ind := reflect.Indirect(reflect.ValueOf(datas[0]))
for name, value := range Data_Values {
for name, value := range DataValues {
e := ind.FieldByName(name)
vu := e.Interface()
switch name {
case "Date":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date)
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate)
value = value.(time.Time).In(DefaultTimeLoc).Format(testDate)
case "DateTime":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
}
throwFail(t, AssertIs(vu == value, true), value, vu)
}
@ -1553,16 +1553,16 @@ func TestQueryRows(t *testing.T) {
ind = reflect.Indirect(reflect.ValueOf(datas2[0]))
for name, value := range Data_Values {
for name, value := range DataValues {
e := ind.FieldByName(name)
vu := e.Interface()
switch name {
case "Date":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date)
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate)
value = value.(time.Time).In(DefaultTimeLoc).Format(testDate)
case "DateTime":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
}
throwFail(t, AssertIs(vu == value, true), value, vu)
}
@ -1699,25 +1699,25 @@ func TestUpdate(t *testing.T) {
throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "slene").Update(Params{
"Nums": ColValue(Col_Add, 100),
"Nums": ColValue(ColAdd, 100),
})
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "slene").Update(Params{
"Nums": ColValue(Col_Minus, 50),
"Nums": ColValue(ColMinus, 50),
})
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "slene").Update(Params{
"Nums": ColValue(Col_Multiply, 3),
"Nums": ColValue(ColMultiply, 3),
})
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "slene").Update(Params{
"Nums": ColValue(Col_Except, 5),
"Nums": ColValue(ColExcept, 5),
})
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
@ -1838,15 +1838,15 @@ func TestReadOrCreate(t *testing.T) {
throwFail(t, AssertIs(u.Status, 7))
throwFail(t, AssertIs(u.IsStaff, false))
throwFail(t, AssertIs(u.IsActive, true))
throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), test_Date))
throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), test_DateTime))
throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), testDate))
throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), testDateTime))
nu := &User{UserName: u.UserName, Email: "someotheremail@gmail.com"}
created, pk, err = dORM.ReadOrCreate(nu, "UserName")
throwFail(t, err)
throwFail(t, AssertIs(created, false))
throwFail(t, AssertIs(nu.Id, u.Id))
throwFail(t, AssertIs(pk, u.Id))
throwFail(t, AssertIs(nu.ID, u.ID))
throwFail(t, AssertIs(pk, u.ID))
throwFail(t, AssertIs(nu.UserName, u.UserName))
throwFail(t, AssertIs(nu.Email, u.Email)) // should contain the value in the table, not the one specified above
throwFail(t, AssertIs(nu.Password, u.Password))

View File

@ -16,6 +16,7 @@ package orm
import "errors"
// QueryBuilder is the Query builder interface
type QueryBuilder interface {
Select(fields ...string) QueryBuilder
From(tables ...string) QueryBuilder
@ -43,15 +44,16 @@ type QueryBuilder interface {
String() string
}
// NewQueryBuilder return the QueryBuilder
func NewQueryBuilder(driver string) (qb QueryBuilder, err error) {
if driver == "mysql" {
qb = new(MySQLQueryBuilder)
} else if driver == "postgres" {
err = errors.New("postgres query builder is not supported yet!")
err = errors.New("postgres query builder is not supported yet")
} else if driver == "sqlite" {
err = errors.New("sqlite query builder is not supported yet!")
err = errors.New("sqlite query builder is not supported yet")
} else {
err = errors.New("unknown driver for query builder!")
err = errors.New("unknown driver for query builder")
}
return
}

View File

@ -20,134 +20,160 @@ import (
"strings"
)
const COMMA_SPACE = ", "
// CommaSpace is the seperation
const CommaSpace = ", "
// MySQLQueryBuilder is the SQL build
type MySQLQueryBuilder struct {
Tokens []string
}
// Select will join the fields
func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, COMMA_SPACE))
qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace))
return qb
}
// From join the tables
func (qb *MySQLQueryBuilder) From(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, COMMA_SPACE))
qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace))
return qb
}
// InnerJoin INNER JOIN the table
func (qb *MySQLQueryBuilder) InnerJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "INNER JOIN", table)
return qb
}
// LeftJoin LEFT JOIN the table
func (qb *MySQLQueryBuilder) LeftJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "LEFT JOIN", table)
return qb
}
// RightJoin RIGHT JOIN the table
func (qb *MySQLQueryBuilder) RightJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table)
return qb
}
// On join with on cond
func (qb *MySQLQueryBuilder) On(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "ON", cond)
return qb
}
// Where join the Where cond
func (qb *MySQLQueryBuilder) Where(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "WHERE", cond)
return qb
}
// And join the and cond
func (qb *MySQLQueryBuilder) And(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "AND", cond)
return qb
}
// Or join the or cond
func (qb *MySQLQueryBuilder) Or(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "OR", cond)
return qb
}
// In join the IN (vals)
func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, COMMA_SPACE), ")")
qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")")
return qb
}
// OrderBy join the Order by fields
func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, COMMA_SPACE))
qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace))
return qb
}
// Asc join the asc
func (qb *MySQLQueryBuilder) Asc() QueryBuilder {
qb.Tokens = append(qb.Tokens, "ASC")
return qb
}
// Desc join the desc
func (qb *MySQLQueryBuilder) Desc() QueryBuilder {
qb.Tokens = append(qb.Tokens, "DESC")
return qb
}
// Limit join the limit num
func (qb *MySQLQueryBuilder) Limit(limit int) QueryBuilder {
qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit))
return qb
}
// Offset join the offset num
func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder {
qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset))
return qb
}
// GroupBy join the Group by fields
func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, COMMA_SPACE))
qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace))
return qb
}
// Having join the Having cond
func (qb *MySQLQueryBuilder) Having(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "HAVING", cond)
return qb
}
// Update join the update table
func (qb *MySQLQueryBuilder) Update(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, COMMA_SPACE))
qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace))
return qb
}
// Set join the set kv
func (qb *MySQLQueryBuilder) Set(kv ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, COMMA_SPACE))
qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace))
return qb
}
// Delete join the Delete tables
func (qb *MySQLQueryBuilder) Delete(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "DELETE")
if len(tables) != 0 {
qb.Tokens = append(qb.Tokens, strings.Join(tables, COMMA_SPACE))
qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace))
}
return qb
}
// InsertInto join the insert SQL
func (qb *MySQLQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "INSERT INTO", table)
if len(fields) != 0 {
fieldsStr := strings.Join(fields, COMMA_SPACE)
fieldsStr := strings.Join(fields, CommaSpace)
qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")")
}
return qb
}
// Values join the Values(vals)
func (qb *MySQLQueryBuilder) Values(vals ...string) QueryBuilder {
valsStr := strings.Join(vals, COMMA_SPACE)
valsStr := strings.Join(vals, CommaSpace)
qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")")
return qb
}
// Subquery join the sub as alias
func (qb *MySQLQueryBuilder) Subquery(sub string, alias string) string {
return fmt.Sprintf("(%s) AS %s", sub, alias)
}
// String join all Tokens
func (qb *MySQLQueryBuilder) String() string {
return strings.Join(qb.Tokens, " ")
}

View File

@ -20,13 +20,13 @@ import (
"time"
)
// database driver
// Driver define database driver
type Driver interface {
Name() string
Type() DriverType
}
// field info
// Fielder define field info
type Fielder interface {
String() string
FieldType() int
@ -34,7 +34,7 @@ type Fielder interface {
RawValue() interface{}
}
// orm struct
// Ormer define the orm interface
type Ormer interface {
Read(interface{}, ...string) error
ReadOrCreate(interface{}, string, ...string) (bool, int64, error)
@ -53,13 +53,13 @@ type Ormer interface {
Driver() Driver
}
// insert prepared statement
// Inserter insert prepared statement
type Inserter interface {
Insert(interface{}) (int64, error)
Close() error
}
// query seter
// QuerySeter query seter
type QuerySeter interface {
Filter(string, ...interface{}) QuerySeter
Exclude(string, ...interface{}) QuerySeter
@ -84,7 +84,7 @@ type QuerySeter interface {
RowsToStruct(interface{}, string, string) (int64, error)
}
// model to model query struct
// QueryM2Mer model to model query struct
type QueryM2Mer interface {
Add(...interface{}) (int64, error)
Remove(...interface{}) (int64, error)
@ -93,13 +93,13 @@ type QueryM2Mer interface {
Count() (int64, error)
}
// raw query statement
// RawPreparer raw query statement
type RawPreparer interface {
Exec(...interface{}) (sql.Result, error)
Close() error
}
// raw query seter
// RawSeter raw query seter
type RawSeter interface {
Exec() (sql.Result, error)
QueryRow(...interface{}) error
@ -113,7 +113,7 @@ type RawSeter interface {
Prepare() (RawPreparer, error)
}
// statement querier
// stmtQuerier statement querier
type stmtQuerier interface {
Close() error
Exec(args ...interface{}) (sql.Result, error)
@ -162,8 +162,8 @@ type dbBaser interface {
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
OperatorSql(string) string
GenerateOperatorSql(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{})
OperatorSQL(string) string
GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{})
GenerateOperatorLeftCol(*fieldInfo, string, *string)
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)

View File

@ -22,9 +22,10 @@ import (
"time"
)
// StrTo is the target string
type StrTo string
// set string
// Set string
func (f *StrTo) Set(v string) {
if v != "" {
*f = StrTo(v)
@ -33,93 +34,93 @@ func (f *StrTo) Set(v string) {
}
}
// clean string
// Clear string
func (f *StrTo) Clear() {
*f = StrTo(0x1E)
}
// check string exist
// Exist check string exist
func (f StrTo) Exist() bool {
return string(f) != string(0x1E)
}
// string to bool
// Bool string to bool
func (f StrTo) Bool() (bool, error) {
return strconv.ParseBool(f.String())
}
// string to float32
// Float32 string to float32
func (f StrTo) Float32() (float32, error) {
v, err := strconv.ParseFloat(f.String(), 32)
return float32(v), err
}
// string to float64
// Float64 string to float64
func (f StrTo) Float64() (float64, error) {
return strconv.ParseFloat(f.String(), 64)
}
// string to int
// Int string to int
func (f StrTo) Int() (int, error) {
v, err := strconv.ParseInt(f.String(), 10, 32)
return int(v), err
}
// string to int8
// Int8 string to int8
func (f StrTo) Int8() (int8, error) {
v, err := strconv.ParseInt(f.String(), 10, 8)
return int8(v), err
}
// string to int16
// Int16 string to int16
func (f StrTo) Int16() (int16, error) {
v, err := strconv.ParseInt(f.String(), 10, 16)
return int16(v), err
}
// string to int32
// Int32 string to int32
func (f StrTo) Int32() (int32, error) {
v, err := strconv.ParseInt(f.String(), 10, 32)
return int32(v), err
}
// string to int64
// Int64 string to int64
func (f StrTo) Int64() (int64, error) {
v, err := strconv.ParseInt(f.String(), 10, 64)
return int64(v), err
}
// string to uint
// Uint string to uint
func (f StrTo) Uint() (uint, error) {
v, err := strconv.ParseUint(f.String(), 10, 32)
return uint(v), err
}
// string to uint8
// Uint8 string to uint8
func (f StrTo) Uint8() (uint8, error) {
v, err := strconv.ParseUint(f.String(), 10, 8)
return uint8(v), err
}
// string to uint16
// Uint16 string to uint16
func (f StrTo) Uint16() (uint16, error) {
v, err := strconv.ParseUint(f.String(), 10, 16)
return uint16(v), err
}
// string to uint31
// Uint32 string to uint31
func (f StrTo) Uint32() (uint32, error) {
v, err := strconv.ParseUint(f.String(), 10, 32)
return uint32(v), err
}
// string to uint64
// Uint64 string to uint64
func (f StrTo) Uint64() (uint64, error) {
v, err := strconv.ParseUint(f.String(), 10, 64)
return uint64(v), err
}
// string to string
// String string to string
func (f StrTo) String() string {
if f.Exist() {
return string(f)
@ -127,7 +128,7 @@ func (f StrTo) String() string {
return ""
}
// interface to string
// ToStr interface to string
func ToStr(value interface{}, args ...int) (s string) {
switch v := value.(type) {
case bool:
@ -166,7 +167,7 @@ func ToStr(value interface{}, args ...int) (s string) {
return s
}
// interface to int64
// ToInt64 interface to int64
func ToInt64(value interface{}) (d int64) {
val := reflect.ValueOf(value)
switch value.(type) {