diff --git a/orm/README.md b/orm/README.md new file mode 100644 index 00000000..a15b8618 --- /dev/null +++ b/orm/README.md @@ -0,0 +1,12 @@ +## beego orm + +a powerful orm framework + +now, beta, unstable, may be changing some api make your app build failed. + +## TODO +- some unrealized api +- examples +- docs +- support postgres +- support sqlite \ No newline at end of file diff --git a/orm/command.go b/orm/command.go new file mode 100644 index 00000000..78508bdc --- /dev/null +++ b/orm/command.go @@ -0,0 +1,44 @@ +package orm + +import ( + "flag" + "fmt" + "os" +) + +func printHelp() { + +} + +func getSqlAll() (sql string) { + for _, mi := range modelCache.allOrdered() { + _ = mi + } + return +} + +func runCommand() { + if len(os.Args) < 2 || os.Args[1] != "orm" { + return + } + + _ = flag.NewFlagSet("orm command", flag.ExitOnError) + + args := argString(os.Args[2:]) + cmd := args.Get(0) + + switch cmd { + case "syncdb": + case "sqlall": + sql := getSqlAll() + fmt.Println(sql) + default: + if cmd != "" { + fmt.Printf("unknown command %s", cmd) + } else { + printHelp() + } + + os.Exit(2) + } +} diff --git a/orm/db.go b/orm/db.go new file mode 100644 index 00000000..712ce654 --- /dev/null +++ b/orm/db.go @@ -0,0 +1,1496 @@ +package orm + +import ( + "database/sql" + "errors" + "fmt" + "reflect" + "strings" + "time" +) + +const ( + format_Date = "2006-01-02" + format_DateTime = "2006-01-02 15:04:05" +) + +var ( + ErrMissPK = errors.New("missed pk value") +) + +var ( + operators = map[string]bool{ + "exact": true, + "iexact": true, + "contains": true, + "icontains": true, + // "regex": true, + // "iregex": true, + "gt": true, + "gte": true, + "lt": true, + "lte": true, + "startswith": true, + "endswith": true, + "istartswith": true, + "iendswith": true, + "in": true, + // "range": true, + // "year": true, + // "month": true, + // "day": true, + // "week_day": true, + "isnull": true, + // "search": true, + } + operatorsSQL = map[string]string{ + "exact": "= ?", + "iexact": "LIKE ?", + "contains": "LIKE BINARY ?", + "icontains": "LIKE ?", + // "regex": "REGEXP BINARY ?", + // "iregex": "REGEXP ?", + "gt": "> ?", + "gte": ">= ?", + "lt": "< ?", + "lte": "<= ?", + "startswith": "LIKE BINARY ?", + "endswith": "LIKE BINARY ?", + "istartswith": "LIKE ?", + "iendswith": "LIKE ?", + } +) + +type dbTable struct { + id int + index string + name string + names []string + sel bool + inner bool + mi *modelInfo + fi *fieldInfo + jtl *dbTable +} + +type dbTables struct { + tablesM map[string]*dbTable + tables []*dbTable + mi *modelInfo + base dbBaser +} + +func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable { + name := strings.Join(names, ExprSep) + if j, ok := t.tablesM[name]; ok { + j.name = name + j.mi = mi + j.fi = fi + j.inner = inner + } else { + i := len(t.tables) + 1 + jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} + t.tablesM[name] = jt + t.tables = append(t.tables, jt) + } + return t.tablesM[name] +} + +func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) { + name := strings.Join(names, ExprSep) + if _, ok := t.tablesM[name]; ok == false { + i := len(t.tables) + 1 + jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} + t.tablesM[name] = jt + t.tables = append(t.tables, jt) + return jt, true + } + return t.tablesM[name], false +} + +func (t *dbTables) get(name string) (*dbTable, bool) { + j, ok := t.tablesM[name] + return j, ok +} + +func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string { + if depth < 0 || fi.fieldType == RelManyToMany { + return related + } + + if prefix == "" { + prefix = fi.name + } else { + prefix = prefix + ExprSep + fi.name + } + related = append(related, prefix) + + depth-- + for _, fi := range fi.relModelInfo.fields.fieldsRel { + related = t.loopDepth(depth, prefix, fi, related) + } + + return related +} + +func (t *dbTables) parseRelated(rels []string, depth int) { + + relsNum := len(rels) + related := make([]string, relsNum) + copy(related, rels) + + relDepth := depth + + if relsNum != 0 { + relDepth = 0 + } + + relDepth-- + for _, fi := range t.mi.fields.fieldsRel { + related = t.loopDepth(relDepth, "", fi, related) + } + + for i, s := range related { + var ( + exs = strings.Split(s, ExprSep) + names = make([]string, 0, len(exs)) + mmi = t.mi + cansel = true + jtl *dbTable + ) + for _, ex := range exs { + if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany { + names = append(names, fi.name) + mmi = fi.relModelInfo + + jt := t.set(names, mmi, fi, fi.null == false) + jt.jtl = jtl + + if fi.reverse { + cansel = false + } + + if cansel { + jt.sel = depth > 0 + + if i < relsNum { + jt.sel = true + } + } + + jtl = jt + + } else { + panic(fmt.Sprintf("unknown model/table name `%s`", ex)) + } + } + } +} + +func (t *dbTables) getJoinSql() (join string) { + for _, jt := range t.tables { + if jt.inner { + join += "INNER JOIN " + } else { + join += "LEFT OUTER JOIN " + } + var ( + table string + t1, t2 string + c1, c2 string + ) + t1 = "T0" + if jt.jtl != nil { + t1 = jt.jtl.index + } + t2 = jt.index + table = jt.mi.table + + switch { + case jt.fi.fieldType == RelManyToMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany: + c1 = jt.fi.mi.fields.pk[0].column + for _, ffi := range jt.mi.fields.fieldsRel { + if jt.fi.mi == ffi.relModelInfo { + c2 = ffi.column + break + } + } + default: + c1 = jt.fi.column + c2 = jt.fi.relModelInfo.fields.pk[0].column + + if jt.fi.reverse { + c1 = jt.mi.fields.pk[0].column + c2 = jt.fi.reverseFieldInfo.column + } + } + + join += fmt.Sprintf("`%s` %s ON %s.`%s` = %s.`%s` ", table, t2, + t2, c2, t1, c1) + } + return +} + +func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column string, info *fieldInfo, success bool) { + var ( + ffi *fieldInfo + jtl *dbTable + mmi = mi + ) + + num := len(exprs) - 1 + names := make([]string, 0) + + for i, ex := range exprs { + exist := false + + check: + fi, ok := mmi.fields.GetByAny(ex) + + if ok { + + if num != i { + names = append(names, fi.name) + + switch { + case fi.rel: + mmi = fi.relModelInfo + if fi.fieldType == RelManyToMany { + mmi = fi.relThroughModelInfo + } + case fi.reverse: + mmi = fi.reverseFieldInfo.mi + if fi.reverseFieldInfo.fieldType == RelManyToMany { + mmi = fi.reverseFieldInfo.relThroughModelInfo + } + } + + jt, _ := d.add(names, mmi, fi, fi.null == false) + jt.jtl = jtl + jtl = jt + + if fi.rel && fi.fieldType == RelManyToMany { + ex = fi.relModelInfo.name + goto check + } + + if fi.reverse && fi.reverseFieldInfo.fieldType == RelManyToMany { + ex = fi.reverseFieldInfo.mi.name + goto check + } + + exist = true + + } else { + + if ffi == nil { + index = "T0" + } else { + index = jtl.index + } + column = fi.column + info = fi + + switch fi.fieldType { + case RelManyToMany, RelReverseMany: + default: + exist = true + } + } + + ffi = fi + } + + if exist == false { + index = "" + column = "" + success = false + return + } + } + + success = index != "" && column != "" + return +} + +func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params []interface{}) { + if cond == nil || cond.IsEmpty() { + return + } + + mi := d.mi + + // outFor: + for i, p := range cond.params { + if i > 0 { + if p.isOr { + where += "OR " + } else { + where += "AND " + } + } + if p.isNot { + where += "NOT " + } + if p.isCond { + w, ps := d.getCondSql(p.cond, true) + if w != "" { + w = fmt.Sprintf("( %s) ", w) + } + where += w + params = append(params, ps...) + } else { + exprs := p.exprs + + num := len(exprs) - 1 + operator := "" + if operators[exprs[num]] { + operator = exprs[num] + exprs = exprs[:num] + } + + index, column, _, suc := d.parseExprs(mi, exprs) + if suc == false { + panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep))) + } + + if operator == "" { + operator = "exact" + } + + operSql, args := d.base.GetOperatorSql(mi, operator, p.args) + + where += fmt.Sprintf("%s.`%s` %s ", index, column, operSql) + params = append(params, args...) + + } + } + + if sub == false && where != "" { + where = "WHERE " + where + } + + return +} + +func (d *dbTables) getOrderSql(orders []string) (orderSql string) { + if len(orders) == 0 { + return + } + + orderSqls := make([]string, 0, len(orders)) + for _, order := range orders { + asc := "ASC" + if order[0] == '-' { + asc = "DESC" + order = order[1:] + } + exprs := strings.Split(order, ExprSep) + + index, column, _, suc := d.parseExprs(d.mi, exprs) + if suc == false { + panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) + } + + orderSqls = append(orderSqls, fmt.Sprintf("%s.`%s` %s", index, column, asc)) + } + + orderSql = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) + return +} + +func (d *dbTables) getLimitSql(offset int64, limit int) (limits string) { + if limit == 0 { + limit = DefaultRowsLimit + } + if limit < 0 { + // no limit + if offset > 0 { + limits = fmt.Sprintf("OFFSET %d", offset) + } + } else if offset <= 0 { + limits = fmt.Sprintf("LIMIT %d", limit) + } else { + limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset) + } + return +} + +func newDbTables(mi *modelInfo, base dbBaser) *dbTables { + tables := &dbTables{} + tables.tablesM = make(map[string]*dbTable) + tables.mi = mi + tables.base = base + return tables +} + +type dbBase struct { + ins dbBaser +} + +func (d *dbBase) existPk(mi *modelInfo, ind reflect.Value) ([]string, []interface{}, bool) { + exist := true + columns := make([]string, 0, len(mi.fields.pk)) + values := make([]interface{}, 0, len(mi.fields.pk)) + for _, fi := range mi.fields.pk { + v := ind.Field(fi.fieldIndex) + if fi.fieldType&IsIntegerField > 0 { + vu := v.Int() + if exist { + exist = vu > 0 + } + values = append(values, vu) + } else { + vu := v.String() + if exist { + exist = vu != "" + } + values = append(values, vu) + } + columns = append(columns, fi.column) + } + return columns, values, exist +} + +func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool) (columns []string, values []interface{}, err error) { + _, pkValues, _ := d.existPk(mi, ind) + for _, column := range mi.fields.orders { + fi := mi.fields.columns[column] + if fi.dbcol == false || fi.auto && skipAuto { + continue + } + var value interface{} + if i, ok := mi.fields.pk.Exist(fi); ok { + value = pkValues[i] + } else { + field := ind.Field(fi.fieldIndex) + if fi.isFielder { + f := field.Addr().Interface().(Fielder) + value = f.RawValue() + } else { + switch fi.fieldType { + case TypeBooleanField: + value = field.Bool() + case TypeCharField, TypeTextField: + value = field.String() + case TypeFloatField, TypeDecimalField: + value = field.Float() + case TypeDateField, TypeDateTimeField: + value = field.Interface() + default: + switch { + case fi.fieldType&IsPostiveIntegerField > 0: + value = field.Uint() + case fi.fieldType&IsIntegerField > 0: + value = field.Int() + case fi.fieldType&IsRelField > 0: + if field.IsNil() { + value = nil + } else { + _, fvalues, fok := d.existPk(fi.relModelInfo, reflect.Indirect(field)) + if fok { + value = fvalues[0] + } else { + value = nil + } + } + if fi.null == false && value == nil { + return nil, nil, errors.New(fmt.Sprintf("field `%s` cannot be NULL", fi.fullName)) + } + } + } + } + switch fi.fieldType { + case TypeDateField, TypeDateTimeField: + if fi.auto_now || fi.auto_now_add && insert { + tnow := time.Now() + if fi.fieldType == TypeDateField { + value = timeFormat(tnow, format_Date) + } else { + value = timeFormat(tnow, format_DateTime) + } + if fi.isFielder { + f := field.Addr().Interface().(Fielder) + f.SetRaw(tnow) + } else { + field.Set(reflect.ValueOf(tnow)) + } + } + } + } + columns = append(columns, column) + values = append(values, value) + } + return +} + +func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (*sql.Stmt, error) { + dbcols := make([]string, 0, len(mi.fields.dbcols)) + marks := make([]string, 0, len(mi.fields.dbcols)) + for _, fi := range mi.fields.fieldsDB { + if fi.auto == false { + dbcols = append(dbcols, fi.column) + marks = append(marks, "?") + } + } + qmarks := strings.Join(marks, ", ") + columns := strings.Join(dbcols, "`,`") + + query := fmt.Sprintf("INSERT INTO `%s` (`%s`) VALUES (%s)", mi.table, columns, qmarks) + return q.Prepare(query) +} + +func (d *dbBase) InsertStmt(stmt *sql.Stmt, mi *modelInfo, ind reflect.Value) (int64, error) { + _, values, err := d.collectValues(mi, ind, true, true) + if err != nil { + return 0, err + } + + if res, err := stmt.Exec(values...); err == nil { + return res.LastInsertId() + } else { + return 0, err + } +} + +func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { + names, values, err := d.collectValues(mi, ind, true, true) + if err != nil { + return 0, err + } + + marks := make([]string, len(names)) + for i, _ := range marks { + marks[i] = "?" + } + qmarks := strings.Join(marks, ", ") + columns := strings.Join(names, "`,`") + + query := fmt.Sprintf("INSERT INTO `%s` (`%s`) VALUES (%s)", mi.table, columns, qmarks) + + if res, err := q.Exec(query, values...); err == nil { + return res.LastInsertId() + } else { + return 0, err + } +} + +func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { + pkNames, pkValues, ok := d.existPk(mi, ind) + if ok == false { + return 0, ErrMissPK + } + setNames, setValues, err := d.collectValues(mi, ind, true, false) + if err != nil { + return 0, err + } + + pkColumns := strings.Join(pkNames, "` = ? AND `") + setColumns := strings.Join(setNames, "` = ?, `") + + query := fmt.Sprintf("UPDATE `%s` SET `%s` = ? WHERE `%s` = ?", mi.table, setColumns, pkColumns) + + setValues = append(setValues, pkValues...) + + if res, err := q.Exec(query, setValues...); err == nil { + return res.RowsAffected() + } else { + return 0, err + } + return 0, nil +} + +func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { + names, values, ok := d.existPk(mi, ind) + if ok == false { + return 0, ErrMissPK + } + + columns := strings.Join(names, "` = ? AND `") + + query := fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, columns) + + if res, err := q.Exec(query, values...); err == nil { + + num, err := res.RowsAffected() + if err != nil { + return 0, err + } + + if num > 0 { + if mi.fields.auto != nil { + ind.Field(mi.fields.auto.fieldIndex).SetInt(0) + } + + if len(names) == 1 { + err := d.deleteRels(q, mi, values) + if err != nil { + return num, err + } + } + } + + return num, err + } else { + return 0, err + } + return 0, nil +} + +func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params) (int64, error) { + columns := make([]string, 0, len(params)) + values := make([]interface{}, 0, len(params)) + for col, val := range params { + column := snakeString(col) + if fi, ok := mi.fields.columns[column]; ok == false || fi.dbcol == false { + panic(fmt.Sprintf("wrong field/column name `%s`", column)) + } + columns = append(columns, column) + values = append(values, val) + } + + if len(columns) == 0 { + panic("update params cannot empty") + } + + tables := newDbTables(mi, d.ins) + if qs != nil { + tables.parseRelated(qs.related, qs.relDepth) + } + + where, args := tables.getCondSql(cond, false) + + join := tables.getJoinSql() + + query := fmt.Sprintf("UPDATE `%s` T0 %sSET T0.`%s` = ? %s", mi.table, join, strings.Join(columns, "` = ?, T0.`"), where) + + values = append(values, args...) + + if res, err := q.Exec(query, values...); err == nil { + return res.RowsAffected() + } else { + return 0, err + } + return 0, nil +} + +func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}) error { + for _, fi := range mi.fields.fieldsReverse { + fi = fi.reverseFieldInfo + switch fi.onDelete { + case od_CASCADE: + cond := NewCondition() + cond.And(fmt.Sprintf("%s__in", fi.name), args...) + _, err := d.DeleteBatch(q, nil, fi.mi, cond) + if err != nil { + return err + } + case od_SET_DEFAULT, od_SET_NULL: + cond := NewCondition() + cond.And(fmt.Sprintf("%s__in", fi.name), args...) + params := Params{fi.column: nil} + if fi.onDelete == od_SET_DEFAULT { + params[fi.column] = fi.initial.String() + } + _, err := d.UpdateBatch(q, nil, fi.mi, cond, params) + if err != nil { + return err + } + case od_DO_NOTHING: + } + } + return nil +} + +func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition) (int64, error) { + tables := newDbTables(mi, d.ins) + if qs != nil { + tables.parseRelated(qs.related, qs.relDepth) + } + + if cond == nil || cond.IsEmpty() { + panic("delete operation cannot execute without condition") + } + + where, args := tables.getCondSql(cond, false) + join := tables.getJoinSql() + + colsNum := len(mi.fields.pk) + cols := make([]string, colsNum) + for i, fi := range mi.fields.pk { + cols[i] = fi.column + } + colsql := fmt.Sprintf("T0.`%s`", strings.Join(cols, "`, T0.`")) + query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s", colsql, mi.table, join, where) + + var rs *sql.Rows + if r, err := q.Query(query, args...); err != nil { + return 0, err + } else { + rs = r + } + + refs := make([]interface{}, colsNum) + for i, _ := range refs { + var ref string + refs[i] = &ref + } + + args = make([]interface{}, 0) + cnt := 0 + for rs.Next() { + if err := rs.Scan(refs...); err != nil { + return 0, err + } + for _, ref := range refs { + args = append(args, reflect.ValueOf(ref).Elem().Interface()) + } + cnt++ + } + + if cnt == 0 { + return 0, nil + } + + if colsNum > 1 { + columns := strings.Join(cols, "` = ? AND `") + query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, columns) + } else { + var sql string + sql, args = d.ins.GetOperatorSql(mi, "in", args) + query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` %s", mi.table, cols[0], sql) + } + + if res, err := q.Exec(query, args...); err == nil { + num, err := res.RowsAffected() + if err != nil { + return 0, err + } + + if colsNum == 1 && num > 0 { + err := d.deleteRels(q, mi, args) + if err != nil { + return num, err + } + } + + return num, nil + } else { + return 0, err + } + + return 0, nil +} + +func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}) (int64, error) { + + val := reflect.ValueOf(container) + ind := reflect.Indirect(val) + typ := ind.Type() + + errTyp := true + + one := true + + if val.Kind() == reflect.Ptr { + tp := typ + if ind.Kind() == reflect.Slice { + one = false + if ind.Type().Elem().Kind() == reflect.Ptr { + tp = ind.Type().Elem().Elem() + } + } + errTyp = tp.PkgPath()+"."+tp.Name() != mi.fullName + } + + if errTyp { + panic(fmt.Sprintf("wrong object type `%s` for rows scan, need *[]*%s or *%s", val.Type(), mi.fullName, mi.fullName)) + } + + rlimit := qs.limit + offset := qs.offset + if one { + rlimit = 0 + offset = 0 + } + + tables := newDbTables(mi, d.ins) + tables.parseRelated(qs.related, qs.relDepth) + + where, args := tables.getCondSql(cond, false) + orderBy := tables.getOrderSql(qs.orders) + limit := tables.getLimitSql(offset, rlimit) + join := tables.getJoinSql() + + colsNum := len(mi.fields.dbcols) + cols := fmt.Sprintf("T0.`%s`", strings.Join(mi.fields.dbcols, "`, T0.`")) + for _, tbl := range tables.tables { + if tbl.sel { + colsNum += len(tbl.mi.fields.dbcols) + cols += fmt.Sprintf(", %s.`%s`", tbl.index, strings.Join(tbl.mi.fields.dbcols, "`, "+tbl.index+".`")) + } + } + + query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s%s%s", cols, mi.table, join, where, orderBy, limit) + + var rs *sql.Rows + if r, err := q.Query(query, args...); err != nil { + return 0, err + } else { + rs = r + } + + refs := make([]interface{}, colsNum) + for i, _ := range refs { + var ref string + refs[i] = &ref + } + + slice := ind + + var cnt int64 + for rs.Next() { + if one && cnt == 0 || one == false { + if err := rs.Scan(refs...); err != nil { + return 0, err + } + + elm := reflect.New(mi.addrField.Elem().Type()) + md := elm.Interface().(Modeler) + md.Init(md) + mind := reflect.Indirect(elm) + + cacheV := make(map[string]*reflect.Value) + cacheM := make(map[string]*modelInfo) + trefs := refs + + d.setColsValues(mi, &mind, mi.fields.dbcols, refs[:len(mi.fields.dbcols)]) + trefs = refs[len(mi.fields.dbcols):] + + for _, tbl := range tables.tables { + if tbl.sel { + last := mind + names := "" + mmi := mi + for _, name := range tbl.names { + names += name + if val, ok := cacheV[names]; ok { + last = *val + mmi = cacheM[names] + } else { + fi := mmi.fields.GetByName(name) + lastm := mmi + mmi := fi.relModelInfo + field := reflect.Indirect(last.Field(fi.fieldIndex)) + d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)]) + for _, fi := range mmi.fields.fieldsReverse { + if fi.reverseFieldInfo.mi == lastm { + if fi.reverseFieldInfo != nil { + field.Field(fi.fieldIndex).Set(last.Addr()) + } + } + } + trefs = trefs[len(mmi.fields.dbcols):] + cacheV[names] = &field + cacheM[names] = mmi + last = field + } + } + } + } + + if one { + ind.Set(mind) + } else { + slice = reflect.Append(slice, mind.Addr()) + } + } + cnt++ + } + + if one == false { + ind.Set(slice) + } + + return cnt, nil +} + +func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition) (cnt int64, err error) { + tables := newDbTables(mi, d.ins) + tables.parseRelated(qs.related, qs.relDepth) + + where, args := tables.getCondSql(cond, false) + tables.getOrderSql(qs.orders) + join := tables.getJoinSql() + + query := fmt.Sprintf("SELECT COUNT(*) FROM `%s` T0 %s%s", mi.table, join, where) + + row := q.QueryRow(query, args...) + + err = row.Scan(&cnt) + return +} + +func (d *dbBase) GetOperatorSql(mi *modelInfo, operator string, args []interface{}) (string, []interface{}) { + params := make([]interface{}, len(args)) + copy(params, args) + sql := "" + for i, arg := range args { + if len(mi.fields.pk) == 1 { + if md, ok := arg.(Modeler); ok { + ind := reflect.Indirect(reflect.ValueOf(md)) + if _, values, exist := d.existPk(mi, ind); exist { + arg = values[0] + } else { + panic(fmt.Sprintf("`%s` need a valid args value", operator)) + } + } + } + params[i] = arg + } + if operator == "in" { + marks := make([]string, len(params)) + for i, _ := range marks { + marks[i] = "?" + } + sql = fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) + } else { + if len(params) > 1 { + panic(fmt.Sprintf("operator `%s` need 1 args not %d", operator, len(params))) + } + sql = operatorsSQL[operator] + arg := params[0] + switch operator { + case "iexact", "contains", "icontains", "startswith", "endswith", "istartswith", "iendswith": + param := strings.Replace(ToStr(arg), `%`, `\%`, -1) + switch operator { + case "iexact", "contains", "icontains": + param = fmt.Sprintf("%%%s%%", param) + case "startswith", "istartswith": + param = fmt.Sprintf("%s%%", param) + case "endswith", "iendswith": + param = fmt.Sprintf("%%%s", param) + } + params[0] = param + case "isnull": + if b, ok := arg.(bool); ok { + if b { + sql = "IS NULL" + } else { + sql = "IS NOT NULL" + } + params = nil + } else { + panic(fmt.Sprintf("operator `%s` need a bool value not `%T`", operator, arg)) + } + } + } + return sql, params +} + +func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}) { + for i, column := range cols { + val := reflect.Indirect(reflect.ValueOf(values[i])).Interface() + + fi := mi.fields.GetByColumn(column) + + field := ind.Field(fi.fieldIndex) + + value, err := d.getValue(fi, val) + if err != nil { + panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error())) + } + + _, err = d.setValue(fi, value, &field) + + if err != nil { + panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error())) + } + } +} + +func (d *dbBase) getValue(fi *fieldInfo, val interface{}) (interface{}, error) { + if val == nil { + return nil, nil + } + + var value interface{} + + var str *StrTo + switch v := val.(type) { + case []byte: + s := StrTo(string(v)) + str = &s + case string: + s := StrTo(v) + str = &s + } + + fieldType := fi.fieldType + +setValue: + switch { + case fieldType == TypeBooleanField: + if str == nil { + switch v := val.(type) { + case int64: + b := v == 1 + value = b + default: + s := StrTo(ToStr(v)) + str = &s + } + } + if str != nil { + b, err := str.Bool() + if err != nil { + return nil, err + } + value = b + } + case fieldType == TypeCharField || fieldType == TypeTextField: + s := str.String() + if str == nil { + s = ToStr(val) + } + value = s + case fieldType == TypeDateField || fieldType == TypeDateTimeField: + if str == nil { + switch v := val.(type) { + case time.Time: + value = v + default: + s := StrTo(ToStr(v)) + str = &s + } + } + if str != nil { + format := format_DateTime + if fi.fieldType == TypeDateField { + format = format_Date + } + s := str.String() + t, err := timeParse(s, format) + if err != nil && s != "0000-00-00" && s != "0000-00-00 00:00:00" { + return nil, err + } + value = t + } + case fieldType&IsIntegerField > 0: + if str == nil { + s := StrTo(ToStr(val)) + str = &s + } + if str != nil { + var err error + switch fieldType { + case TypeSmallIntegerField: + _, err = str.Int16() + case TypeIntegerField: + _, err = str.Int32() + case TypeBigIntegerField: + _, err = str.Int64() + case TypePositiveSmallIntegerField: + _, err = str.Uint16() + case TypePositiveIntegerField: + _, err = str.Uint32() + case TypePositiveBigIntegerField: + _, err = str.Uint64() + } + if err != nil { + return nil, err + } + if fieldType&IsPostiveIntegerField > 0 { + v, _ := str.Uint64() + value = v + } else { + v, _ := str.Int64() + value = v + } + } + case fieldType == TypeFloatField || fieldType == TypeDecimalField: + if str == nil { + switch v := val.(type) { + case float64: + value = v + default: + s := StrTo(ToStr(v)) + str = &s + } + } + if str != nil { + v, err := str.Float64() + if err != nil { + return nil, err + } + value = v + } + case fieldType&IsRelField > 0: + fieldType = fi.relModelInfo.fields.pk[0].fieldType + goto setValue + } + + return value, nil + +} + +func (d *dbBase) setValue(fi *fieldInfo, value interface{}, field *reflect.Value) (interface{}, error) { + + fieldType := fi.fieldType + isNative := fi.isFielder == false + +setValue: + switch { + case fieldType == TypeBooleanField: + if isNative { + field.SetBool(value.(bool)) + } + case fieldType == TypeCharField || fieldType == TypeTextField: + if isNative { + field.SetString(value.(string)) + } + case fieldType == TypeDateField || fieldType == TypeDateTimeField: + if isNative { + field.Set(reflect.ValueOf(value)) + } + case fieldType&IsIntegerField > 0: + if fieldType&IsPostiveIntegerField > 0 { + if isNative { + field.SetUint(value.(uint64)) + } + } else { + if isNative { + field.SetInt(value.(int64)) + } + } + case fieldType == TypeFloatField || fieldType == TypeDecimalField: + if isNative { + field.SetFloat(value.(float64)) + } + case fieldType&IsRelField > 0: + fieldType = fi.relModelInfo.fields.pk[0].fieldType + mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) + md := mf.Interface().(Modeler) + md.Init(md) + field.Set(mf) + f := mf.Elem().Field(fi.relModelInfo.fields.pk[0].fieldIndex) + field = &f + goto setValue + } + + if isNative == false { + fd := field.Addr().Interface().(Fielder) + err := fd.SetRaw(value) + if err != nil { + return nil, err + } + } + + return value, nil +} + +func (d *dbBase) xsetValue(fi *fieldInfo, val interface{}, field *reflect.Value) (interface{}, error) { + if val == nil { + return nil, nil + } + + var value interface{} + + var str *StrTo + switch v := val.(type) { + case []byte: + s := StrTo(string(v)) + str = &s + case string: + s := StrTo(v) + str = &s + } + + fieldType := fi.fieldType + isNative := fi.isFielder == false + +setValue: + switch { + case fieldType == TypeBooleanField: + if str == nil { + switch v := val.(type) { + case int64: + b := v == 1 + if isNative { + field.SetBool(b) + } + value = b + default: + s := StrTo(ToStr(v)) + str = &s + } + } + if str != nil { + b, err := str.Bool() + if err != nil { + return nil, err + } + if isNative { + field.SetBool(b) + } + value = b + } + case fieldType == TypeCharField || fieldType == TypeTextField: + s := str.String() + if str == nil { + s = ToStr(val) + } + if isNative { + field.SetString(s) + } + value = s + case fieldType == TypeDateField || fieldType == TypeDateTimeField: + if str == nil { + switch v := val.(type) { + case time.Time: + if isNative { + field.Set(reflect.ValueOf(v)) + } + value = v + default: + s := StrTo(ToStr(v)) + str = &s + } + } + if str != nil { + format := format_DateTime + if fi.fieldType == TypeDateField { + format = format_Date + } + + t, err := timeParse(str.String(), format) + if err != nil { + return nil, err + } + if isNative { + field.Set(reflect.ValueOf(t)) + } + value = t + } + case fieldType&IsIntegerField > 0: + if str == nil { + s := StrTo(ToStr(val)) + str = &s + } + if str != nil { + var err error + switch fieldType { + case TypeSmallIntegerField: + value, err = str.Int16() + case TypeIntegerField: + value, err = str.Int32() + case TypeBigIntegerField: + value, err = str.Int64() + case TypePositiveSmallIntegerField: + value, err = str.Uint16() + case TypePositiveIntegerField: + value, err = str.Uint32() + case TypePositiveBigIntegerField: + value, err = str.Uint64() + } + if err != nil { + return nil, err + } + if fieldType&IsPostiveIntegerField > 0 { + v, _ := str.Uint64() + if isNative { + field.SetUint(v) + } + } else { + v, _ := str.Int64() + if isNative { + field.SetInt(v) + } + } + } + case fieldType == TypeFloatField || fieldType == TypeDecimalField: + if str == nil { + switch v := val.(type) { + case float64: + if isNative { + field.SetFloat(v) + } + value = v + default: + s := StrTo(ToStr(v)) + str = &s + } + } + if str != nil { + v, err := str.Float64() + if err != nil { + return nil, err + } + if isNative { + field.SetFloat(v) + } + value = v + } + case fieldType&IsRelField > 0: + fieldType = fi.relModelInfo.fields.pk[0].fieldType + mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) + md := mf.Interface().(Modeler) + md.Init(md) + field.Set(mf) + f := mf.Elem().Field(fi.relModelInfo.fields.pk[0].fieldIndex) + field = &f + goto setValue + } + + if isNative == false { + fd := field.Addr().Interface().(Fielder) + err := fd.SetRaw(value) + if err != nil { + return nil, err + } + } + + return value, nil +} + +func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}) (int64, error) { + + var ( + maps []Params + lists []ParamsList + list ParamsList + ) + + typ := 0 + switch container.(type) { + case *[]Params: + typ = 1 + case *[]ParamsList: + typ = 2 + case *ParamsList: + typ = 3 + default: + panic(fmt.Sprintf("unsupport read values type `%T`", container)) + } + + tables := newDbTables(mi, d.ins) + + var ( + cols []string + infos []*fieldInfo + ) + + hasExprs := len(exprs) > 0 + + if hasExprs { + cols = make([]string, 0, len(exprs)) + infos = make([]*fieldInfo, 0, len(exprs)) + for _, ex := range exprs { + index, col, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep)) + if suc == false { + panic(fmt.Errorf("unknown field/column name `%s`", ex)) + } + cols = append(cols, fmt.Sprintf("%s.`%s`", index, col)) + infos = append(infos, fi) + } + } else { + cols = make([]string, 0, len(mi.fields.dbcols)) + infos = make([]*fieldInfo, 0, len(exprs)) + for _, fi := range mi.fields.fieldsDB { + cols = append(cols, fmt.Sprintf("T0.`%s`", fi.column)) + infos = append(infos, fi) + } + } + + where, args := tables.getCondSql(cond, false) + orderBy := tables.getOrderSql(qs.orders) + limit := tables.getLimitSql(qs.offset, qs.limit) + join := tables.getJoinSql() + + sels := strings.Join(cols, ", ") + + query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s%s%s", sels, mi.table, join, where, orderBy, limit) + + var rs *sql.Rows + if r, err := q.Query(query, args...); err != nil { + return 0, err + } else { + rs = r + } + + refs := make([]interface{}, len(cols)) + for i, _ := range refs { + var ref string + refs[i] = &ref + } + + var cnt int64 + for rs.Next() { + if err := rs.Scan(refs...); err != nil { + return 0, err + } + + switch typ { + case 1: + params := make(Params, len(cols)) + for i, ref := range refs { + fi := infos[i] + + val := reflect.Indirect(reflect.ValueOf(ref)).Interface() + + value, err := d.getValue(fi, val) + if err != nil { + panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error())) + } + + if hasExprs { + params[exprs[i]] = value + } else { + params[mi.fields.dbcols[i]] = value + } + } + maps = append(maps, params) + case 2: + params := make(ParamsList, 0, len(cols)) + for i, ref := range refs { + fi := infos[i] + + val := reflect.Indirect(reflect.ValueOf(ref)).Interface() + + value, err := d.getValue(fi, val) + if err != nil { + panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error())) + } + + params = append(params, value) + } + lists = append(lists, params) + case 3: + for i, ref := range refs { + fi := infos[i] + + val := reflect.Indirect(reflect.ValueOf(ref)).Interface() + + value, err := d.getValue(fi, val) + if err != nil { + panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error())) + } + + list = append(list, value) + } + } + + cnt++ + } + + switch v := container.(type) { + case *[]Params: + *v = maps + case *[]ParamsList: + *v = lists + case *ParamsList: + *v = list + } + + return cnt, nil +} diff --git a/orm/db_alias.go b/orm/db_alias.go new file mode 100644 index 00000000..eb3103e8 --- /dev/null +++ b/orm/db_alias.go @@ -0,0 +1,127 @@ +package orm + +import ( + "database/sql" + "fmt" + "os" + "sync" +) + +const defaultMaxIdle = 30 + +type driverType int + +const ( + _ driverType = iota + DR_MySQL + DR_Sqlite + DR_Oracle + DR_Postgres +) + +var ( + dataBaseCache = &_dbCache{cache: make(map[string]*alias)} + drivers = make(map[string]driverType) + dbBasers = map[driverType]dbBaser{ + DR_MySQL: newdbBaseMysql(), + DR_Sqlite: newdbBaseSqlite(), + DR_Oracle: newdbBaseMysql(), + DR_Postgres: newdbBasePostgres(), + } +) + +type _dbCache struct { + mux sync.RWMutex + cache map[string]*alias +} + +func (ac *_dbCache) add(name string, al *alias) (added bool) { + ac.mux.Lock() + defer ac.mux.Unlock() + if _, ok := ac.cache[name]; ok == false { + ac.cache[name] = al + added = true + } + return +} + +func (ac *_dbCache) get(name string) (al *alias, ok bool) { + ac.mux.RLock() + defer ac.mux.RUnlock() + al, ok = ac.cache[name] + return +} + +func (ac *_dbCache) getDefault() (al *alias) { + al, _ = ac.get("default") + return +} + +type alias struct { + Name string + DriverName string + DataSource string + MaxIdle int + DB *sql.DB + DbBaser dbBaser +} + +func RegisterDataBase(name, driverName, dataSource string, maxIdle int) { + if maxIdle <= 0 { + maxIdle = defaultMaxIdle + } + + al := new(alias) + al.Name = name + al.DriverName = driverName + al.DataSource = dataSource + al.MaxIdle = maxIdle + + var ( + err error + ) + + if dr, ok := drivers[driverName]; ok { + al.DbBaser = dbBasers[dr] + } else { + err = fmt.Errorf("driver name `%s` have not registered", driverName) + goto end + } + + if dataBaseCache.add(name, al) == false { + err = fmt.Errorf("db name `%s` already registered, cannot reuse", name) + goto end + } + + al.DB, err = sql.Open(driverName, dataSource) + if err != nil { + err = fmt.Errorf("register db `%s`, %s", name, err.Error()) + goto end + } + + err = al.DB.Ping() + if err != nil { + err = fmt.Errorf("register db `%s`, %s", name, err.Error()) + goto end + } + +end: + if err != nil { + fmt.Println(err.Error()) + os.Exit(2) + } +} + +func RegisterDriver(name string, typ driverType) { + if _, ok := drivers[name]; ok == false { + drivers[name] = typ + } else { + fmt.Println("name `%s` db driver already registered") + os.Exit(2) + } +} + +func init() { + // RegisterDriver("mysql", DR_MySQL) + RegisterDriver("mymysql", DR_MySQL) +} diff --git a/orm/db_mysql.go b/orm/db_mysql.go new file mode 100644 index 00000000..c7cacd90 --- /dev/null +++ b/orm/db_mysql.go @@ -0,0 +1,15 @@ +package orm + +type dbBaseMysql struct { + dbBase +} + +func (d *dbBaseMysql) GetOperatorSql(mi *modelInfo, operator string, args []interface{}) (sql string, params []interface{}) { + return d.dbBase.GetOperatorSql(mi, operator, args) +} + +func newdbBaseMysql() dbBaser { + b := new(dbBaseMysql) + b.ins = b + return b +} diff --git a/orm/db_oracle.go b/orm/db_oracle.go new file mode 100644 index 00000000..b5a27cad --- /dev/null +++ b/orm/db_oracle.go @@ -0,0 +1,11 @@ +package orm + +type dbBaseOracle struct { + dbBase +} + +func newdbBaseOracle() dbBaser { + b := new(dbBaseOracle) + b.ins = b + return b +} diff --git a/orm/db_postgres.go b/orm/db_postgres.go new file mode 100644 index 00000000..1a8a2e3a --- /dev/null +++ b/orm/db_postgres.go @@ -0,0 +1,11 @@ +package orm + +type dbBasePostgres struct { + dbBase +} + +func newdbBasePostgres() dbBaser { + b := new(dbBasePostgres) + b.ins = b + return b +} diff --git a/orm/db_sqlite.go b/orm/db_sqlite.go new file mode 100644 index 00000000..c3c0322e --- /dev/null +++ b/orm/db_sqlite.go @@ -0,0 +1,11 @@ +package orm + +type dbBaseSqlite struct { + dbBase +} + +func newdbBaseSqlite() dbBaser { + b := new(dbBaseSqlite) + b.ins = b + return b +} diff --git a/orm/docs/zh/README.md b/orm/docs/zh/README.md new file mode 100644 index 00000000..e69de29b diff --git a/orm/models.go b/orm/models.go new file mode 100644 index 00000000..53fcb150 --- /dev/null +++ b/orm/models.go @@ -0,0 +1,81 @@ +package orm + +import ( + "log" + "os" + "sync" +) + +const ( + od_CASCADE = "cascade" + od_SET_NULL = "set_null" + od_SET_DEFAULT = "set_default" + od_DO_NOTHING = "do_nothing" + defaultStructTagName = "orm" +) + +var ( + errLog *log.Logger + modelCache = &_modelCache{cache: make(map[string]*modelInfo)} + supportTag = map[string]int{ + "null": 1, + "blank": 1, + "index": 1, + "unique": 1, + "pk": 1, + "auto": 1, + "auto_now": 1, + "auto_now_add": 1, + "max_length": 2, + "choices": 2, + "column": 2, + "default": 2, + "rel": 2, + "reverse": 2, + "rel_table": 2, + "rel_through": 2, + "digits": 2, + "decimals": 2, + "on_delete": 2, + } +) + +func init() { + errLog = log.New(os.Stderr, "[ORM] ", log.Ldate|log.Ltime|log.Lshortfile) +} + +type _modelCache struct { + sync.RWMutex + orders []string + cache map[string]*modelInfo +} + +func (mc *_modelCache) all() map[string]*modelInfo { + m := make(map[string]*modelInfo, len(mc.cache)) + for k, v := range mc.cache { + m[k] = v + } + return m +} + +func (mc *_modelCache) allOrdered() []*modelInfo { + m := make([]*modelInfo, 0, len(mc.orders)) + for _, v := range mc.cache { + m = append(m, v) + } + return m +} + +func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) { + mi, ok = mc.cache[table] + return +} + +func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo { + mii := mc.cache[table] + mc.cache[table] = mi + if mii == nil { + mc.orders = append(mc.orders, table) + } + return mii +} diff --git a/orm/models_boot.go b/orm/models_boot.go new file mode 100644 index 00000000..7440df9e --- /dev/null +++ b/orm/models_boot.go @@ -0,0 +1,212 @@ +package orm + +import ( + "errors" + "fmt" + "os" + "reflect" + "strings" +) + +func RegisterModel(model Modeler) { + info := newModelInfo(model) + model.Init(model) + table := model.GetTableName() + if _, ok := modelCache.get(table); ok { + fmt.Printf("model <%T> redeclared, must be unique\n", model) + os.Exit(2) + } + if info.fields.pk == nil { + fmt.Printf("model <%T> need a primary key field\n", model) + os.Exit(2) + } + info.table = table + info.pkg = getPkgPath(model) + info.model = model + info.manual = true + modelCache.set(table, info) +} + +func BootStrap() { + modelCache.Lock() + defer modelCache.Unlock() + + var ( + err error + models map[string]*modelInfo + ) + + if dataBaseCache.getDefault() == nil { + err = fmt.Errorf("must have one register alias named `default`") + goto end + } + + models = modelCache.all() + for _, mi := range models { + for _, fi := range mi.fields.columns { + if fi.rel || fi.reverse { + elm := fi.addrValue.Type().Elem() + switch fi.fieldType { + case RelReverseMany, RelManyToMany: + elm = elm.Elem() + } + + tn := getTableName(reflect.New(elm).Interface().(Modeler)) + mii, ok := modelCache.get(tn) + if ok == false || mii.pkg != elm.PkgPath() { + err = fmt.Errorf("can not found rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String()) + goto end + } + fi.relModelInfo = mii + + if fi.rel { + + if mii.fields.pk.IsMulti() { + err = fmt.Errorf("field `%s` unsupport rel to multi primary key field", fi.fullName) + goto end + } + } + + switch fi.fieldType { + case RelManyToMany: + if fi.relThrough != "" { + msg := fmt.Sprintf("filed `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough) + if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) { + pn := fi.relThrough[:i] + mn := fi.relThrough[i+1:] + tn := snakeString(mn) + rmi, ok := modelCache.get(tn) + if ok == false || pn != rmi.pkg { + err = errors.New(msg + " cannot find table") + goto end + } + + fi.relThroughModelInfo = rmi + fi.relTable = rmi.table + + } else { + err = errors.New(msg) + goto end + } + err = nil + } else { + i := newM2MModelInfo(mi, mii) + if fi.relTable != "" { + i.table = fi.relTable + } + + if v := modelCache.set(i.table, i); v != nil { + err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable) + goto end + } + fi.relTable = i.table + fi.relThroughModelInfo = i + } + } + } + } + } + + models = modelCache.all() + for _, mi := range models { + for _, fi := range mi.fields.fieldsRel { + switch fi.fieldType { + case RelForeignKey, RelOneToOne, RelManyToMany: + inModel := false + for _, ffi := range fi.relModelInfo.fields.fieldsReverse { + if ffi.relModelInfo == mi { + inModel = true + break + } + } + if inModel == false { + rmi := fi.relModelInfo + ffi := new(fieldInfo) + ffi.name = mi.name + ffi.column = ffi.name + ffi.fullName = rmi.fullName + "." + ffi.name + ffi.reverse = true + ffi.relModelInfo = mi + ffi.mi = rmi + if fi.fieldType == RelOneToOne { + ffi.fieldType = RelReverseOne + } else { + ffi.fieldType = RelReverseMany + } + if rmi.fields.Add(ffi) == false { + added := false + for cnt := 0; cnt < 5; cnt++ { + ffi.name = fmt.Sprintf("%s%d", mi.name, cnt) + ffi.column = ffi.name + ffi.fullName = rmi.fullName + "." + ffi.name + if added = rmi.fields.Add(ffi); added { + break + } + } + if added == false { + panic(fmt.Sprintf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName)) + } + } + } + } + } + } + + for _, mi := range models { + if fields, ok := mi.fields.fieldsByType[RelReverseOne]; ok { + for _, fi := range fields { + found := false + mForA: + for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] { + if ffi.relModelInfo == mi { + found = true + fi.reverseField = ffi.name + fi.reverseFieldInfo = ffi + break mForA + } + } + if found == false { + err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) + goto end + } + } + } + if fields, ok := mi.fields.fieldsByType[RelReverseMany]; ok { + for _, fi := range fields { + found := false + mForB: + for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] { + if ffi.relModelInfo == mi { + found = true + fi.reverseField = ffi.name + fi.reverseFieldInfo = ffi + break mForB + } + } + if found == false { + mForC: + for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] { + if ffi.relModelInfo == mi { + found = true + fi.reverseField = ffi.name + fi.reverseFieldInfo = ffi + break mForC + } + } + } + if found == false { + err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) + goto end + } + } + } + } + +end: + if err != nil { + fmt.Println(err) + os.Exit(2) + } + + runCommand() +} diff --git a/orm/models_fields.go b/orm/models_fields.go new file mode 100644 index 00000000..432505f2 --- /dev/null +++ b/orm/models_fields.go @@ -0,0 +1,523 @@ +package orm + +import ( + "errors" + "fmt" + "strconv" + "time" +) + +const ( + // bool + TypeBooleanField = 1 << iota + + // string + TypeCharField + + // string + TypeTextField + + // time.Time + TypeDateField + // time.Time + TypeDateTimeField + + // int16 + TypeSmallIntegerField + // int32 + TypeIntegerField + // int64 + TypeBigIntegerField + // uint16 + TypePositiveSmallIntegerField + // uint32 + TypePositiveIntegerField + // uint64 + TypePositiveBigIntegerField + + // float64 + TypeFloatField + // float64 + TypeDecimalField + + RelForeignKey + RelOneToOne + RelManyToMany + RelReverseOne + RelReverseMany +) + +const ( + IsIntegerField = ^-TypePositiveBigIntegerField >> 4 << 5 + IsPostiveIntegerField = ^-TypePositiveBigIntegerField >> 7 << 8 + IsRelField = ^-RelReverseMany >> 12 << 13 + IsFieldType = ^-RelReverseMany<<1 + 1 +) + +// A true/false field. +type BooleanField bool + +func (e BooleanField) Value() bool { + return bool(e) +} + +func (e *BooleanField) Set(d bool) { + *e = BooleanField(d) +} + +func (e *BooleanField) String() string { + return strconv.FormatBool(e.Value()) +} + +func (e *BooleanField) FieldType() int { + return TypeBooleanField +} + +func (e *BooleanField) SetRaw(value interface{}) error { + switch d := value.(type) { + case bool: + e.Set(d) + case string: + v, err := StrTo(d).Bool() + if err != nil { + e.Set(v) + } + return err + default: + return errors.New(fmt.Sprintf(" unknown value `%s`", value)) + } + return nil +} + +func (e *BooleanField) RawValue() interface{} { + return e.Value() +} + +// A string field +// required values tag: max_length +// The max_length is enforced at the database level and in models’s validation. +// eg: `max_length:"120"` +type CharField string + +func (e CharField) Value() string { + return string(e) +} + +func (e *CharField) Set(d string) { + *e = CharField(d) +} + +func (e *CharField) String() string { + return e.Value() +} + +func (e *CharField) FieldType() int { + return TypeCharField +} + +func (e *CharField) SetRaw(value interface{}) error { + switch d := value.(type) { + case string: + e.Set(d) + default: + return errors.New(fmt.Sprintf(" unknown value `%s`", value)) + } + return nil +} + +func (e *CharField) RawValue() interface{} { + return e.Value() +} + +// A date, represented in go by a time.Time instance. +// only date values like 2006-01-02 +// Has a few extra, optional attr tag: +// +// auto_now: +// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// auto_now_add: +// Automatically set the field to now when the object is first created. Useful for creation of timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// eg: `attr:"auto_now"` or `attr:"auto_now_add"` +type DateField time.Time + +func (e DateField) Value() time.Time { + return time.Time(e) +} + +func (e *DateField) Set(d time.Time) { + *e = DateField(d) +} + +func (e *DateField) String() string { + return e.Value().String() +} + +func (e *DateField) FieldType() int { + return TypeDateField +} + +func (e *DateField) SetRaw(value interface{}) error { + switch d := value.(type) { + case time.Time: + e.Set(d) + case string: + v, err := timeParse(d, format_Date) + if err != nil { + e.Set(v) + } + return err + default: + return errors.New(fmt.Sprintf(" unknown value `%s`", value)) + } + return nil +} + +func (e *DateField) RawValue() interface{} { + return e.Value() +} + +// A date, represented in go by a time.Time instance. +// datetime values like 2006-01-02 15:04:05 +// Takes the same extra arguments as DateField. +type DateTimeField time.Time + +func (e DateTimeField) Value() time.Time { + return time.Time(e) +} + +func (e *DateTimeField) Set(d time.Time) { + *e = DateTimeField(d) +} + +func (e *DateTimeField) String() string { + return e.Value().String() +} + +func (e *DateTimeField) FieldType() int { + return TypeDateTimeField +} + +func (e *DateTimeField) SetRaw(value interface{}) error { + switch d := value.(type) { + case time.Time: + e.Set(d) + case string: + v, err := timeParse(d, format_DateTime) + if err != nil { + e.Set(v) + } + return err + default: + return errors.New(fmt.Sprintf(" unknown value `%s`", value)) + } + return nil +} + +func (e *DateTimeField) RawValue() interface{} { + return e.Value() +} + +// A floating-point number represented in go by a float32 value. +type FloatField float64 + +func (e FloatField) Value() float64 { + return float64(e) +} + +func (e *FloatField) Set(d float64) { + *e = FloatField(d) +} + +func (e *FloatField) String() string { + return ToStr(e.Value(), -1, 32) +} + +func (e *FloatField) FieldType() int { + return TypeFloatField +} + +func (e *FloatField) SetRaw(value interface{}) error { + switch d := value.(type) { + case float32: + e.Set(float64(d)) + case float64: + e.Set(d) + case string: + v, err := StrTo(d).Float64() + if err != nil { + e.Set(v) + } + default: + return errors.New(fmt.Sprintf(" unknown value `%s`", value)) + } + return nil +} + +func (e *FloatField) RawValue() interface{} { + return e.Value() +} + +// -32768 to 32767 +type SmallIntegerField int16 + +func (e SmallIntegerField) Value() int16 { + return int16(e) +} + +func (e *SmallIntegerField) Set(d int16) { + *e = SmallIntegerField(d) +} + +func (e *SmallIntegerField) String() string { + return ToStr(e.Value()) +} + +func (e *SmallIntegerField) FieldType() int { + return TypeSmallIntegerField +} + +func (e *SmallIntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case int16: + e.Set(d) + case string: + v, err := StrTo(d).Int16() + if err != nil { + e.Set(v) + } + default: + return errors.New(fmt.Sprintf(" unknown value `%s`", value)) + } + return nil +} + +func (e *SmallIntegerField) RawValue() interface{} { + return e.Value() +} + +// -2147483648 to 2147483647 +type IntegerField int32 + +func (e IntegerField) Value() int32 { + return int32(e) +} + +func (e *IntegerField) Set(d int32) { + *e = IntegerField(d) +} + +func (e *IntegerField) String() string { + return ToStr(e.Value()) +} + +func (e *IntegerField) FieldType() int { + return TypeIntegerField +} + +func (e *IntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case int32: + e.Set(d) + case string: + v, err := StrTo(d).Int32() + if err != nil { + e.Set(v) + } + default: + return errors.New(fmt.Sprintf(" unknown value `%s`", value)) + } + return nil +} + +func (e *IntegerField) RawValue() interface{} { + return e.Value() +} + +// -9223372036854775808 to 9223372036854775807. +type BigIntegerField int64 + +func (e BigIntegerField) Value() int64 { + return int64(e) +} + +func (e *BigIntegerField) Set(d int64) { + *e = BigIntegerField(d) +} + +func (e *BigIntegerField) String() string { + return ToStr(e.Value()) +} + +func (e *BigIntegerField) FieldType() int { + return TypeBigIntegerField +} + +func (e *BigIntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case int64: + e.Set(d) + case string: + v, err := StrTo(d).Int64() + if err != nil { + e.Set(v) + } + default: + return errors.New(fmt.Sprintf(" unknown value `%s`", value)) + } + return nil +} + +func (e *BigIntegerField) RawValue() interface{} { + return e.Value() +} + +// 0 to 65535 +type PositiveSmallIntegerField uint16 + +func (e PositiveSmallIntegerField) Value() uint16 { + return uint16(e) +} + +func (e *PositiveSmallIntegerField) Set(d uint16) { + *e = PositiveSmallIntegerField(d) +} + +func (e *PositiveSmallIntegerField) String() string { + return ToStr(e.Value()) +} + +func (e *PositiveSmallIntegerField) FieldType() int { + return TypePositiveSmallIntegerField +} + +func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case uint16: + e.Set(d) + case string: + v, err := StrTo(d).Uint16() + if err != nil { + e.Set(v) + } + default: + return errors.New(fmt.Sprintf(" unknown value `%s`", value)) + } + return nil +} + +func (e *PositiveSmallIntegerField) RawValue() interface{} { + return e.Value() +} + +// 0 to 4294967295 +type PositiveIntegerField uint32 + +func (e PositiveIntegerField) Value() uint32 { + return uint32(e) +} + +func (e *PositiveIntegerField) Set(d uint32) { + *e = PositiveIntegerField(d) +} + +func (e *PositiveIntegerField) String() string { + return ToStr(e.Value()) +} + +func (e *PositiveIntegerField) FieldType() int { + return TypePositiveIntegerField +} + +func (e *PositiveIntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case uint32: + e.Set(d) + case string: + v, err := StrTo(d).Uint32() + if err != nil { + e.Set(v) + } + default: + return errors.New(fmt.Sprintf(" unknown value `%s`", value)) + } + return nil +} + +func (e *PositiveIntegerField) RawValue() interface{} { + return e.Value() +} + +// 0 to 18446744073709551615 +type PositiveBigIntegerField uint64 + +func (e PositiveBigIntegerField) Value() uint64 { + return uint64(e) +} + +func (e *PositiveBigIntegerField) Set(d uint64) { + *e = PositiveBigIntegerField(d) +} + +func (e *PositiveBigIntegerField) String() string { + return ToStr(e.Value()) +} + +func (e *PositiveBigIntegerField) FieldType() int { + return TypePositiveIntegerField +} + +func (e *PositiveBigIntegerField) SetRaw(value interface{}) error { + switch d := value.(type) { + case uint64: + e.Set(d) + case string: + v, err := StrTo(d).Uint64() + if err != nil { + e.Set(v) + } + default: + return errors.New(fmt.Sprintf(" unknown value `%s`", value)) + } + return nil +} + +func (e *PositiveBigIntegerField) RawValue() interface{} { + return e.Value() +} + +// A large text field. +type TextField string + +func (e TextField) Value() string { + return string(e) +} + +func (e *TextField) Set(d string) { + *e = TextField(d) +} + +func (e *TextField) String() string { + return e.Value() +} + +func (e *TextField) FieldType() int { + return TypeTextField +} + +func (e *TextField) SetRaw(value interface{}) error { + switch d := value.(type) { + case string: + e.Set(d) + default: + return errors.New(fmt.Sprintf(" unknown value `%s`", value)) + } + return nil +} + +func (e *TextField) RawValue() interface{} { + return e.Value() +} diff --git a/orm/models_info_f.go b/orm/models_info_f.go new file mode 100644 index 00000000..429eef2e --- /dev/null +++ b/orm/models_info_f.go @@ -0,0 +1,495 @@ +package orm + +import ( + "errors" + "fmt" + "reflect" + "strings" +) + +type fieldChoices []StrTo + +func (f *fieldChoices) Add(s StrTo) { + if f.Have(s) == false { + *f = append(*f, s) + } +} + +func (f *fieldChoices) Clear() { + *f = fieldChoices([]StrTo{}) +} + +func (f *fieldChoices) Have(s StrTo) bool { + for _, v := range *f { + if v == s { + return true + } + } + return false +} + +func (f *fieldChoices) Clone() fieldChoices { + return *f +} + +type primaryKeys []*fieldInfo + +func (p *primaryKeys) Add(fi *fieldInfo) { + *p = append(*p, fi) +} + +func (p primaryKeys) Exist(fi *fieldInfo) (int, bool) { + for i, v := range p { + if v == fi { + return i, true + } + } + return -1, false +} + +func (p primaryKeys) IsMulti() bool { + return len(p) > 1 +} + +func (p primaryKeys) IsEmpty() bool { + return len(p) == 0 +} + +type fields struct { + pk primaryKeys + auto *fieldInfo + columns map[string]*fieldInfo + fields map[string]*fieldInfo + fieldsLow map[string]*fieldInfo + fieldsByType map[int][]*fieldInfo + fieldsRel []*fieldInfo + fieldsReverse []*fieldInfo + fieldsDB []*fieldInfo + rels []*fieldInfo + orders []string + dbcols []string +} + +func (f *fields) Add(fi *fieldInfo) (added bool) { + if f.fields[fi.name] == nil && f.columns[fi.column] == nil { + f.columns[fi.column] = fi + f.fields[fi.name] = fi + f.fieldsLow[strings.ToLower(fi.name)] = fi + } else { + return + } + if _, ok := f.fieldsByType[fi.fieldType]; ok == false { + f.fieldsByType[fi.fieldType] = make([]*fieldInfo, 0) + } + f.fieldsByType[fi.fieldType] = append(f.fieldsByType[fi.fieldType], fi) + f.orders = append(f.orders, fi.column) + if fi.dbcol { + f.dbcols = append(f.dbcols, fi.column) + f.fieldsDB = append(f.fieldsDB, fi) + } + if fi.rel { + f.fieldsRel = append(f.fieldsRel, fi) + } + if fi.reverse { + f.fieldsReverse = append(f.fieldsReverse, fi) + } + return true +} + +func (f *fields) GetByName(name string) *fieldInfo { + return f.fields[name] +} + +func (f *fields) GetByColumn(column string) *fieldInfo { + return f.columns[column] +} + +func (f *fields) GetByAny(name string) (*fieldInfo, bool) { + if fi, ok := f.fields[name]; ok { + return fi, ok + } + if fi, ok := f.fieldsLow[strings.ToLower(name)]; ok { + return fi, ok + } + if fi, ok := f.columns[name]; ok { + return fi, ok + } + return nil, false +} + +func newFields() *fields { + f := new(fields) + f.fields = make(map[string]*fieldInfo) + f.fieldsLow = make(map[string]*fieldInfo) + f.columns = make(map[string]*fieldInfo) + f.fieldsByType = make(map[int][]*fieldInfo) + return f +} + +type fieldInfo struct { + mi *modelInfo + fieldIndex int + fieldType int + dbcol bool + inModel bool + name string + fullName string + column string + addrValue *reflect.Value + sf *reflect.StructField + auto bool + pk bool + null bool + blank bool + index bool + unique bool + initial StrTo + choices fieldChoices + maxLength int + auto_now bool + auto_now_add bool + rel bool + reverse bool + reverseField string + reverseFieldInfo *fieldInfo + relTable string + relThrough string + relThroughModelInfo *modelInfo + relModelInfo *modelInfo + digits int + decimals int + isFielder bool + onDelete string +} + +func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField) (fi *fieldInfo, err error) { + var ( + tag string + tagValue string + choices fieldChoices + values fieldChoices + initial StrTo + fieldType int + attrs map[string]bool + tags map[string]string + parts []string + addrField reflect.Value + ) + + fi = new(fieldInfo) + + if field.Kind() != reflect.Ptr && field.Kind() != reflect.Slice && field.CanAddr() { + addrField = field.Addr() + } else { + addrField = field + } + + parseStructTag(sf.Tag.Get(defaultStructTagName), &attrs, &tags) + + digits := tags["digits"] + decimals := tags["decimals"] + maxLength := tags["max_length"] + onDelete := tags["on_delete"] + +checkType: + switch f := addrField.Interface().(type) { + case Fielder: + fi.isFielder = true + if field.Kind() == reflect.Ptr { + err = fmt.Errorf("the model Fielder can not be use ptr") + goto end + } + fieldType = f.FieldType() + if fieldType&IsRelField > 0 { + err = fmt.Errorf("unsupport rel type custom field") + goto end + } + default: + tag = "rel" + tagValue = tags[tag] + if tagValue != "" { + switch tagValue { + case "fk": + fieldType = RelForeignKey + break checkType + case "one": + fieldType = RelOneToOne + break checkType + case "m2m": + fieldType = RelManyToMany + if tv := tags["rel_table"]; tv != "" { + fi.relTable = tv + } else if tv := tags["rel_through"]; tv != "" { + fi.relThrough = tv + } + break checkType + default: + err = fmt.Errorf("error") + goto wrongTag + } + } + tag = "reverse" + tagValue = tags[tag] + if tagValue != "" { + switch tagValue { + case "one": + fieldType = RelReverseOne + break checkType + case "many": + fieldType = RelReverseMany + break checkType + default: + err = fmt.Errorf("error") + goto wrongTag + } + } + + fieldType, err = getFieldType(addrField) + if err != nil { + goto end + } + if fieldType == TypeTextField && maxLength != "" { + fieldType = TypeCharField + } + if fieldType == TypeFloatField && (digits != "" || decimals != "") { + fieldType = TypeDecimalField + } + if fieldType == TypeDateTimeField && attrs["date"] { + fieldType = TypeDateField + } + } + + switch fieldType { + case RelForeignKey, RelOneToOne, RelReverseOne: + if _, ok := addrField.Interface().(Modeler); ok == false { + err = fmt.Errorf("rel/reverse:one field must be implements Modeler") + goto end + } + if field.Kind() != reflect.Ptr { + err = fmt.Errorf("rel/reverse:one field must be *%s", field.Type().Name()) + goto end + } + case RelManyToMany, RelReverseMany: + if field.Kind() != reflect.Slice { + err = fmt.Errorf("rel/reverse:many field must be slice") + goto end + } else { + if field.Type().Elem().Kind() != reflect.Ptr { + err = fmt.Errorf("rel/reverse:many slice must be []*%s", field.Type().Elem().Name()) + goto end + } + if _, ok := reflect.New(field.Type().Elem()).Elem().Interface().(Modeler); ok == false { + err = fmt.Errorf("rel/reverse:many slice element must be implements Modeler") + goto end + } + } + } + + if fieldType&IsFieldType == 0 { + err = fmt.Errorf("wrong field type") + goto end + } + + fi.fieldType = fieldType + fi.name = sf.Name + fi.column = getColumnName(fieldType, addrField, sf, tags["column"]) + fi.addrValue = &addrField + fi.sf = &sf + fi.fullName = mi.fullName + "." + sf.Name + + fi.null = attrs["null"] + fi.blank = attrs["blank"] + fi.index = attrs["index"] + fi.auto = attrs["auto"] + fi.pk = attrs["pk"] + fi.unique = attrs["unique"] + + switch fieldType { + case RelManyToMany, RelReverseMany, RelReverseOne: + fi.null = false + fi.blank = false + fi.index = false + fi.auto = false + fi.pk = false + fi.unique = false + default: + fi.dbcol = true + } + + switch fieldType { + case RelForeignKey, RelOneToOne, RelManyToMany: + fi.rel = true + if fieldType == RelOneToOne { + fi.unique = true + } + case RelReverseMany, RelReverseOne: + fi.reverse = true + } + + if fi.rel && fi.dbcol { + switch onDelete { + case od_CASCADE, od_DO_NOTHING: + case od_SET_DEFAULT: + if tags["default"] == "" { + err = errors.New("on_delete: set_default need set field a default value") + goto end + } + case od_SET_NULL: + if fi.null == false { + err = errors.New("on_delete: set_null need set field null") + goto end + } + default: + if onDelete == "" { + onDelete = od_CASCADE + } else { + err = fmt.Errorf("on_delete value expected choice in `cascade,set_null,set_default,do_nothing`, unknown `%s`", onDelete) + goto end + } + } + + fi.onDelete = onDelete + } + + switch fieldType { + case TypeBooleanField: + case TypeCharField: + if maxLength != "" { + v, e := StrTo(maxLength).Int32() + if e != nil { + err = fmt.Errorf("wrong maxLength value `%s`", maxLength) + } else { + fi.maxLength = int(v) + } + } else { + err = fmt.Errorf("maxLength must be specify") + } + case TypeTextField: + fi.index = false + fi.unique = false + case TypeDateField, TypeDateTimeField: + if attrs["auto_now"] { + fi.auto_now = true + } else if attrs["auto_now_add"] { + fi.auto_now_add = true + } + case TypeFloatField: + case TypeDecimalField: + d1 := digits + d2 := decimals + v1, er1 := StrTo(d1).Int16() + v2, er2 := StrTo(d2).Int16() + if er1 != nil || er2 != nil { + err = fmt.Errorf("wrong digits/decimals value %s/%s", d2, d1) + goto end + } + fi.digits = int(v1) + fi.decimals = int(v2) + default: + switch { + case fieldType&IsIntegerField > 0: + case fieldType&IsRelField > 0: + } + } + + if fieldType&IsIntegerField == 0 { + if fi.auto { + err = fmt.Errorf("non-integer type cannot set auto") + goto end + } + + if fi.pk || fi.index || fi.unique { + if fieldType != TypeCharField && fieldType != RelOneToOne { + err = fmt.Errorf("cannot set pk/index/unique") + goto end + } + } + } + + if fi.auto || fi.pk { + if fi.auto { + fi.pk = true + } + fi.null = false + fi.blank = false + fi.index = false + fi.unique = false + } + + if fi.unique { + fi.null = false + fi.blank = false + fi.index = false + } + + parts = strings.Split(tags["choices"], ",") + if len(parts) > 1 { + for _, v := range parts { + choices.Add(StrTo(strings.TrimSpace(v))) + } + } + + initial.Clear() + if v, ok := tags["default"]; ok { + initial.Set(v) + } + + if fi.auto || fi.pk || fi.unique || fieldType == TypeDateField || fieldType == TypeDateTimeField { + // can not set default + choices.Clear() + initial.Clear() + } + + values = choices.Clone() + + if initial.Exist() { + values.Add(initial) + } + + for i, v := range values { + switch fieldType { + case TypeBooleanField: + _, err = v.Bool() + case TypeFloatField, TypeDecimalField: + _, err = v.Float64() + case TypeSmallIntegerField: + _, err = v.Int16() + case TypeIntegerField: + _, err = v.Int32() + case TypeBigIntegerField: + _, err = v.Int64() + case TypePositiveSmallIntegerField: + _, err = v.Uint16() + case TypePositiveIntegerField: + _, err = v.Uint32() + case TypePositiveBigIntegerField: + _, err = v.Uint64() + } + if err != nil { + if initial.Exist() && len(values) == i { + tag, tagValue = "default", tags["default"] + } else { + tag, tagValue = "choices", tags["choices"] + } + goto wrongTag + } + } + + if len(choices) > 0 && initial.Exist() { + if choices.Have(initial) == false { + err = fmt.Errorf("default value `%s` not in choices `%s`", tags["default"], tags["choices"]) + goto end + } + } + + fi.choices = choices + fi.initial = initial + +end: + if err != nil { + return nil, err + } + return +wrongTag: + return nil, fmt.Errorf("wrong tag format: `%s:\"%s\"`, %s", tag, tagValue, err) +} diff --git a/orm/models_info_m.go b/orm/models_info_m.go new file mode 100644 index 00000000..a6e755b7 --- /dev/null +++ b/orm/models_info_m.go @@ -0,0 +1,130 @@ +package orm + +import ( + "errors" + "fmt" + "os" + "reflect" +) + +type modelInfo struct { + pkg string + name string + fullName string + table string + model Modeler + fields *fields + manual bool + addrField reflect.Value +} + +func newModelInfo(model Modeler) (info *modelInfo) { + var ( + err error + fi *fieldInfo + sf reflect.StructField + ) + + info = &modelInfo{} + info.fields = newFields() + + val := reflect.ValueOf(model) + ind := reflect.Indirect(val) + typ := ind.Type() + + info.addrField = ind.Addr() + + info.name = typ.Name() + info.fullName = typ.PkgPath() + "." + typ.Name() + + for i := 0; i < ind.NumField(); i++ { + field := ind.Field(i) + sf = ind.Type().Field(i) + if field.CanAddr() { + addr := field.Addr() + if _, ok := addr.Interface().(*Manager); ok { + continue + } + } + fi, err = newFieldInfo(info, field, sf) + if err != nil { + break + } + added := info.fields.Add(fi) + if added == false { + err = errors.New(fmt.Sprintf("duplicate column name: %s", fi.column)) + break + } + if fi.pk { + if info.fields.pk != nil { + err = errors.New(fmt.Sprintf("one model must have one pk field only")) + break + } else { + info.fields.pk.Add(fi) + } + } + if fi.auto { + info.fields.auto = fi + } + fi.fieldIndex = i + fi.mi = info + } + + if _, ok := info.fields.pk.Exist(info.fields.auto); info.fields.auto != nil && ok == false { + err = errors.New(fmt.Sprintf("when auto field exists, you cannot set other pk field")) + goto end + } + + if err != nil { + fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err)) + os.Exit(2) + } + +end: + if err != nil { + fmt.Println(err) + os.Exit(2) + } + return +} + +func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) { + info = new(modelInfo) + info.fields = newFields() + info.table = m1.table + "_" + m2.table + "_rel" + info.name = camelString(info.table) + info.fullName = m1.pkg + "." + info.name + + fa := new(fieldInfo) + f1 := new(fieldInfo) + f2 := new(fieldInfo) + fa.fieldType = TypeBigIntegerField + fa.auto = true + fa.pk = true + fa.dbcol = true + + f1.dbcol = true + f2.dbcol = true + f1.fieldType = RelForeignKey + f2.fieldType = RelForeignKey + f1.name = camelString(m1.table) + f2.name = camelString(m2.table) + f1.fullName = info.fullName + "." + f1.name + f2.fullName = info.fullName + "." + f2.name + f1.column = m1.table + "_id" + f2.column = m2.table + "_id" + f1.rel = true + f2.rel = true + f1.relTable = m1.table + f2.relTable = m2.table + f1.relModelInfo = m1 + f2.relModelInfo = m2 + f1.mi = info + f2.mi = info + + info.fields.Add(fa) + info.fields.Add(f1) + info.fields.Add(f2) + info.fields.pk.Add(fa) + return +} diff --git a/orm/models_manager.go b/orm/models_manager.go new file mode 100644 index 00000000..46659811 --- /dev/null +++ b/orm/models_manager.go @@ -0,0 +1,56 @@ +package orm + +import () + +// non cleaned field errors +type FieldErrors map[string]error + +func (fe FieldErrors) Get(name string) error { + return fe[name] +} + +func (fe FieldErrors) Set(name string, value error) { + fe[name] = value +} + +type Manager struct { + ins Modeler + inited bool +} + +// func (m *Manager) init(model reflect.Value) { +// elm := model.Elem() +// for i := 0; i < elm.NumField(); i++ { +// field := elm.Field(i) +// if _, ok := field.Interface().(Fielder); ok && field.CanSet() { +// if field.Elem().Kind() != reflect.Struct { +// field.Set(reflect.New(field.Type().Elem())) +// } +// } +// } +// } + +func (m *Manager) Init(model Modeler) Modeler { + if m.inited { + return m.ins + } + m.inited = true + m.ins = model + return model +} + +func (m *Manager) IsInited() bool { + return m.inited +} + +func (m *Manager) Clean() FieldErrors { + return nil +} + +func (m *Manager) CleanFields(name string) FieldErrors { + return nil +} + +func (m *Manager) GetTableName() string { + return getTableName(m.ins) +} diff --git a/orm/models_utils.go b/orm/models_utils.go new file mode 100644 index 00000000..fafc033e --- /dev/null +++ b/orm/models_utils.go @@ -0,0 +1,97 @@ +package orm + +import ( + "fmt" + "reflect" + "strings" + "time" +) + +func getTableName(model Modeler) string { + val := reflect.ValueOf(model) + ind := reflect.Indirect(val) + fun := val.MethodByName("TableName") + if fun.IsValid() { + vals := fun.Call([]reflect.Value{}) + if len(vals) > 0 { + val := vals[0] + if val.Kind() == reflect.String { + return val.String() + } + } + } + return snakeString(ind.Type().Name()) +} + +func getPkgPath(model Modeler) string { + val := reflect.ValueOf(model) + return val.Type().Elem().PkgPath() +} + +func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string { + column := strings.ToLower(col) + if column == "" { + column = snakeString(sf.Name) + } + switch ft { + case RelForeignKey, RelOneToOne: + column = column + "_id" + case RelManyToMany, RelReverseMany, RelReverseOne: + column = sf.Name + } + return column +} + +func getFieldType(val reflect.Value) (ft int, err error) { + elm := reflect.Indirect(val) + switch elm.Kind() { + case reflect.Int16: + ft = TypeSmallIntegerField + case reflect.Int32, reflect.Int: + ft = TypeIntegerField + case reflect.Int64: + ft = TypeBigIntegerField + case reflect.Uint16: + ft = TypePositiveSmallIntegerField + case reflect.Uint32: + ft = TypePositiveIntegerField + case reflect.Uint64: + ft = TypePositiveBigIntegerField + case reflect.Float32, reflect.Float64: + ft = TypeFloatField + case reflect.Bool: + ft = TypeBooleanField + case reflect.String: + ft = TypeTextField + case reflect.Invalid: + default: + if elm.CanInterface() { + if _, ok := elm.Interface().(time.Time); ok { + ft = TypeDateTimeField + } + } + } + if ft&IsFieldType == 0 { + err = fmt.Errorf("unsupport field type %s, may be miss setting tag", val) + } + return +} + +func parseStructTag(data string, attrs *map[string]bool, tags *map[string]string) { + attr := make(map[string]bool) + tag := make(map[string]string) + for _, v := range strings.Split(data, ";") { + v = strings.TrimSpace(v) + if supportTag[v] == 1 { + attr[v] = true + } else if i := strings.Index(v, "("); i > 0 && strings.Index(v, ")") == len(v)-1 { + name := v[:i] + if supportTag[name] == 2 { + v = v[i+1 : len(v)-1] + tag[name] = v + } + } + } + *attrs = attr + *tags = tag +} diff --git a/orm/orm.go b/orm/orm.go new file mode 100644 index 00000000..769b86c5 --- /dev/null +++ b/orm/orm.go @@ -0,0 +1,111 @@ +package orm + +import ( + "database/sql" + "errors" + "fmt" + "time" +) + +var ( + ErrTXHasBegin = errors.New(" transaction already begin") + ErrTXNotBegin = errors.New(" transaction not begin") + ErrMultiRows = errors.New(" return multi rows") + ErrStmtClosed = errors.New(" stmt already closed") + DefaultRowsLimit = 1000 + DefaultRelsDepth = 5 + DefaultTimeLoc = time.Local +) + +type Params map[string]interface{} +type ParamsList []interface{} + +type orm struct { + alias *alias + db dbQuerier + isTx bool +} + +func (o *orm) Object(md Modeler) ObjectSeter { + name := md.GetTableName() + if mi, ok := modelCache.get(name); ok { + return newObject(o, mi, md) + } + panic(fmt.Sprintf(" table name: `%s` not exists", name)) +} + +func (o *orm) QueryTable(ptrStructOrTableName interface{}) QuerySeter { + name := "" + if table, ok := ptrStructOrTableName.(string); ok { + name = snakeString(table) + } else if m, ok := ptrStructOrTableName.(Modeler); ok { + name = m.GetTableName() + } + if mi, ok := modelCache.get(name); ok { + return newQuerySet(o, mi) + } + panic(fmt.Sprintf(" table name: `%s` not exists", name)) +} + +func (o *orm) Using(name string) error { + if o.isTx { + panic(" transaction has been start, cannot change db") + } + if al, ok := dataBaseCache.get(name); ok { + o.alias = al + o.db = al.DB + } else { + return errors.New(fmt.Sprintf(" unknown db alias name `%s`", name)) + } + return nil +} + +func (o *orm) Begin() error { + if o.isTx { + return ErrTXHasBegin + } + tx, err := o.alias.DB.Begin() + if err != nil { + return err + } + o.isTx = true + o.db = tx + return nil +} + +func (o *orm) Commit() error { + if o.isTx == false { + return ErrTXNotBegin + } + err := o.db.(*sql.Tx).Commit() + if err == nil { + o.isTx = false + o.db = o.alias.DB + } + return err +} + +func (o *orm) Rollback() error { + if o.isTx == false { + return ErrTXNotBegin + } + err := o.db.(*sql.Tx).Rollback() + if err == nil { + o.isTx = false + o.db = o.alias.DB + } + return err +} + +func (o *orm) Raw(query string, args ...interface{}) RawSeter { + return newRawSet(o, query, args) +} + +func NewOrm() Ormer { + o := new(orm) + err := o.Using("default") + if err != nil { + panic(err) + } + return o +} diff --git a/orm/orm_conds.go b/orm/orm_conds.go new file mode 100644 index 00000000..43ba0f70 --- /dev/null +++ b/orm/orm_conds.go @@ -0,0 +1,94 @@ +package orm + +import ( + "strings" +) + +const ( + ExprSep = "__" +) + +type condValue struct { + exprs []string + args []interface{} + cond *Condition + isOr bool + isNot bool + isCond bool +} + +type Condition struct { + params []condValue +} + +func NewCondition() *Condition { + c := &Condition{} + return c +} + +func (c *Condition) And(expr string, args ...interface{}) *Condition { + if expr == "" || len(args) == 0 { + panic(" args cannot empty") + } + c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args}) + return c +} + +func (c *Condition) AndNot(expr string, args ...interface{}) *Condition { + if expr == "" || len(args) == 0 { + panic(" args cannot empty") + } + c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true}) + return c +} + +func (c *Condition) AndCond(cond *Condition) *Condition { + if c == cond { + panic("cannot use self as sub cond") + } + if cond != nil { + c.params = append(c.params, condValue{cond: cond, isCond: true}) + } + return c +} + +func (c *Condition) Or(expr string, args ...interface{}) *Condition { + if expr == "" || len(args) == 0 { + panic(" args cannot empty") + } + c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isOr: true}) + return c +} + +func (c *Condition) OrNot(expr string, args ...interface{}) *Condition { + if expr == "" || len(args) == 0 { + panic(" args cannot empty") + } + c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true, isOr: true}) + return c +} + +func (c *Condition) OrCond(cond *Condition) *Condition { + if c == cond { + panic("cannot use self as sub cond") + } + if cond != nil { + c.params = append(c.params, condValue{cond: cond, isCond: true, isOr: true}) + } + return c +} + +func (c *Condition) IsEmpty() bool { + return len(c.params) == 0 +} + +func (c Condition) Clone() *Condition { + params := c.params + c.params = make([]condValue, len(params)) + copy(c.params, params) + return &c +} + +func (c *Condition) Merge() (expr string, args []interface{}) { + return expr, args +} diff --git a/orm/orm_object.go b/orm/orm_object.go new file mode 100644 index 00000000..1bcb5595 --- /dev/null +++ b/orm/orm_object.go @@ -0,0 +1,92 @@ +package orm + +import ( + "database/sql" + "fmt" + "reflect" +) + +type insertSet struct { + mi *modelInfo + orm *orm + stmt *sql.Stmt + closed bool +} + +func (o *insertSet) Insert(md Modeler) (int64, error) { + if o.closed { + return 0, ErrStmtClosed + } + val := reflect.ValueOf(md) + ind := reflect.Indirect(val) + if val.Type() != o.mi.addrField.Type() { + panic(fmt.Sprintf(" need type `%s` but found `%s`", o.mi.addrField.Type(), val.Type())) + } + id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind) + if err != nil { + return id, err + } + if id > 0 { + if o.mi.fields.auto != nil { + ind.Field(o.mi.fields.auto.fieldIndex).SetInt(id) + } + } + return id, nil +} + +func (o *insertSet) Close() error { + o.closed = true + return o.stmt.Close() +} + +func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) { + bi := new(insertSet) + bi.orm = orm + bi.mi = mi + st, err := orm.alias.DbBaser.PrepareInsert(orm.db, mi) + if err != nil { + return nil, err + } + bi.stmt = st + return bi, nil +} + +type object struct { + ind reflect.Value + mi *modelInfo + orm *orm +} + +func (o *object) Insert() (int64, error) { + id, err := o.orm.alias.DbBaser.Insert(o.orm.db, o.mi, o.ind) + if err != nil { + return id, err + } + if id > 0 { + if o.mi.fields.auto != nil { + o.ind.Field(o.mi.fields.auto.fieldIndex).SetInt(id) + } + } + return id, nil +} + +func (o *object) Update() (int64, error) { + num, err := o.orm.alias.DbBaser.Update(o.orm.db, o.mi, o.ind) + if err != nil { + return num, err + } + return 0, nil +} + +func (o *object) Delete() (int64, error) { + return o.orm.alias.DbBaser.Delete(o.orm.db, o.mi, o.ind) +} + +func newObject(orm *orm, mi *modelInfo, md Modeler) ObjectSeter { + o := new(object) + ind := reflect.Indirect(reflect.ValueOf(md)) + o.ind = ind + o.mi = mi + o.orm = orm + return o +} diff --git a/orm/orm_queryset.go b/orm/orm_queryset.go new file mode 100644 index 00000000..b9fd324a --- /dev/null +++ b/orm/orm_queryset.go @@ -0,0 +1,132 @@ +package orm + +import ( + "fmt" +) + +type querySet struct { + mi *modelInfo + cond *Condition + related []string + relDepth int + limit int + offset int64 + orders []string + orm *orm +} + +func (o *querySet) Filter(expr string, args ...interface{}) QuerySeter { + if o.cond == nil { + o.cond = NewCondition() + } + o.cond.And(expr, args...) + return o.Clone() +} + +func (o *querySet) Exclude(expr string, args ...interface{}) QuerySeter { + if o.cond == nil { + o.cond = NewCondition() + } + o.cond.AndNot(expr, args...) + return o.Clone() +} + +func (o *querySet) Limit(limit int, args ...int64) QuerySeter { + o.limit = limit + if len(args) > 0 { + o.offset = args[0] + } + return o.Clone() +} + +func (o *querySet) Offset(offset int64) QuerySeter { + o.offset = offset + return o.Clone() +} + +func (o *querySet) OrderBy(orders ...string) QuerySeter { + o.orders = orders + return o.Clone() +} + +func (o *querySet) RelatedSel(params ...interface{}) QuerySeter { + var related []string + if len(params) == 0 { + o.relDepth = DefaultRelsDepth + } else { + for _, p := range params { + switch val := p.(type) { + case string: + related = append(o.related, val) + case int: + o.relDepth = val + default: + panic(fmt.Sprintf(" wrong param kind: %v", val)) + } + } + } + o.related = related + return o.Clone() +} + +func (o querySet) Clone() QuerySeter { + if o.cond != nil { + o.cond = o.cond.Clone() + } + return &o +} + +func (o *querySet) SetCond(cond *Condition) error { + o.cond = cond + return nil +} + +func (o *querySet) Count() (int64, error) { + return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond) +} + +func (o *querySet) Update(values Params) (int64, error) { + return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values) +} + +func (o *querySet) Delete() (int64, error) { + return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond) +} + +func (o *querySet) PrepareInsert() (Inserter, error) { + return newInsertSet(o.orm, o.mi) +} + +func (o *querySet) All(container interface{}) (int64, error) { + return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container) +} + +func (o *querySet) One(container Modeler) error { + num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container) + if err != nil { + return err + } + if num > 1 { + return ErrMultiRows + } + return nil +} + +func (o *querySet) Values(results *[]Params, args ...string) (int64, error) { + return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, args, results) +} + +func (o *querySet) ValuesList(results *[]ParamsList, args ...string) (int64, error) { + return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, args, results) +} + +func (o *querySet) ValuesFlat(result *ParamsList, arg string) (int64, error) { + return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{arg}, result) +} + +func newQuerySet(orm *orm, mi *modelInfo) QuerySeter { + o := new(querySet) + o.mi = mi + o.orm = orm + return o +} diff --git a/orm/orm_raw.go b/orm/orm_raw.go new file mode 100644 index 00000000..3a5a488e --- /dev/null +++ b/orm/orm_raw.go @@ -0,0 +1,192 @@ +package orm + +import ( + "database/sql" + "fmt" + "reflect" +) + +func getResult(res sql.Result) (int64, error) { + if num, err := res.LastInsertId(); err != nil { + return 0, err + } else { + if num > 0 { + return num, nil + } + } + if num, err := res.RowsAffected(); err != nil { + return num, err + } else { + if num > 0 { + return num, nil + } + } + return 0, nil +} + +type rawPrepare struct { + rs *rawSet + stmt *sql.Stmt + closed bool +} + +func (o *rawPrepare) Exec(args ...interface{}) (int64, error) { + if o.closed { + return 0, ErrStmtClosed + } + res, err := o.stmt.Exec(args...) + if err != nil { + return 0, err + } + return getResult(res) +} + +func (o *rawPrepare) Close() error { + o.closed = true + return o.stmt.Close() +} + +func newRawPreparer(rs *rawSet) (RawPreparer, error) { + o := new(rawPrepare) + o.rs = rs + st, err := rs.orm.db.Prepare(rs.query) + if err != nil { + return nil, err + } + o.stmt = st + return o, nil +} + +type rawSet struct { + query string + args []interface{} + orm *orm +} + +func (o rawSet) SetArgs(args ...interface{}) RawSeter { + o.args = args + return &o +} + +func (o *rawSet) Exec() (int64, error) { + res, err := o.orm.db.Exec(o.query, o.args...) + if err != nil { + return 0, err + } + return getResult(res) +} + +func (o *rawSet) Mapper(...interface{}) (int64, error) { + //TODO + return 0, nil +} + +func (o *rawSet) readValues(container interface{}) (int64, error) { + var ( + maps []Params + lists []ParamsList + list ParamsList + ) + + typ := 0 + switch container.(type) { + case *[]Params: + typ = 1 + case *[]ParamsList: + typ = 2 + case *ParamsList: + typ = 3 + default: + panic(fmt.Sprintf("unsupport read values type `%T`", container)) + } + + var rs *sql.Rows + if r, err := o.orm.db.Query(o.query, o.args...); err != nil { + return 0, err + } else { + rs = r + } + + var ( + refs []interface{} + cnt int64 + cols []string + ) + for rs.Next() { + if cnt == 0 { + if columns, err := rs.Columns(); err != nil { + return 0, err + } else { + cols = columns + refs = make([]interface{}, len(cols)) + for i, _ := range refs { + var ref string + refs[i] = &ref + } + } + } + + if err := rs.Scan(refs...); err != nil { + return 0, err + } + + switch typ { + case 1: + params := make(Params, len(cols)) + for i, ref := range refs { + value := reflect.Indirect(reflect.ValueOf(ref)).Interface() + params[cols[i]] = value + } + maps = append(maps, params) + case 2: + params := make(ParamsList, 0, len(cols)) + for _, ref := range refs { + value := reflect.Indirect(reflect.ValueOf(ref)).Interface() + params = append(params, value) + } + lists = append(lists, params) + case 3: + for _, ref := range refs { + value := reflect.Indirect(reflect.ValueOf(ref)).Interface() + list = append(list, value) + } + } + + cnt++ + } + + switch v := container.(type) { + case *[]Params: + *v = maps + case *[]ParamsList: + *v = lists + case *ParamsList: + *v = list + } + + return cnt, nil +} + +func (o *rawSet) Values(container *[]Params) (int64, error) { + return o.readValues(container) +} + +func (o *rawSet) ValuesList(container *[]ParamsList) (int64, error) { + return o.readValues(container) +} + +func (o *rawSet) ValuesFlat(container *ParamsList) (int64, error) { + return o.readValues(container) +} + +func (o *rawSet) Prepare() (RawPreparer, error) { + return newRawPreparer(o) +} + +func newRawSet(orm *orm, query string, args []interface{}) RawSeter { + o := new(rawSet) + o.query = query + o.args = args + o.orm = orm + return o +} diff --git a/orm/types.go b/orm/types.go new file mode 100644 index 00000000..50217246 --- /dev/null +++ b/orm/types.go @@ -0,0 +1,97 @@ +package orm + +import ( + "database/sql" + "reflect" +) + +type Fielder interface { + String() string + FieldType() int + SetRaw(interface{}) error + RawValue() interface{} +} + +type Modeler interface { + Init(Modeler) Modeler + IsInited() bool + Clean() FieldErrors + CleanFields(string) FieldErrors + GetTableName() string +} + +type Ormer interface { + Object(Modeler) ObjectSeter + QueryTable(interface{}) QuerySeter + Using(string) error + Begin() error + Commit() error + Rollback() error + Raw(string, ...interface{}) RawSeter +} + +type ObjectSeter interface { + Insert() (int64, error) + Update() (int64, error) + Delete() (int64, error) +} + +type Inserter interface { + Insert(Modeler) (int64, error) + Close() error +} + +type QuerySeter interface { + Filter(string, ...interface{}) QuerySeter + Exclude(string, ...interface{}) QuerySeter + Limit(int, ...int64) QuerySeter + Offset(int64) QuerySeter + OrderBy(...string) QuerySeter + RelatedSel(...interface{}) QuerySeter + Clone() QuerySeter + SetCond(*Condition) error + Count() (int64, error) + Update(Params) (int64, error) + Delete() (int64, error) + PrepareInsert() (Inserter, error) + + All(interface{}) (int64, error) + One(Modeler) error + Values(*[]Params, ...string) (int64, error) + ValuesList(*[]ParamsList, ...string) (int64, error) + ValuesFlat(*ParamsList, string) (int64, error) +} + +type RawPreparer interface { + Close() error +} + +type RawSeter interface { + Exec() (int64, error) + Mapper(...interface{}) (int64, error) + Values(*[]Params) (int64, error) + ValuesList(*[]ParamsList) (int64, error) + ValuesFlat(*ParamsList) (int64, error) + Prepare() (RawPreparer, error) +} + +type dbQuerier interface { + Prepare(query string) (*sql.Stmt, error) + Exec(query string, args ...interface{}) (sql.Result, error) + Query(query string, args ...interface{}) (*sql.Rows, error) + QueryRow(query string, args ...interface{}) *sql.Row +} + +type dbBaser interface { + Insert(dbQuerier, *modelInfo, reflect.Value) (int64, error) + InsertStmt(*sql.Stmt, *modelInfo, reflect.Value) (int64, error) + Update(dbQuerier, *modelInfo, reflect.Value) (int64, error) + Delete(dbQuerier, *modelInfo, reflect.Value) (int64, error) + ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}) (int64, error) + UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params) (int64, error) + DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error) + Count(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error) + GetOperatorSql(*modelInfo, string, []interface{}) (string, []interface{}) + PrepareInsert(dbQuerier, *modelInfo) (*sql.Stmt, error) + ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}) (int64, error) +} diff --git a/orm/utils.go b/orm/utils.go new file mode 100644 index 00000000..7dd54896 --- /dev/null +++ b/orm/utils.go @@ -0,0 +1,181 @@ +package orm + +import ( + "fmt" + "strconv" + "strings" + "time" +) + +type StrTo string + +func (f *StrTo) Set(v string) { + if v != "" { + *f = StrTo(v) + } else { + f.Clear() + } +} + +func (f *StrTo) Clear() { + *f = StrTo(0x1E) +} + +func (f StrTo) Exist() bool { + return string(f) != string(0x1E) +} + +func (f StrTo) Bool() (bool, error) { + return strconv.ParseBool(f.String()) +} + +func (f StrTo) Float32() (float32, error) { + v, err := strconv.ParseFloat(f.String(), 32) + return float32(v), err +} + +func (f StrTo) Float64() (float64, error) { + return strconv.ParseFloat(f.String(), 64) +} + +func (f StrTo) Int16() (int16, error) { + v, err := strconv.ParseInt(f.String(), 10, 16) + return int16(v), err +} + +func (f StrTo) Int32() (int32, error) { + v, err := strconv.ParseInt(f.String(), 10, 32) + return int32(v), err +} + +func (f StrTo) Int64() (int64, error) { + v, err := strconv.ParseInt(f.String(), 10, 64) + return int64(v), err +} + +func (f StrTo) Uint16() (uint16, error) { + v, err := strconv.ParseUint(f.String(), 10, 16) + return uint16(v), err +} + +func (f StrTo) Uint32() (uint32, error) { + v, err := strconv.ParseUint(f.String(), 10, 32) + return uint32(v), err +} + +func (f StrTo) Uint64() (uint64, error) { + v, err := strconv.ParseUint(f.String(), 10, 64) + return uint64(v), err +} + +func (f StrTo) String() string { + if f.Exist() { + return string(f) + } + return "" +} + +func ToStr(value interface{}, args ...int) (s string) { + switch v := value.(type) { + case bool: + s = strconv.FormatBool(v) + case float32: + s = strconv.FormatFloat(float64(v), 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 32)) + case float64: + s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64)) + case int: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int16: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int32: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int64: + s = strconv.FormatInt(v, argInt(args).Get(0, 10)) + case uint: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint16: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint32: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint64: + s = strconv.FormatUint(v, argInt(args).Get(0, 10)) + case string: + s = v + default: + s = fmt.Sprintf("%v", v) + } + return s +} + +func snakeString(s string) string { + data := make([]byte, 0, len(s)*2) + j := false + num := len(s) + for i := 0; i < num; i++ { + d := s[i] + if i > 0 && d >= 'A' && d <= 'Z' && j { + data = append(data, '_') + } + if d != '_' { + j = true + } + data = append(data, d) + } + return strings.ToLower(string(data[:len(data)])) +} + +func camelString(s string) string { + data := make([]byte, 0, len(s)) + j := false + k := false + num := len(s) - 1 + for i := 0; i <= num; i++ { + d := s[i] + if k == false && d >= 'A' && d <= 'Z' { + k = true + } + if d >= 'a' && d <= 'z' && (j || k == false) { + d = d - 32 + j = false + k = true + } + if k && d == '_' && num > i && s[i+1] >= 'a' && s[i+1] <= 'z' { + j = true + continue + } + data = append(data, d) + } + return string(data[:len(data)]) +} + +type argString []string + +func (a argString) Get(i int, args ...string) (r string) { + if i >= 0 && i < len(a) { + r = a[i] + } else if len(args) > 0 { + r = args[0] + } + return +} + +type argInt []int + +func (a argInt) Get(i int, args ...int) (r int) { + if i >= 0 && i < len(a) { + r = a[i] + } + if len(args) > 0 { + r = args[0] + } + return +} + +func timeParse(dateString, format string) (time.Time, error) { + tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc) + return tp, err +} + +func timeFormat(t time.Time, format string) string { + return t.Format(format) +}