package orm import ( "database/sql" "errors" "fmt" "reflect" "strings" "time" ) const ( format_Date = "2006-01-02" format_DateTime = "2006-01-02 15:04:05" ) var ( ErrMissPK = errors.New("missed pk value") ) var ( operators = map[string]bool{ "exact": true, "iexact": true, "contains": true, "icontains": true, // "regex": true, // "iregex": true, "gt": true, "gte": true, "lt": true, "lte": true, "startswith": true, "endswith": true, "istartswith": true, "iendswith": true, "in": true, // "range": true, // "year": true, // "month": true, // "day": true, // "week_day": true, "isnull": true, // "search": true, } operatorsSQL = map[string]string{ "exact": "= ?", "iexact": "LIKE ?", "contains": "LIKE BINARY ?", "icontains": "LIKE ?", // "regex": "REGEXP BINARY ?", // "iregex": "REGEXP ?", "gt": "> ?", "gte": ">= ?", "lt": "< ?", "lte": "<= ?", "startswith": "LIKE BINARY ?", "endswith": "LIKE BINARY ?", "istartswith": "LIKE ?", "iendswith": "LIKE ?", } ) type dbTable struct { id int index string name string names []string sel bool inner bool mi *modelInfo fi *fieldInfo jtl *dbTable } type dbTables struct { tablesM map[string]*dbTable tables []*dbTable mi *modelInfo base dbBaser } func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable { name := strings.Join(names, ExprSep) if j, ok := t.tablesM[name]; ok { j.name = name j.mi = mi j.fi = fi j.inner = inner } else { i := len(t.tables) + 1 jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} t.tablesM[name] = jt t.tables = append(t.tables, jt) } return t.tablesM[name] } func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) { name := strings.Join(names, ExprSep) if _, ok := t.tablesM[name]; ok == false { i := len(t.tables) + 1 jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} t.tablesM[name] = jt t.tables = append(t.tables, jt) return jt, true } return t.tablesM[name], false } func (t *dbTables) get(name string) (*dbTable, bool) { j, ok := t.tablesM[name] return j, ok } func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string { if depth < 0 || fi.fieldType == RelManyToMany { return related } if prefix == "" { prefix = fi.name } else { prefix = prefix + ExprSep + fi.name } related = append(related, prefix) depth-- for _, fi := range fi.relModelInfo.fields.fieldsRel { related = t.loopDepth(depth, prefix, fi, related) } return related } func (t *dbTables) parseRelated(rels []string, depth int) { relsNum := len(rels) related := make([]string, relsNum) copy(related, rels) relDepth := depth if relsNum != 0 { relDepth = 0 } relDepth-- for _, fi := range t.mi.fields.fieldsRel { related = t.loopDepth(relDepth, "", fi, related) } for i, s := range related { var ( exs = strings.Split(s, ExprSep) names = make([]string, 0, len(exs)) mmi = t.mi cansel = true jtl *dbTable ) for _, ex := range exs { if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany { names = append(names, fi.name) mmi = fi.relModelInfo jt := t.set(names, mmi, fi, fi.null == false) jt.jtl = jtl if fi.reverse { cansel = false } if cansel { jt.sel = depth > 0 if i < relsNum { jt.sel = true } } jtl = jt } else { panic(fmt.Sprintf("unknown model/table name `%s`", ex)) } } } } func (t *dbTables) getJoinSql() (join string) { for _, jt := range t.tables { if jt.inner { join += "INNER JOIN " } else { join += "LEFT OUTER JOIN " } var ( table string t1, t2 string c1, c2 string ) t1 = "T0" if jt.jtl != nil { t1 = jt.jtl.index } t2 = jt.index table = jt.mi.table switch { case jt.fi.fieldType == RelManyToMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany: c1 = jt.fi.mi.fields.pk[0].column for _, ffi := range jt.mi.fields.fieldsRel { if jt.fi.mi == ffi.relModelInfo { c2 = ffi.column break } } default: c1 = jt.fi.column c2 = jt.fi.relModelInfo.fields.pk[0].column if jt.fi.reverse { c1 = jt.mi.fields.pk[0].column c2 = jt.fi.reverseFieldInfo.column } } join += fmt.Sprintf("`%s` %s ON %s.`%s` = %s.`%s` ", table, t2, t2, c2, t1, c1) } return } func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, name string, info *fieldInfo, success bool) { var ( ffi *fieldInfo jtl *dbTable mmi = mi ) num := len(exprs) - 1 names := make([]string, 0) for i, ex := range exprs { exist := false check: fi, ok := mmi.fields.GetByAny(ex) if ok { if num != i { names = append(names, fi.name) switch { case fi.rel: mmi = fi.relModelInfo if fi.fieldType == RelManyToMany { mmi = fi.relThroughModelInfo } case fi.reverse: mmi = fi.reverseFieldInfo.mi if fi.reverseFieldInfo.fieldType == RelManyToMany { mmi = fi.reverseFieldInfo.relThroughModelInfo } } jt, _ := d.add(names, mmi, fi, fi.null == false) jt.jtl = jtl jtl = jt if fi.rel && fi.fieldType == RelManyToMany { ex = fi.relModelInfo.name goto check } if fi.reverse && fi.reverseFieldInfo.fieldType == RelManyToMany { ex = fi.reverseFieldInfo.mi.name goto check } exist = true } else { if ffi == nil { index = "T0" } else { index = jtl.index } column = fi.column info = fi if jtl != nil { name = jtl.name + ExprSep + fi.name } else { name = fi.name } switch fi.fieldType { case RelManyToMany, RelReverseMany: default: exist = true } } ffi = fi } if exist == false { index = "" column = "" name = "" success = false return } } success = index != "" && column != "" return } func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params []interface{}) { if cond == nil || cond.IsEmpty() { return } mi := d.mi // outFor: for i, p := range cond.params { if i > 0 { if p.isOr { where += "OR " } else { where += "AND " } } if p.isNot { where += "NOT " } if p.isCond { w, ps := d.getCondSql(p.cond, true) if w != "" { w = fmt.Sprintf("( %s) ", w) } where += w params = append(params, ps...) } else { exprs := p.exprs num := len(exprs) - 1 operator := "" if operators[exprs[num]] { operator = exprs[num] exprs = exprs[:num] } index, column, _, _, suc := d.parseExprs(mi, exprs) if suc == false { panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep))) } if operator == "" { operator = "exact" } operSql, args := d.base.GetOperatorSql(mi, operator, p.args) where += fmt.Sprintf("%s.`%s` %s ", index, column, operSql) params = append(params, args...) } } if sub == false && where != "" { where = "WHERE " + where } return } func (d *dbTables) getOrderSql(orders []string) (orderSql string) { if len(orders) == 0 { return } orderSqls := make([]string, 0, len(orders)) for _, order := range orders { asc := "ASC" if order[0] == '-' { asc = "DESC" order = order[1:] } exprs := strings.Split(order, ExprSep) index, column, _, _, suc := d.parseExprs(d.mi, exprs) if suc == false { panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) } orderSqls = append(orderSqls, fmt.Sprintf("%s.`%s` %s", index, column, asc)) } orderSql = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) return } func (d *dbTables) getLimitSql(offset int64, limit int) (limits string) { if limit == 0 { limit = DefaultRowsLimit } if limit < 0 { // no limit if offset > 0 { limits = fmt.Sprintf("LIMIT 18446744073709551615 OFFSET %d", offset) } } else if offset <= 0 { limits = fmt.Sprintf("LIMIT %d", limit) } else { limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset) } return } func newDbTables(mi *modelInfo, base dbBaser) *dbTables { tables := &dbTables{} tables.tablesM = make(map[string]*dbTable) tables.mi = mi tables.base = base return tables } type dbBase struct { ins dbBaser } func (d *dbBase) existPk(mi *modelInfo, ind reflect.Value) ([]string, []interface{}, bool) { exist := true columns := make([]string, 0, len(mi.fields.pk)) values := make([]interface{}, 0, len(mi.fields.pk)) for _, fi := range mi.fields.pk { v := ind.Field(fi.fieldIndex) if fi.fieldType&IsIntegerField > 0 { vu := v.Int() if exist { exist = vu > 0 } values = append(values, vu) } else { vu := v.String() if exist { exist = vu != "" } values = append(values, vu) } columns = append(columns, fi.column) } return columns, values, exist } func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool) (columns []string, values []interface{}, err error) { _, pkValues, _ := d.existPk(mi, ind) for _, column := range mi.fields.orders { fi := mi.fields.columns[column] if fi.dbcol == false || fi.auto && skipAuto { continue } var value interface{} if i, ok := mi.fields.pk.Exist(fi); ok { value = pkValues[i] } else { field := ind.Field(fi.fieldIndex) if fi.isFielder { f := field.Addr().Interface().(Fielder) value = f.RawValue() } else { switch fi.fieldType { case TypeBooleanField: value = field.Bool() case TypeCharField, TypeTextField: value = field.String() case TypeFloatField, TypeDecimalField: value = field.Float() case TypeDateField, TypeDateTimeField: value = field.Interface() default: switch { case fi.fieldType&IsPostiveIntegerField > 0: value = field.Uint() case fi.fieldType&IsIntegerField > 0: value = field.Int() case fi.fieldType&IsRelField > 0: if field.IsNil() { value = nil } else { _, fvalues, fok := d.existPk(fi.relModelInfo, reflect.Indirect(field)) if fok { value = fvalues[0] } else { value = nil } } if fi.null == false && value == nil { return nil, nil, errors.New(fmt.Sprintf("field `%s` cannot be NULL", fi.fullName)) } } } } switch fi.fieldType { case TypeDateField, TypeDateTimeField: if fi.auto_now || fi.auto_now_add && insert { tnow := time.Now() if fi.fieldType == TypeDateField { value = timeFormat(tnow, format_Date) } else { value = timeFormat(tnow, format_DateTime) } if fi.isFielder { f := field.Addr().Interface().(Fielder) f.SetRaw(tnow) } else { field.Set(reflect.ValueOf(tnow)) } } } } columns = append(columns, column) values = append(values, value) } return } func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (*sql.Stmt, error) { dbcols := make([]string, 0, len(mi.fields.dbcols)) marks := make([]string, 0, len(mi.fields.dbcols)) for _, fi := range mi.fields.fieldsDB { if fi.auto == false { dbcols = append(dbcols, fi.column) marks = append(marks, "?") } } qmarks := strings.Join(marks, ", ") columns := strings.Join(dbcols, "`,`") query := fmt.Sprintf("INSERT INTO `%s` (`%s`) VALUES (%s)", mi.table, columns, qmarks) return q.Prepare(query) } func (d *dbBase) InsertStmt(stmt *sql.Stmt, mi *modelInfo, ind reflect.Value) (int64, error) { _, values, err := d.collectValues(mi, ind, true, true) if err != nil { return 0, err } if res, err := stmt.Exec(values...); err == nil { return res.LastInsertId() } else { return 0, err } } func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error { pkNames, pkValues, ok := d.existPk(mi, ind) if ok == false { return ErrMissPK } pkColumns := strings.Join(pkNames, "` = ? AND `") sels := strings.Join(mi.fields.dbcols, "`, `") colsNum := len(mi.fields.dbcols) query := fmt.Sprintf("SELECT `%s` FROM `%s` WHERE `%s` = ?", sels, mi.table, pkColumns) refs := make([]interface{}, colsNum) for i, _ := range refs { var ref interface{} refs[i] = &ref } row := q.QueryRow(query, pkValues...) if err := row.Scan(refs...); err != nil { return err } else { elm := reflect.New(mi.addrField.Elem().Type()) md := elm.Interface().(Modeler) md.Init(md) mind := reflect.Indirect(elm) d.setColsValues(mi, &mind, mi.fields.dbcols, refs) ind.Set(mind) } return nil } func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { names, values, err := d.collectValues(mi, ind, true, true) if err != nil { return 0, err } marks := make([]string, len(names)) for i, _ := range marks { marks[i] = "?" } qmarks := strings.Join(marks, ", ") columns := strings.Join(names, "`,`") query := fmt.Sprintf("INSERT INTO `%s` (`%s`) VALUES (%s)", mi.table, columns, qmarks) if res, err := q.Exec(query, values...); err == nil { return res.LastInsertId() } else { return 0, err } } func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { pkNames, pkValues, ok := d.existPk(mi, ind) if ok == false { return 0, ErrMissPK } setNames, setValues, err := d.collectValues(mi, ind, true, false) if err != nil { return 0, err } pkColumns := strings.Join(pkNames, "` = ? AND `") setColumns := strings.Join(setNames, "` = ?, `") query := fmt.Sprintf("UPDATE `%s` SET `%s` = ? WHERE `%s` = ?", mi.table, setColumns, pkColumns) setValues = append(setValues, pkValues...) if res, err := q.Exec(query, setValues...); err == nil { return res.RowsAffected() } else { return 0, err } return 0, nil } func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { names, values, ok := d.existPk(mi, ind) if ok == false { return 0, ErrMissPK } columns := strings.Join(names, "` = ? AND `") query := fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, columns) if res, err := q.Exec(query, values...); err == nil { num, err := res.RowsAffected() if err != nil { return 0, err } if num > 0 { if mi.fields.auto != nil { ind.Field(mi.fields.auto.fieldIndex).SetInt(0) } if len(names) == 1 { err := d.deleteRels(q, mi, values) if err != nil { return num, err } } } return num, err } else { return 0, err } return 0, nil } func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params) (int64, error) { columns := make([]string, 0, len(params)) values := make([]interface{}, 0, len(params)) for col, val := range params { column := snakeString(col) if fi, ok := mi.fields.columns[column]; ok == false || fi.dbcol == false { panic(fmt.Sprintf("wrong field/column name `%s`", column)) } columns = append(columns, column) values = append(values, val) } if len(columns) == 0 { panic("update params cannot empty") } tables := newDbTables(mi, d.ins) if qs != nil { tables.parseRelated(qs.related, qs.relDepth) } where, args := tables.getCondSql(cond, false) join := tables.getJoinSql() query := fmt.Sprintf("UPDATE `%s` T0 %sSET T0.`%s` = ? %s", mi.table, join, strings.Join(columns, "` = ?, T0.`"), where) values = append(values, args...) if res, err := q.Exec(query, values...); err == nil { return res.RowsAffected() } else { return 0, err } return 0, nil } func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}) error { for _, fi := range mi.fields.fieldsReverse { fi = fi.reverseFieldInfo switch fi.onDelete { case od_CASCADE: cond := NewCondition() cond.And(fmt.Sprintf("%s__in", fi.name), args...) _, err := d.DeleteBatch(q, nil, fi.mi, cond) if err != nil { return err } case od_SET_DEFAULT, od_SET_NULL: cond := NewCondition() cond.And(fmt.Sprintf("%s__in", fi.name), args...) params := Params{fi.column: nil} if fi.onDelete == od_SET_DEFAULT { params[fi.column] = fi.initial.String() } _, err := d.UpdateBatch(q, nil, fi.mi, cond, params) if err != nil { return err } case od_DO_NOTHING: } } return nil } func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition) (int64, error) { tables := newDbTables(mi, d.ins) if qs != nil { tables.parseRelated(qs.related, qs.relDepth) } if cond == nil || cond.IsEmpty() { panic("delete operation cannot execute without condition") } where, args := tables.getCondSql(cond, false) join := tables.getJoinSql() colsNum := len(mi.fields.pk) cols := make([]string, colsNum) for i, fi := range mi.fields.pk { cols[i] = fi.column } colsql := fmt.Sprintf("T0.`%s`", strings.Join(cols, "`, T0.`")) query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s", colsql, mi.table, join, where) var rs *sql.Rows if r, err := q.Query(query, args...); err != nil { return 0, err } else { rs = r } refs := make([]interface{}, colsNum) for i, _ := range refs { var ref interface{} refs[i] = &ref } args = make([]interface{}, 0) cnt := 0 for rs.Next() { if err := rs.Scan(refs...); err != nil { return 0, err } for _, ref := range refs { args = append(args, reflect.ValueOf(ref).Elem().Interface()) } cnt++ } if cnt == 0 { return 0, nil } if colsNum > 1 { columns := strings.Join(cols, "` = ? AND `") query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, columns) } else { var sql string sql, args = d.ins.GetOperatorSql(mi, "in", args) query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` %s", mi.table, cols[0], sql) } if res, err := q.Exec(query, args...); err == nil { num, err := res.RowsAffected() if err != nil { return 0, err } if colsNum == 1 && num > 0 { err := d.deleteRels(q, mi, args) if err != nil { return num, err } } return num, nil } else { return 0, err } return 0, nil } func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}) (int64, error) { val := reflect.ValueOf(container) ind := reflect.Indirect(val) typ := ind.Type() errTyp := true one := true if val.Kind() == reflect.Ptr { tp := typ if ind.Kind() == reflect.Slice { one = false if ind.Type().Elem().Kind() == reflect.Ptr { tp = ind.Type().Elem().Elem() } } errTyp = tp.PkgPath()+"."+tp.Name() != mi.fullName } if errTyp { panic(fmt.Sprintf("wrong object type `%s` for rows scan, need *[]*%s or *%s", val.Type(), mi.fullName, mi.fullName)) } rlimit := qs.limit offset := qs.offset if one { rlimit = 0 offset = 0 } tables := newDbTables(mi, d.ins) tables.parseRelated(qs.related, qs.relDepth) where, args := tables.getCondSql(cond, false) orderBy := tables.getOrderSql(qs.orders) limit := tables.getLimitSql(offset, rlimit) join := tables.getJoinSql() colsNum := len(mi.fields.dbcols) cols := fmt.Sprintf("T0.`%s`", strings.Join(mi.fields.dbcols, "`, T0.`")) for _, tbl := range tables.tables { if tbl.sel { colsNum += len(tbl.mi.fields.dbcols) cols += fmt.Sprintf(", %s.`%s`", tbl.index, strings.Join(tbl.mi.fields.dbcols, "`, "+tbl.index+".`")) } } query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s%s%s", cols, mi.table, join, where, orderBy, limit) var rs *sql.Rows if r, err := q.Query(query, args...); err != nil { return 0, err } else { rs = r } refs := make([]interface{}, colsNum) for i, _ := range refs { var ref interface{} refs[i] = &ref } slice := ind var cnt int64 for rs.Next() { if one && cnt == 0 || one == false { if err := rs.Scan(refs...); err != nil { return 0, err } elm := reflect.New(mi.addrField.Elem().Type()) md := elm.Interface().(Modeler) md.Init(md) mind := reflect.Indirect(elm) cacheV := make(map[string]*reflect.Value) cacheM := make(map[string]*modelInfo) trefs := refs d.setColsValues(mi, &mind, mi.fields.dbcols, refs[:len(mi.fields.dbcols)]) trefs = refs[len(mi.fields.dbcols):] for _, tbl := range tables.tables { if tbl.sel { last := mind names := "" mmi := mi for _, name := range tbl.names { names += name if val, ok := cacheV[names]; ok { last = *val mmi = cacheM[names] } else { fi := mmi.fields.GetByName(name) lastm := mmi mmi := fi.relModelInfo field := reflect.Indirect(last.Field(fi.fieldIndex)) if field.IsValid() { d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)]) for _, fi := range mmi.fields.fieldsReverse { if fi.reverseFieldInfo.mi == lastm { if fi.reverseFieldInfo != nil { field.Field(fi.fieldIndex).Set(last.Addr()) } } } cacheV[names] = &field cacheM[names] = mmi last = field } trefs = trefs[len(mmi.fields.dbcols):] } } } } if one { ind.Set(mind) } else { slice = reflect.Append(slice, mind.Addr()) } } cnt++ } if one == false { ind.Set(slice) } return cnt, nil } func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition) (cnt int64, err error) { tables := newDbTables(mi, d.ins) tables.parseRelated(qs.related, qs.relDepth) where, args := tables.getCondSql(cond, false) tables.getOrderSql(qs.orders) join := tables.getJoinSql() query := fmt.Sprintf("SELECT COUNT(*) FROM `%s` T0 %s%s", mi.table, join, where) row := q.QueryRow(query, args...) err = row.Scan(&cnt) return } func (d *dbBase) GetOperatorSql(mi *modelInfo, operator string, args []interface{}) (string, []interface{}) { params := make([]interface{}, len(args)) copy(params, args) sql := "" for i, arg := range args { if len(mi.fields.pk) == 1 { if md, ok := arg.(Modeler); ok { ind := reflect.Indirect(reflect.ValueOf(md)) if _, values, exist := d.existPk(mi, ind); exist { arg = values[0] } else { panic(fmt.Sprintf("`%s` need a valid args value", operator)) } } } params[i] = arg } if operator == "in" { marks := make([]string, len(params)) for i, _ := range marks { marks[i] = "?" } sql = fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) } else { if len(params) > 1 { panic(fmt.Sprintf("operator `%s` need 1 args not %d", operator, len(params))) } sql = operatorsSQL[operator] arg := params[0] switch operator { case "exact": if arg == nil { params[0] = "IS NULL" } case "iexact", "contains", "icontains", "startswith", "endswith", "istartswith", "iendswith": param := strings.Replace(ToStr(arg), `%`, `\%`, -1) switch operator { case "iexact": case "contains", "icontains": param = fmt.Sprintf("%%%s%%", param) case "startswith", "istartswith": param = fmt.Sprintf("%s%%", param) case "endswith", "iendswith": param = fmt.Sprintf("%%%s", param) } params[0] = param case "isnull": if b, ok := arg.(bool); ok { if b { sql = "IS NULL" } else { sql = "IS NOT NULL" } params = nil } else { panic(fmt.Sprintf("operator `%s` need a bool value not `%T`", operator, arg)) } } } return sql, params } func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}) { for i, column := range cols { val := reflect.Indirect(reflect.ValueOf(values[i])).Interface() fi := mi.fields.GetByColumn(column) field := ind.Field(fi.fieldIndex) value, err := d.getValue(fi, val) if err != nil { panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error())) } _, err = d.setValue(fi, value, &field) if err != nil { panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error())) } } } func (d *dbBase) getValue(fi *fieldInfo, val interface{}) (interface{}, error) { if val == nil { return nil, nil } var value interface{} var str *StrTo switch v := val.(type) { case []byte: s := StrTo(string(v)) str = &s case string: s := StrTo(v) str = &s } fieldType := fi.fieldType setValue: switch { case fieldType == TypeBooleanField: if str == nil { switch v := val.(type) { case int64: b := v == 1 value = b default: s := StrTo(ToStr(v)) str = &s } } if str != nil { b, err := str.Bool() if err != nil { return nil, err } value = b } case fieldType == TypeCharField || fieldType == TypeTextField: s := str.String() if str == nil { s = ToStr(val) } value = s case fieldType == TypeDateField || fieldType == TypeDateTimeField: if str == nil { switch v := val.(type) { case time.Time: value = v default: s := StrTo(ToStr(v)) str = &s } } if str != nil { format := format_DateTime if fi.fieldType == TypeDateField { format = format_Date } s := str.String() t, err := timeParse(s, format) if err != nil && s != "0000-00-00" && s != "0000-00-00 00:00:00" { return nil, err } value = t } case fieldType&IsIntegerField > 0: if str == nil { s := StrTo(ToStr(val)) str = &s } if str != nil { var err error switch fieldType { case TypeSmallIntegerField: _, err = str.Int16() case TypeIntegerField: _, err = str.Int32() case TypeBigIntegerField: _, err = str.Int64() case TypePositiveSmallIntegerField: _, err = str.Uint16() case TypePositiveIntegerField: _, err = str.Uint32() case TypePositiveBigIntegerField: _, err = str.Uint64() } if err != nil { return nil, err } if fieldType&IsPostiveIntegerField > 0 { v, _ := str.Uint64() value = v } else { v, _ := str.Int64() value = v } } case fieldType == TypeFloatField || fieldType == TypeDecimalField: if str == nil { switch v := val.(type) { case float64: value = v default: s := StrTo(ToStr(v)) str = &s } } if str != nil { v, err := str.Float64() if err != nil { return nil, err } value = v } case fieldType&IsRelField > 0: fieldType = fi.relModelInfo.fields.pk[0].fieldType goto setValue } return value, nil } func (d *dbBase) setValue(fi *fieldInfo, value interface{}, field *reflect.Value) (interface{}, error) { fieldType := fi.fieldType isNative := fi.isFielder == false setValue: switch { case fieldType == TypeBooleanField: if isNative { if value == nil { value = false } field.SetBool(value.(bool)) } case fieldType == TypeCharField || fieldType == TypeTextField: if isNative { if value == nil { value = "" } field.SetString(value.(string)) } case fieldType == TypeDateField || fieldType == TypeDateTimeField: if isNative { if value == nil { value = time.Time{} } field.Set(reflect.ValueOf(value)) } case fieldType&IsIntegerField > 0: if fieldType&IsPostiveIntegerField > 0 { if isNative { if value == nil { value = uint64(0) } field.SetUint(value.(uint64)) } } else { if isNative { if value == nil { value = int64(0) } field.SetInt(value.(int64)) } } case fieldType == TypeFloatField || fieldType == TypeDecimalField: if isNative { if value == nil { value = float64(0) } field.SetFloat(value.(float64)) } case fieldType&IsRelField > 0: if value != nil { fieldType = fi.relModelInfo.fields.pk[0].fieldType mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) md := mf.Interface().(Modeler) md.Init(md) field.Set(mf) f := mf.Elem().Field(fi.relModelInfo.fields.pk[0].fieldIndex) field = &f goto setValue } } if isNative == false { fd := field.Addr().Interface().(Fielder) err := fd.SetRaw(value) if err != nil { return nil, err } } return value, nil } func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}) (int64, error) { var ( maps []Params lists []ParamsList list ParamsList ) typ := 0 switch container.(type) { case *[]Params: typ = 1 case *[]ParamsList: typ = 2 case *ParamsList: typ = 3 default: panic(fmt.Sprintf("unsupport read values type `%T`", container)) } tables := newDbTables(mi, d.ins) var ( cols []string infos []*fieldInfo ) hasExprs := len(exprs) > 0 if hasExprs { cols = make([]string, 0, len(exprs)) infos = make([]*fieldInfo, 0, len(exprs)) for _, ex := range exprs { index, col, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep)) if suc == false { panic(fmt.Errorf("unknown field/column name `%s`", ex)) } cols = append(cols, fmt.Sprintf("%s.`%s` `%s`", index, col, name)) infos = append(infos, fi) } } else { cols = make([]string, 0, len(mi.fields.dbcols)) infos = make([]*fieldInfo, 0, len(exprs)) for _, fi := range mi.fields.fieldsDB { cols = append(cols, fmt.Sprintf("T0.`%s` `%s`", fi.column, fi.name)) infos = append(infos, fi) } } where, args := tables.getCondSql(cond, false) orderBy := tables.getOrderSql(qs.orders) limit := tables.getLimitSql(qs.offset, qs.limit) join := tables.getJoinSql() sels := strings.Join(cols, ", ") query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s%s%s", sels, mi.table, join, where, orderBy, limit) var rs *sql.Rows if r, err := q.Query(query, args...); err != nil { return 0, err } else { rs = r } refs := make([]interface{}, len(cols)) for i, _ := range refs { var ref interface{} refs[i] = &ref } var ( cnt int64 columns []string ) for rs.Next() { if cnt == 0 { if cols, err := rs.Columns(); err != nil { return 0, err } else { columns = cols } } if err := rs.Scan(refs...); err != nil { return 0, err } switch typ { case 1: params := make(Params, len(cols)) for i, ref := range refs { fi := infos[i] val := reflect.Indirect(reflect.ValueOf(ref)).Interface() value, err := d.getValue(fi, val) if err != nil { panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error())) } params[columns[i]] = value } maps = append(maps, params) case 2: params := make(ParamsList, 0, len(cols)) for i, ref := range refs { fi := infos[i] val := reflect.Indirect(reflect.ValueOf(ref)).Interface() value, err := d.getValue(fi, val) if err != nil { panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error())) } params = append(params, value) } lists = append(lists, params) case 3: for i, ref := range refs { fi := infos[i] val := reflect.Indirect(reflect.ValueOf(ref)).Interface() value, err := d.getValue(fi, val) if err != nil { panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error())) } list = append(list, value) } } cnt++ } switch v := container.(type) { case *[]Params: *v = maps case *[]ParamsList: *v = lists case *ParamsList: *v = list } return cnt, nil }