1
0
mirror of https://github.com/astaxie/beego.git synced 2024-12-22 21:50:50 +00:00

Move orm to pkg/orm

This commit is contained in:
Ming Deng 2020-07-15 10:04:22 +08:00
parent 3db31385cf
commit ffe1d52120
33 changed files with 13229 additions and 0 deletions

159
pkg/orm/README.md Normal file
View File

@ -0,0 +1,159 @@
# beego orm
[![Build Status](https://drone.io/github.com/astaxie/beego/status.png)](https://drone.io/github.com/astaxie/beego/latest)
A powerful orm framework for go.
It is heavily influenced by Django ORM, SQLAlchemy.
**Support Database:**
* MySQL: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql)
* PostgreSQL: [github.com/lib/pq](https://github.com/lib/pq)
* Sqlite3: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3)
Passed all test, but need more feedback.
**Features:**
* full go type support
* easy for usage, simple CRUD operation
* auto join with relation table
* cross DataBase compatible query
* Raw SQL query / mapper without orm model
* full test keep stable and strong
more features please read the docs
**Install:**
go get github.com/astaxie/beego/orm
## Changelog
* 2013-08-19: support table auto create
* 2013-08-13: update test for database types
* 2013-08-13: go type support, such as int8, uint8, byte, rune
* 2013-08-13: date / datetime timezone support very well
## Quick Start
#### Simple Usage
```go
package main
import (
"fmt"
"github.com/astaxie/beego/orm"
_ "github.com/go-sql-driver/mysql" // import your used driver
)
// Model Struct
type User struct {
Id int `orm:"auto"`
Name string `orm:"size(100)"`
}
func init() {
// register model
orm.RegisterModel(new(User))
// set default database
orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30)
// create table
orm.RunSyncdb("default", false, true)
}
func main() {
o := orm.NewOrm()
user := User{Name: "slene"}
// insert
id, err := o.Insert(&user)
// update
user.Name = "astaxie"
num, err := o.Update(&user)
// read one
u := User{Id: user.Id}
err = o.Read(&u)
// delete
num, err = o.Delete(&u)
}
```
#### Next with relation
```go
type Post struct {
Id int `orm:"auto"`
Title string `orm:"size(100)"`
User *User `orm:"rel(fk)"`
}
var posts []*Post
qs := o.QueryTable("post")
num, err := qs.Filter("User__Name", "slene").All(&posts)
```
#### Use Raw sql
If you don't like ORMuse Raw SQL to query / mapping without ORM setting
```go
var maps []Params
num, err := o.Raw("SELECT id FROM user WHERE name = ?", "slene").Values(&maps)
if num > 0 {
fmt.Println(maps[0]["id"])
}
```
#### Transaction
```go
o.Begin()
...
user := User{Name: "slene"}
id, err := o.Insert(&user)
if err == nil {
o.Commit()
} else {
o.Rollback()
}
```
#### Debug Log Queries
In development env, you can simple use
```go
func main() {
orm.Debug = true
...
```
enable log queries.
output include all queries, such as exec / prepare / transaction.
like this:
```go
[ORM] - 2013-08-09 13:18:16 - [Queries/default] - [ db.Exec / 0.4ms] - [INSERT INTO `user` (`name`) VALUES (?)] - `slene`
...
```
note: not recommend use this in product env.
## Docs
more details and examples in docs and test
[documents](http://beego.me/docs/mvc/model/overview.md)

283
pkg/orm/cmd.go Normal file
View File

@ -0,0 +1,283 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"flag"
"fmt"
"os"
"strings"
)
type commander interface {
Parse([]string)
Run() error
}
var (
commands = make(map[string]commander)
)
// print help.
func printHelp(errs ...string) {
content := `orm command usage:
syncdb - auto create tables
sqlall - print sql of create tables
help - print this help
`
if len(errs) > 0 {
fmt.Println(errs[0])
}
fmt.Println(content)
os.Exit(2)
}
// RunCommand listen for orm command and then run it if command arguments passed.
func RunCommand() {
if len(os.Args) < 2 || os.Args[1] != "orm" {
return
}
BootStrap()
args := argString(os.Args[2:])
name := args.Get(0)
if name == "help" {
printHelp()
}
if cmd, ok := commands[name]; ok {
cmd.Parse(os.Args[3:])
cmd.Run()
os.Exit(0)
} else {
if name == "" {
printHelp()
} else {
printHelp(fmt.Sprintf("unknown command %s", name))
}
}
}
// sync database struct command interface.
type commandSyncDb struct {
al *alias
force bool
verbose bool
noInfo bool
rtOnError bool
}
// parse orm command line arguments.
func (d *commandSyncDb) Parse(args []string) {
var name string
flagSet := flag.NewFlagSet("orm command: syncdb", flag.ExitOnError)
flagSet.StringVar(&name, "db", "default", "DataBase alias name")
flagSet.BoolVar(&d.force, "force", false, "drop tables before create")
flagSet.BoolVar(&d.verbose, "v", false, "verbose info")
flagSet.Parse(args)
d.al = getDbAlias(name)
}
// run orm line command.
func (d *commandSyncDb) Run() error {
var drops []string
if d.force {
drops = getDbDropSQL(d.al)
}
db := d.al.DB
if d.force {
for i, mi := range modelCache.allOrdered() {
query := drops[i]
if !d.noInfo {
fmt.Printf("drop table `%s`\n", mi.table)
}
_, err := db.Exec(query)
if d.verbose {
fmt.Printf(" %s\n\n", query)
}
if err != nil {
if d.rtOnError {
return err
}
fmt.Printf(" %s\n", err.Error())
}
}
}
sqls, indexes := getDbCreateSQL(d.al)
tables, err := d.al.DbBaser.GetTables(db)
if err != nil {
if d.rtOnError {
return err
}
fmt.Printf(" %s\n", err.Error())
}
for i, mi := range modelCache.allOrdered() {
if tables[mi.table] {
if !d.noInfo {
fmt.Printf("table `%s` already exists, skip\n", mi.table)
}
var fields []*fieldInfo
columns, err := d.al.DbBaser.GetColumns(db, mi.table)
if err != nil {
if d.rtOnError {
return err
}
fmt.Printf(" %s\n", err.Error())
}
for _, fi := range mi.fields.fieldsDB {
if _, ok := columns[fi.column]; !ok {
fields = append(fields, fi)
}
}
for _, fi := range fields {
query := getColumnAddQuery(d.al, fi)
if !d.noInfo {
fmt.Printf("add column `%s` for table `%s`\n", fi.fullName, mi.table)
}
_, err := db.Exec(query)
if d.verbose {
fmt.Printf(" %s\n", query)
}
if err != nil {
if d.rtOnError {
return err
}
fmt.Printf(" %s\n", err.Error())
}
}
for _, idx := range indexes[mi.table] {
if !d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) {
if !d.noInfo {
fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
}
query := idx.SQL
_, err := db.Exec(query)
if d.verbose {
fmt.Printf(" %s\n", query)
}
if err != nil {
if d.rtOnError {
return err
}
fmt.Printf(" %s\n", err.Error())
}
}
}
continue
}
if !d.noInfo {
fmt.Printf("create table `%s` \n", mi.table)
}
queries := []string{sqls[i]}
for _, idx := range indexes[mi.table] {
queries = append(queries, idx.SQL)
}
for _, query := range queries {
_, err := db.Exec(query)
if d.verbose {
query = " " + strings.Join(strings.Split(query, "\n"), "\n ")
fmt.Println(query)
}
if err != nil {
if d.rtOnError {
return err
}
fmt.Printf(" %s\n", err.Error())
}
}
if d.verbose {
fmt.Println("")
}
}
return nil
}
// database creation commander interface implement.
type commandSQLAll struct {
al *alias
}
// parse orm command line arguments.
func (d *commandSQLAll) Parse(args []string) {
var name string
flagSet := flag.NewFlagSet("orm command: sqlall", flag.ExitOnError)
flagSet.StringVar(&name, "db", "default", "DataBase alias name")
flagSet.Parse(args)
d.al = getDbAlias(name)
}
// run orm line command.
func (d *commandSQLAll) Run() error {
sqls, indexes := getDbCreateSQL(d.al)
var all []string
for i, mi := range modelCache.allOrdered() {
queries := []string{sqls[i]}
for _, idx := range indexes[mi.table] {
queries = append(queries, idx.SQL)
}
sql := strings.Join(queries, "\n")
all = append(all, sql)
}
fmt.Println(strings.Join(all, "\n\n"))
return nil
}
func init() {
commands["syncdb"] = new(commandSyncDb)
commands["sqlall"] = new(commandSQLAll)
}
// RunSyncdb run syncdb command line.
// name means table's alias name. default is "default".
// force means run next sql if the current is error.
// verbose means show all info when running command or not.
func RunSyncdb(name string, force bool, verbose bool) error {
BootStrap()
al := getDbAlias(name)
cmd := new(commandSyncDb)
cmd.al = al
cmd.force = force
cmd.noInfo = !verbose
cmd.verbose = verbose
cmd.rtOnError = true
return cmd.Run()
}

320
pkg/orm/cmd_utils.go Normal file
View File

@ -0,0 +1,320 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"fmt"
"os"
"strings"
)
type dbIndex struct {
Table string
Name string
SQL string
}
// create database drop sql.
func getDbDropSQL(al *alias) (sqls []string) {
if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model")
os.Exit(2)
}
Q := al.DbBaser.TableQuote()
for _, mi := range modelCache.allOrdered() {
sqls = append(sqls, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q))
}
return sqls
}
// get database column type string.
func getColumnTyp(al *alias, fi *fieldInfo) (col string) {
T := al.DbBaser.DbTypes()
fieldType := fi.fieldType
fieldSize := fi.size
checkColumn:
switch fieldType {
case TypeBooleanField:
col = T["bool"]
case TypeVarCharField:
if al.Driver == DRPostgres && fi.toText {
col = T["string-text"]
} else {
col = fmt.Sprintf(T["string"], fieldSize)
}
case TypeCharField:
col = fmt.Sprintf(T["string-char"], fieldSize)
case TypeTextField:
col = T["string-text"]
case TypeTimeField:
col = T["time.Time-clock"]
case TypeDateField:
col = T["time.Time-date"]
case TypeDateTimeField:
col = T["time.Time"]
case TypeBitField:
col = T["int8"]
case TypeSmallIntegerField:
col = T["int16"]
case TypeIntegerField:
col = T["int32"]
case TypeBigIntegerField:
if al.Driver == DRSqlite {
fieldType = TypeIntegerField
goto checkColumn
}
col = T["int64"]
case TypePositiveBitField:
col = T["uint8"]
case TypePositiveSmallIntegerField:
col = T["uint16"]
case TypePositiveIntegerField:
col = T["uint32"]
case TypePositiveBigIntegerField:
col = T["uint64"]
case TypeFloatField:
col = T["float64"]
case TypeDecimalField:
s := T["float64-decimal"]
if !strings.Contains(s, "%d") {
col = s
} else {
col = fmt.Sprintf(s, fi.digits, fi.decimals)
}
case TypeJSONField:
if al.Driver != DRPostgres {
fieldType = TypeVarCharField
goto checkColumn
}
col = T["json"]
case TypeJsonbField:
if al.Driver != DRPostgres {
fieldType = TypeVarCharField
goto checkColumn
}
col = T["jsonb"]
case RelForeignKey, RelOneToOne:
fieldType = fi.relModelInfo.fields.pk.fieldType
fieldSize = fi.relModelInfo.fields.pk.size
goto checkColumn
}
return
}
// create alter sql string.
func getColumnAddQuery(al *alias, fi *fieldInfo) string {
Q := al.DbBaser.TableQuote()
typ := getColumnTyp(al, fi)
if !fi.null {
typ += " " + "NOT NULL"
}
return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s",
Q, fi.mi.table, Q,
Q, fi.column, Q,
typ, getColumnDefault(fi),
)
}
// create database creation string.
func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model")
os.Exit(2)
}
Q := al.DbBaser.TableQuote()
T := al.DbBaser.DbTypes()
sep := fmt.Sprintf("%s, %s", Q, Q)
tableIndexes = make(map[string][]dbIndex)
for _, mi := range modelCache.allOrdered() {
sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.fullName)
sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q)
columns := make([]string, 0, len(mi.fields.fieldsDB))
sqlIndexes := [][]string{}
for _, fi := range mi.fields.fieldsDB {
column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q)
col := getColumnTyp(al, fi)
if fi.auto {
switch al.Driver {
case DRSqlite, DRPostgres:
column += T["auto"]
default:
column += col + " " + T["auto"]
}
} else if fi.pk {
column += col + " " + T["pk"]
} else {
column += col
if !fi.null {
column += " " + "NOT NULL"
}
//if fi.initial.String() != "" {
// column += " DEFAULT " + fi.initial.String()
//}
// Append attribute DEFAULT
column += getColumnDefault(fi)
if fi.unique {
column += " " + "UNIQUE"
}
if fi.index {
sqlIndexes = append(sqlIndexes, []string{fi.column})
}
}
if strings.Contains(column, "%COL%") {
column = strings.Replace(column, "%COL%", fi.column, -1)
}
if fi.description != "" && al.Driver!=DRSqlite {
column += " " + fmt.Sprintf("COMMENT '%s'",fi.description)
}
columns = append(columns, column)
}
if mi.model != nil {
allnames := getTableUnique(mi.addrField)
if !mi.manual && len(mi.uniques) > 0 {
allnames = append(allnames, mi.uniques)
}
for _, names := range allnames {
cols := make([]string, 0, len(names))
for _, name := range names {
if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol {
cols = append(cols, fi.column)
} else {
panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.fullName))
}
}
column := fmt.Sprintf(" UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q)
columns = append(columns, column)
}
}
sql += strings.Join(columns, ",\n")
sql += "\n)"
if al.Driver == DRMySQL {
var engine string
if mi.model != nil {
engine = getTableEngine(mi.addrField)
}
if engine == "" {
engine = al.Engine
}
sql += " ENGINE=" + engine
}
sql += ";"
sqls = append(sqls, sql)
if mi.model != nil {
for _, names := range getTableIndex(mi.addrField) {
cols := make([]string, 0, len(names))
for _, name := range names {
if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol {
cols = append(cols, fi.column)
} else {
panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.fullName))
}
}
sqlIndexes = append(sqlIndexes, cols)
}
}
for _, names := range sqlIndexes {
name := mi.table + "_" + strings.Join(names, "_")
cols := strings.Join(names, sep)
sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q)
index := dbIndex{}
index.Table = mi.table
index.Name = name
index.SQL = sql
tableIndexes[mi.table] = append(tableIndexes[mi.table], index)
}
}
return
}
// Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands
func getColumnDefault(fi *fieldInfo) string {
var (
v, t, d string
)
// Skip default attribute if field is in relations
if fi.rel || fi.reverse {
return v
}
t = " DEFAULT '%s' "
// These defaults will be useful if there no config value orm:"default" and NOT NULL is on
switch fi.fieldType {
case TypeTimeField, TypeDateField, TypeDateTimeField, TypeTextField:
return v
case TypeBitField, TypeSmallIntegerField, TypeIntegerField,
TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField,
TypePositiveIntegerField, TypePositiveBigIntegerField, TypeFloatField,
TypeDecimalField:
t = " DEFAULT %s "
d = "0"
case TypeBooleanField:
t = " DEFAULT %s "
d = "FALSE"
case TypeJSONField, TypeJsonbField:
d = "{}"
}
if fi.colDefault {
if !fi.initial.Exist() {
v = fmt.Sprintf(t, "")
} else {
v = fmt.Sprintf(t, fi.initial.String())
}
} else {
if !fi.null {
v = fmt.Sprintf(t, d)
}
}
return v
}

1902
pkg/orm/db.go Normal file

File diff suppressed because it is too large Load Diff

466
pkg/orm/db_alias.go Normal file
View File

@ -0,0 +1,466 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"context"
"database/sql"
"fmt"
lru "github.com/hashicorp/golang-lru"
"reflect"
"sync"
"time"
)
// DriverType database driver constant int.
type DriverType int
// Enum the Database driver
const (
_ DriverType = iota // int enum type
DRMySQL // mysql
DRSqlite // sqlite
DROracle // oracle
DRPostgres // pgsql
DRTiDB // TiDB
)
// database driver string.
type driver string
// get type constant int of current driver..
func (d driver) Type() DriverType {
a, _ := dataBaseCache.get(string(d))
return a.Driver
}
// get name of current driver
func (d driver) Name() string {
return string(d)
}
// check driver iis implemented Driver interface or not.
var _ Driver = new(driver)
var (
dataBaseCache = &_dbCache{cache: make(map[string]*alias)}
drivers = map[string]DriverType{
"mysql": DRMySQL,
"postgres": DRPostgres,
"sqlite3": DRSqlite,
"tidb": DRTiDB,
"oracle": DROracle,
"oci8": DROracle, // github.com/mattn/go-oci8
"ora": DROracle, //https://github.com/rana/ora
}
dbBasers = map[DriverType]dbBaser{
DRMySQL: newdbBaseMysql(),
DRSqlite: newdbBaseSqlite(),
DROracle: newdbBaseOracle(),
DRPostgres: newdbBasePostgres(),
DRTiDB: newdbBaseTidb(),
}
)
// database alias cacher.
type _dbCache struct {
mux sync.RWMutex
cache map[string]*alias
}
// add database alias with original name.
func (ac *_dbCache) add(name string, al *alias) (added bool) {
ac.mux.Lock()
defer ac.mux.Unlock()
if _, ok := ac.cache[name]; !ok {
ac.cache[name] = al
added = true
}
return
}
// get database alias if cached.
func (ac *_dbCache) get(name string) (al *alias, ok bool) {
ac.mux.RLock()
defer ac.mux.RUnlock()
al, ok = ac.cache[name]
return
}
// get default alias.
func (ac *_dbCache) getDefault() (al *alias) {
al, _ = ac.get("default")
return
}
type DB struct {
*sync.RWMutex
DB *sql.DB
stmtDecorators *lru.Cache
}
func (d *DB) Begin() (*sql.Tx, error) {
return d.DB.Begin()
}
func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
return d.DB.BeginTx(ctx, opts)
}
//su must call release to release *sql.Stmt after using
func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) {
d.RLock()
c, ok := d.stmtDecorators.Get(query)
if ok {
c.(*stmtDecorator).acquire()
d.RUnlock()
return c.(*stmtDecorator), nil
}
d.RUnlock()
d.Lock()
c, ok = d.stmtDecorators.Get(query)
if ok {
c.(*stmtDecorator).acquire()
d.Unlock()
return c.(*stmtDecorator), nil
}
stmt, err := d.Prepare(query)
if err != nil {
d.Unlock()
return nil, err
}
sd := newStmtDecorator(stmt)
sd.acquire()
d.stmtDecorators.Add(query, sd)
d.Unlock()
return sd, nil
}
func (d *DB) Prepare(query string) (*sql.Stmt, error) {
return d.DB.Prepare(query)
}
func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
return d.DB.PrepareContext(ctx, query)
}
func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) {
sd, err := d.getStmtDecorator(query)
if err != nil {
return nil, err
}
stmt := sd.getStmt()
defer sd.release()
return stmt.Exec(args...)
}
func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
sd, err := d.getStmtDecorator(query)
if err != nil {
return nil, err
}
stmt := sd.getStmt()
defer sd.release()
return stmt.ExecContext(ctx, args...)
}
func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) {
sd, err := d.getStmtDecorator(query)
if err != nil {
return nil, err
}
stmt := sd.getStmt()
defer sd.release()
return stmt.Query(args...)
}
func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
sd, err := d.getStmtDecorator(query)
if err != nil {
return nil, err
}
stmt := sd.getStmt()
defer sd.release()
return stmt.QueryContext(ctx, args...)
}
func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row {
sd, err := d.getStmtDecorator(query)
if err != nil {
panic(err)
}
stmt := sd.getStmt()
defer sd.release()
return stmt.QueryRow(args...)
}
func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
sd, err := d.getStmtDecorator(query)
if err != nil {
panic(err)
}
stmt := sd.getStmt()
defer sd.release()
return stmt.QueryRowContext(ctx, args)
}
type alias struct {
Name string
Driver DriverType
DriverName string
DataSource string
MaxIdleConns int
MaxOpenConns int
DB *DB
DbBaser dbBaser
TZ *time.Location
Engine string
}
func detectTZ(al *alias) {
// orm timezone system match database
// default use Local
al.TZ = DefaultTimeLoc
if al.DriverName == "sphinx" {
return
}
switch al.Driver {
case DRMySQL:
row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
var tz string
row.Scan(&tz)
if len(tz) >= 8 {
if tz[0] != '-' {
tz = "+" + tz
}
t, err := time.Parse("-07:00:00", tz)
if err == nil {
if t.Location().String() != "" {
al.TZ = t.Location()
}
} else {
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
}
}
// get default engine from current database
row = al.DB.QueryRow("SELECT ENGINE, TRANSACTIONS FROM information_schema.engines WHERE SUPPORT = 'DEFAULT'")
var engine string
var tx bool
row.Scan(&engine, &tx)
if engine != "" {
al.Engine = engine
} else {
al.Engine = "INNODB"
}
case DRSqlite, DROracle:
al.TZ = time.UTC
case DRPostgres:
row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')")
var tz string
row.Scan(&tz)
loc, err := time.LoadLocation(tz)
if err == nil {
al.TZ = loc
} else {
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
}
}
}
func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
al := new(alias)
al.Name = aliasName
al.DriverName = driverName
al.DB = &DB{
RWMutex: new(sync.RWMutex),
DB: db,
stmtDecorators: newStmtDecoratorLruWithEvict(),
}
if dr, ok := drivers[driverName]; ok {
al.DbBaser = dbBasers[dr]
al.Driver = dr
} else {
return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
}
err := db.Ping()
if err != nil {
return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error())
}
if !dataBaseCache.add(aliasName, al) {
return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName)
}
return al, nil
}
// AddAliasWthDB add a aliasName for the drivename
func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error {
_, err := addAliasWthDB(aliasName, driverName, db)
return err
}
// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args.
func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error {
var (
err error
db *sql.DB
al *alias
)
db, err = sql.Open(driverName, dataSource)
if err != nil {
err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
goto end
}
al, err = addAliasWthDB(aliasName, driverName, db)
if err != nil {
goto end
}
al.DataSource = dataSource
detectTZ(al)
for i, v := range params {
switch i {
case 0:
SetMaxIdleConns(al.Name, v)
case 1:
SetMaxOpenConns(al.Name, v)
}
}
end:
if err != nil {
if db != nil {
db.Close()
}
DebugLog.Println(err.Error())
}
return err
}
// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type.
func RegisterDriver(driverName string, typ DriverType) error {
if t, ok := drivers[driverName]; !ok {
drivers[driverName] = typ
} else {
if t != typ {
return fmt.Errorf("driverName `%s` db driver already registered and is other type", driverName)
}
}
return nil
}
// SetDataBaseTZ Change the database default used timezone
func SetDataBaseTZ(aliasName string, tz *time.Location) error {
if al, ok := dataBaseCache.get(aliasName); ok {
al.TZ = tz
} else {
return fmt.Errorf("DataBase alias name `%s` not registered", aliasName)
}
return nil
}
// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name
func SetMaxIdleConns(aliasName string, maxIdleConns int) {
al := getDbAlias(aliasName)
al.MaxIdleConns = maxIdleConns
al.DB.DB.SetMaxIdleConns(maxIdleConns)
}
// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name
func SetMaxOpenConns(aliasName string, maxOpenConns int) {
al := getDbAlias(aliasName)
al.MaxOpenConns = maxOpenConns
al.DB.DB.SetMaxOpenConns(maxOpenConns)
// for tip go 1.2
if fun := reflect.ValueOf(al.DB).MethodByName("SetMaxOpenConns"); fun.IsValid() {
fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)})
}
}
// GetDB Get *sql.DB from registered database by db alias name.
// Use "default" as alias name if you not set.
func GetDB(aliasNames ...string) (*sql.DB, error) {
var name string
if len(aliasNames) > 0 {
name = aliasNames[0]
} else {
name = "default"
}
al, ok := dataBaseCache.get(name)
if ok {
return al.DB.DB, nil
}
return nil, fmt.Errorf("DataBase of alias name `%s` not found", name)
}
type stmtDecorator struct {
wg sync.WaitGroup
stmt *sql.Stmt
}
func (s *stmtDecorator) getStmt() *sql.Stmt {
return s.stmt
}
// acquire will add one
// since this method will be used inside read lock scope,
// so we can not do more things here
// we should think about refactor this
func (s *stmtDecorator) acquire() {
s.wg.Add(1)
}
func (s *stmtDecorator) release() {
s.wg.Done()
}
//garbage recycle for stmt
func (s *stmtDecorator) destroy() {
go func() {
s.wg.Wait()
_ = s.stmt.Close()
}()
}
func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator {
return &stmtDecorator{
stmt: sqlStmt,
}
}
func newStmtDecoratorLruWithEvict() *lru.Cache {
cache, _ := lru.NewWithEvict(1000, func(key interface{}, value interface{}) {
value.(*stmtDecorator).destroy()
})
return cache
}

183
pkg/orm/db_mysql.go Normal file
View File

@ -0,0 +1,183 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"fmt"
"reflect"
"strings"
)
// mysql operators.
var mysqlOperators = map[string]string{
"exact": "= ?",
"iexact": "LIKE ?",
"contains": "LIKE BINARY ?",
"icontains": "LIKE ?",
// "regex": "REGEXP BINARY ?",
// "iregex": "REGEXP ?",
"gt": "> ?",
"gte": ">= ?",
"lt": "< ?",
"lte": "<= ?",
"eq": "= ?",
"ne": "!= ?",
"startswith": "LIKE BINARY ?",
"endswith": "LIKE BINARY ?",
"istartswith": "LIKE ?",
"iendswith": "LIKE ?",
}
// mysql column field types.
var mysqlTypes = map[string]string{
"auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY",
"pk": "NOT NULL PRIMARY KEY",
"bool": "bool",
"string": "varchar(%d)",
"string-char": "char(%d)",
"string-text": "longtext",
"time.Time-date": "date",
"time.Time": "datetime",
"int8": "tinyint",
"int16": "smallint",
"int32": "integer",
"int64": "bigint",
"uint8": "tinyint unsigned",
"uint16": "smallint unsigned",
"uint32": "integer unsigned",
"uint64": "bigint unsigned",
"float64": "double precision",
"float64-decimal": "numeric(%d, %d)",
}
// mysql dbBaser implementation.
type dbBaseMysql struct {
dbBase
}
var _ dbBaser = new(dbBaseMysql)
// get mysql operator.
func (d *dbBaseMysql) OperatorSQL(operator string) string {
return mysqlOperators[operator]
}
// get mysql table field types.
func (d *dbBaseMysql) DbTypes() map[string]string {
return mysqlTypes
}
// show table sql for mysql.
func (d *dbBaseMysql) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
}
// show columns sql of table for mysql.
func (d *dbBaseMysql) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
"WHERE table_schema = DATABASE() AND table_name = '%s'", table)
}
// execute sql to check index exist.
func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool {
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
var cnt int
row.Scan(&cnt)
return cnt > 0
}
// InsertOrUpdate a row
// If your primary key or unique column conflict will update
// If no will insert
// Add "`" for mysql sql building
func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
var iouStr string
argsMap := map[string]string{}
iouStr = "ON DUPLICATE KEY UPDATE"
//Get on the key-value pairs
for _, v := range args {
kv := strings.Split(v, "=")
if len(kv) == 2 {
argsMap[strings.ToLower(kv[0])] = kv[1]
}
}
isMulti := false
names := make([]string, 0, len(mi.fields.dbcols)-1)
Q := d.ins.TableQuote()
values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ)
if err != nil {
return 0, err
}
marks := make([]string, len(names))
updateValues := make([]interface{}, 0)
updates := make([]string, len(names))
for i, v := range names {
marks[i] = "?"
valueStr := argsMap[strings.ToLower(v)]
if valueStr != "" {
updates[i] = "`" + v + "`" + "=" + valueStr
} else {
updates[i] = "`" + v + "`" + "=?"
updateValues = append(updateValues, values[i])
}
}
values = append(values, updateValues...)
sep := fmt.Sprintf("%s, %s", Q, Q)
qmarks := strings.Join(marks, ", ")
qupdates := strings.Join(updates, ", ")
columns := strings.Join(names, sep)
multi := len(values) / len(names)
if isMulti {
qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
}
//conflitValue maybe is a int,can`t use fmt.Sprintf
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr)
d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) {
res, err := q.Exec(query, values...)
if err == nil {
if isMulti {
return res.RowsAffected()
}
return res.LastInsertId()
}
return 0, err
}
row := q.QueryRow(query, values...)
var id int64
err = row.Scan(&id)
return id, err
}
// create new mysql dbBaser.
func newdbBaseMysql() dbBaser {
b := new(dbBaseMysql)
b.ins = b
return b
}

137
pkg/orm/db_oracle.go Normal file
View File

@ -0,0 +1,137 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"fmt"
"strings"
)
// oracle operators.
var oracleOperators = map[string]string{
"exact": "= ?",
"gt": "> ?",
"gte": ">= ?",
"lt": "< ?",
"lte": "<= ?",
"//iendswith": "LIKE ?",
}
// oracle column field types.
var oracleTypes = map[string]string{
"pk": "NOT NULL PRIMARY KEY",
"bool": "bool",
"string": "VARCHAR2(%d)",
"string-char": "CHAR(%d)",
"string-text": "VARCHAR2(%d)",
"time.Time-date": "DATE",
"time.Time": "TIMESTAMP",
"int8": "INTEGER",
"int16": "INTEGER",
"int32": "INTEGER",
"int64": "INTEGER",
"uint8": "INTEGER",
"uint16": "INTEGER",
"uint32": "INTEGER",
"uint64": "INTEGER",
"float64": "NUMBER",
"float64-decimal": "NUMBER(%d, %d)",
}
// oracle dbBaser
type dbBaseOracle struct {
dbBase
}
var _ dbBaser = new(dbBaseOracle)
// create oracle dbBaser.
func newdbBaseOracle() dbBaser {
b := new(dbBaseOracle)
b.ins = b
return b
}
// OperatorSQL get oracle operator.
func (d *dbBaseOracle) OperatorSQL(operator string) string {
return oracleOperators[operator]
}
// DbTypes get oracle table field types.
func (d *dbBaseOracle) DbTypes() map[string]string {
return oracleTypes
}
//ShowTablesQuery show all the tables in database
func (d *dbBaseOracle) ShowTablesQuery() string {
return "SELECT TABLE_NAME FROM USER_TABLES"
}
// Oracle
func (d *dbBaseOracle) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT COLUMN_NAME FROM ALL_TAB_COLUMNS "+
"WHERE TABLE_NAME ='%s'", strings.ToUpper(table))
}
// check index is exist
func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool {
row := db.QueryRow("SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+
"WHERE USER_IND_COLUMNS.INDEX_NAME = USER_INDEXES.INDEX_NAME "+
"AND USER_IND_COLUMNS.TABLE_NAME = ? AND USER_IND_COLUMNS.INDEX_NAME = ?", strings.ToUpper(table), strings.ToUpper(name))
var cnt int
row.Scan(&cnt)
return cnt > 0
}
// execute insert sql with given struct and given values.
// insert the given values, not the field values in struct.
func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
Q := d.ins.TableQuote()
marks := make([]string, len(names))
for i := range marks {
marks[i] = ":" + names[i]
}
sep := fmt.Sprintf("%s, %s", Q, Q)
qmarks := strings.Join(marks, ", ")
columns := strings.Join(names, sep)
multi := len(values) / len(names)
if isMulti {
qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
}
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks)
d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) {
res, err := q.Exec(query, values...)
if err == nil {
if isMulti {
return res.RowsAffected()
}
return res.LastInsertId()
}
return 0, err
}
row := q.QueryRow(query, values...)
var id int64
err := row.Scan(&id)
return id, err
}

189
pkg/orm/db_postgres.go Normal file
View File

@ -0,0 +1,189 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"fmt"
"strconv"
)
// postgresql operators.
var postgresOperators = map[string]string{
"exact": "= ?",
"iexact": "= UPPER(?)",
"contains": "LIKE ?",
"icontains": "LIKE UPPER(?)",
"gt": "> ?",
"gte": ">= ?",
"lt": "< ?",
"lte": "<= ?",
"eq": "= ?",
"ne": "!= ?",
"startswith": "LIKE ?",
"endswith": "LIKE ?",
"istartswith": "LIKE UPPER(?)",
"iendswith": "LIKE UPPER(?)",
}
// postgresql column field types.
var postgresTypes = map[string]string{
"auto": "serial NOT NULL PRIMARY KEY",
"pk": "NOT NULL PRIMARY KEY",
"bool": "bool",
"string": "varchar(%d)",
"string-char": "char(%d)",
"string-text": "text",
"time.Time-date": "date",
"time.Time": "timestamp with time zone",
"int8": `smallint CHECK("%COL%" >= -127 AND "%COL%" <= 128)`,
"int16": "smallint",
"int32": "integer",
"int64": "bigint",
"uint8": `smallint CHECK("%COL%" >= 0 AND "%COL%" <= 255)`,
"uint16": `integer CHECK("%COL%" >= 0)`,
"uint32": `bigint CHECK("%COL%" >= 0)`,
"uint64": `bigint CHECK("%COL%" >= 0)`,
"float64": "double precision",
"float64-decimal": "numeric(%d, %d)",
"json": "json",
"jsonb": "jsonb",
}
// postgresql dbBaser.
type dbBasePostgres struct {
dbBase
}
var _ dbBaser = new(dbBasePostgres)
// get postgresql operator.
func (d *dbBasePostgres) OperatorSQL(operator string) string {
return postgresOperators[operator]
}
// generate functioned sql string, such as contains(text).
func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
switch operator {
case "contains", "startswith", "endswith":
*leftCol = fmt.Sprintf("%s::text", *leftCol)
case "iexact", "icontains", "istartswith", "iendswith":
*leftCol = fmt.Sprintf("UPPER(%s::text)", *leftCol)
}
}
// postgresql unsupports updating joined record.
func (d *dbBasePostgres) SupportUpdateJoin() bool {
return false
}
func (d *dbBasePostgres) MaxLimit() uint64 {
return 0
}
// postgresql quote is ".
func (d *dbBasePostgres) TableQuote() string {
return `"`
}
// postgresql value placeholder is $n.
// replace default ? to $n.
func (d *dbBasePostgres) ReplaceMarks(query *string) {
q := *query
num := 0
for _, c := range q {
if c == '?' {
num++
}
}
if num == 0 {
return
}
data := make([]byte, 0, len(q)+num)
num = 1
for i := 0; i < len(q); i++ {
c := q[i]
if c == '?' {
data = append(data, '$')
data = append(data, []byte(strconv.Itoa(num))...)
num++
} else {
data = append(data, c)
}
}
*query = string(data)
}
// make returning sql support for postgresql.
func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool {
fi := mi.fields.pk
if fi.fieldType&IsPositiveIntegerField == 0 && fi.fieldType&IsIntegerField == 0 {
return false
}
if query != nil {
*query = fmt.Sprintf(`%s RETURNING "%s"`, *query, fi.column)
}
return true
}
// sync auto key
func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error {
if len(autoFields) == 0 {
return nil
}
Q := d.ins.TableQuote()
for _, name := range autoFields {
query := fmt.Sprintf("SELECT setval(pg_get_serial_sequence('%s', '%s'), (SELECT MAX(%s%s%s) FROM %s%s%s));",
mi.table, name,
Q, name, Q,
Q, mi.table, Q)
if _, err := db.Exec(query); err != nil {
return err
}
}
return nil
}
// show table sql for postgresql.
func (d *dbBasePostgres) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')"
}
// show table columns sql for postgresql.
func (d *dbBasePostgres) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table)
}
// get column types of postgresql.
func (d *dbBasePostgres) DbTypes() map[string]string {
return postgresTypes
}
// check index exist in postgresql.
func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool {
query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name)
row := db.QueryRow(query)
var cnt int
row.Scan(&cnt)
return cnt > 0
}
// create new postgresql dbBaser.
func newdbBasePostgres() dbBaser {
b := new(dbBasePostgres)
b.ins = b
return b
}

161
pkg/orm/db_sqlite.go Normal file
View File

@ -0,0 +1,161 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"database/sql"
"fmt"
"reflect"
"time"
)
// sqlite operators.
var sqliteOperators = map[string]string{
"exact": "= ?",
"iexact": "LIKE ? ESCAPE '\\'",
"contains": "LIKE ? ESCAPE '\\'",
"icontains": "LIKE ? ESCAPE '\\'",
"gt": "> ?",
"gte": ">= ?",
"lt": "< ?",
"lte": "<= ?",
"eq": "= ?",
"ne": "!= ?",
"startswith": "LIKE ? ESCAPE '\\'",
"endswith": "LIKE ? ESCAPE '\\'",
"istartswith": "LIKE ? ESCAPE '\\'",
"iendswith": "LIKE ? ESCAPE '\\'",
}
// sqlite column types.
var sqliteTypes = map[string]string{
"auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT",
"pk": "NOT NULL PRIMARY KEY",
"bool": "bool",
"string": "varchar(%d)",
"string-char": "character(%d)",
"string-text": "text",
"time.Time-date": "date",
"time.Time": "datetime",
"int8": "tinyint",
"int16": "smallint",
"int32": "integer",
"int64": "bigint",
"uint8": "tinyint unsigned",
"uint16": "smallint unsigned",
"uint32": "integer unsigned",
"uint64": "bigint unsigned",
"float64": "real",
"float64-decimal": "decimal",
}
// sqlite dbBaser.
type dbBaseSqlite struct {
dbBase
}
var _ dbBaser = new(dbBaseSqlite)
// override base db read for update behavior as SQlite does not support syntax
func (d *dbBaseSqlite) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
if isForUpdate {
DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work")
}
return d.dbBase.Read(q, mi, ind, tz, cols, false)
}
// get sqlite operator.
func (d *dbBaseSqlite) OperatorSQL(operator string) string {
return sqliteOperators[operator]
}
// generate functioned sql for sqlite.
// only support DATE(text).
func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
if fi.fieldType == TypeDateField {
*leftCol = fmt.Sprintf("DATE(%s)", *leftCol)
}
}
// unable updating joined record in sqlite.
func (d *dbBaseSqlite) SupportUpdateJoin() bool {
return false
}
// max int in sqlite.
func (d *dbBaseSqlite) MaxLimit() uint64 {
return 9223372036854775807
}
// get column types in sqlite.
func (d *dbBaseSqlite) DbTypes() map[string]string {
return sqliteTypes
}
// get show tables sql in sqlite.
func (d *dbBaseSqlite) ShowTablesQuery() string {
return "SELECT name FROM sqlite_master WHERE type = 'table'"
}
// get columns in sqlite.
func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
query := d.ins.ShowColumnsQuery(table)
rows, err := db.Query(query)
if err != nil {
return nil, err
}
columns := make(map[string][3]string)
for rows.Next() {
var tmp, name, typ, null sql.NullString
err := rows.Scan(&tmp, &name, &typ, &null, &tmp, &tmp)
if err != nil {
return nil, err
}
columns[name.String] = [3]string{name.String, typ.String, null.String}
}
return columns, nil
}
// get show columns sql in sqlite.
func (d *dbBaseSqlite) ShowColumnsQuery(table string) string {
return fmt.Sprintf("pragma table_info('%s')", table)
}
// check index exist in sqlite.
func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool {
query := fmt.Sprintf("PRAGMA index_list('%s')", table)
rows, err := db.Query(query)
if err != nil {
panic(err)
}
defer rows.Close()
for rows.Next() {
var tmp, index sql.NullString
rows.Scan(&tmp, &index, &tmp, &tmp, &tmp)
if name == index.String {
return true
}
}
return false
}
// create new sqlite dbBaser.
func newdbBaseSqlite() dbBaser {
b := new(dbBaseSqlite)
b.ins = b
return b
}

482
pkg/orm/db_tables.go Normal file
View File

@ -0,0 +1,482 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"fmt"
"strings"
"time"
)
// table info struct.
type dbTable struct {
id int
index string
name string
names []string
sel bool
inner bool
mi *modelInfo
fi *fieldInfo
jtl *dbTable
}
// tables collection struct, contains some tables.
type dbTables struct {
tablesM map[string]*dbTable
tables []*dbTable
mi *modelInfo
base dbBaser
skipEnd bool
}
// set table info to collection.
// if not exist, create new.
func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
name := strings.Join(names, ExprSep)
if j, ok := t.tablesM[name]; ok {
j.name = name
j.mi = mi
j.fi = fi
j.inner = inner
} else {
i := len(t.tables) + 1
jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
t.tablesM[name] = jt
t.tables = append(t.tables, jt)
}
return t.tablesM[name]
}
// add table info to collection.
func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
name := strings.Join(names, ExprSep)
if _, ok := t.tablesM[name]; !ok {
i := len(t.tables) + 1
jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
t.tablesM[name] = jt
t.tables = append(t.tables, jt)
return jt, true
}
return t.tablesM[name], false
}
// get table info in collection.
func (t *dbTables) get(name string) (*dbTable, bool) {
j, ok := t.tablesM[name]
return j, ok
}
// get related fields info in recursive depth loop.
// loop once, depth decreases one.
func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
if depth < 0 || fi.fieldType == RelManyToMany {
return related
}
if prefix == "" {
prefix = fi.name
} else {
prefix = prefix + ExprSep + fi.name
}
related = append(related, prefix)
depth--
for _, fi := range fi.relModelInfo.fields.fieldsRel {
related = t.loopDepth(depth, prefix, fi, related)
}
return related
}
// parse related fields.
func (t *dbTables) parseRelated(rels []string, depth int) {
relsNum := len(rels)
related := make([]string, relsNum)
copy(related, rels)
relDepth := depth
if relsNum != 0 {
relDepth = 0
}
relDepth--
for _, fi := range t.mi.fields.fieldsRel {
related = t.loopDepth(relDepth, "", fi, related)
}
for i, s := range related {
var (
exs = strings.Split(s, ExprSep)
names = make([]string, 0, len(exs))
mmi = t.mi
cancel = true
jtl *dbTable
)
inner := true
for _, ex := range exs {
if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany {
names = append(names, fi.name)
mmi = fi.relModelInfo
if fi.null || t.skipEnd {
inner = false
}
jt := t.set(names, mmi, fi, inner)
jt.jtl = jtl
if fi.reverse {
cancel = false
}
if cancel {
jt.sel = depth > 0
if i < relsNum {
jt.sel = true
}
}
jtl = jt
} else {
panic(fmt.Errorf("unknown model/table name `%s`", ex))
}
}
}
}
// generate join string.
func (t *dbTables) getJoinSQL() (join string) {
Q := t.base.TableQuote()
for _, jt := range t.tables {
if jt.inner {
join += "INNER JOIN "
} else {
join += "LEFT OUTER JOIN "
}
var (
table string
t1, t2 string
c1, c2 string
)
t1 = "T0"
if jt.jtl != nil {
t1 = jt.jtl.index
}
t2 = jt.index
table = jt.mi.table
switch {
case jt.fi.fieldType == RelManyToMany || jt.fi.fieldType == RelReverseMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
c1 = jt.fi.mi.fields.pk.column
for _, ffi := range jt.mi.fields.fieldsRel {
if jt.fi.mi == ffi.relModelInfo {
c2 = ffi.column
break
}
}
default:
c1 = jt.fi.column
c2 = jt.fi.relModelInfo.fields.pk.column
if jt.fi.reverse {
c1 = jt.mi.fields.pk.column
c2 = jt.fi.reverseFieldInfo.column
}
}
join += fmt.Sprintf("%s%s%s %s ON %s.%s%s%s = %s.%s%s%s ", Q, table, Q, t2,
t2, Q, c2, Q, t1, Q, c1, Q)
}
return
}
// parse orm model struct field tag expression.
func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) {
var (
jtl *dbTable
fi *fieldInfo
fiN *fieldInfo
mmi = mi
)
num := len(exprs) - 1
var names []string
inner := true
loopFor:
for i, ex := range exprs {
var ok, okN bool
if fiN != nil {
fi = fiN
ok = true
fiN = nil
}
if i == 0 {
fi, ok = mmi.fields.GetByAny(ex)
}
_ = okN
if ok {
isRel := fi.rel || fi.reverse
names = append(names, fi.name)
switch {
case fi.rel:
mmi = fi.relModelInfo
if fi.fieldType == RelManyToMany {
mmi = fi.relThroughModelInfo
}
case fi.reverse:
mmi = fi.reverseFieldInfo.mi
}
if i < num {
fiN, okN = mmi.fields.GetByAny(exprs[i+1])
}
if isRel && (!fi.mi.isThrough || num != i) {
if fi.null || t.skipEnd {
inner = false
}
if t.skipEnd && okN || !t.skipEnd {
if t.skipEnd && okN && fiN.pk {
goto loopEnd
}
jt, _ := t.add(names, mmi, fi, inner)
jt.jtl = jtl
jtl = jt
}
}
if num != i {
continue
}
loopEnd:
if i == 0 || jtl == nil {
index = "T0"
} else {
index = jtl.index
}
info = fi
if jtl == nil {
name = fi.name
} else {
name = jtl.name + ExprSep + fi.name
}
switch {
case fi.rel:
case fi.reverse:
switch fi.reverseFieldInfo.fieldType {
case RelOneToOne, RelForeignKey:
index = jtl.index
info = fi.reverseFieldInfo.mi.fields.pk
name = info.name
}
}
break loopFor
} else {
index = ""
name = ""
info = nil
success = false
return
}
}
success = index != "" && info != nil
return
}
// generate condition sql.
func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
if cond == nil || cond.IsEmpty() {
return
}
Q := t.base.TableQuote()
mi := t.mi
for i, p := range cond.params {
if i > 0 {
if p.isOr {
where += "OR "
} else {
where += "AND "
}
}
if p.isNot {
where += "NOT "
}
if p.isCond {
w, ps := t.getCondSQL(p.cond, true, tz)
if w != "" {
w = fmt.Sprintf("( %s) ", w)
}
where += w
params = append(params, ps...)
} else {
exprs := p.exprs
num := len(exprs) - 1
operator := ""
if operators[exprs[num]] {
operator = exprs[num]
exprs = exprs[:num]
}
index, _, fi, suc := t.parseExprs(mi, exprs)
if !suc {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep)))
}
if operator == "" {
operator = "exact"
}
var operSQL string
var args []interface{}
if p.isRaw {
operSQL = p.sql
} else {
operSQL, args = t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz)
}
leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)
t.base.GenerateOperatorLeftCol(fi, operator, &leftCol)
where += fmt.Sprintf("%s %s ", leftCol, operSQL)
params = append(params, args...)
}
}
if !sub && where != "" {
where = "WHERE " + where
}
return
}
// generate group sql.
func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) {
if len(groups) == 0 {
return
}
Q := t.base.TableQuote()
groupSqls := make([]string, 0, len(groups))
for _, group := range groups {
exprs := strings.Split(group, ExprSep)
index, _, fi, suc := t.parseExprs(t.mi, exprs)
if !suc {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
}
groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q))
}
groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", "))
return
}
// generate order sql.
func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) {
if len(orders) == 0 {
return
}
Q := t.base.TableQuote()
orderSqls := make([]string, 0, len(orders))
for _, order := range orders {
asc := "ASC"
if order[0] == '-' {
asc = "DESC"
order = order[1:]
}
exprs := strings.Split(order, ExprSep)
index, _, fi, suc := t.parseExprs(t.mi, exprs)
if !suc {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
}
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc))
}
orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
return
}
// generate limit sql.
func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) {
if limit == 0 {
limit = int64(DefaultRowsLimit)
}
if limit < 0 {
// no limit
if offset > 0 {
maxLimit := t.base.MaxLimit()
if maxLimit == 0 {
limits = fmt.Sprintf("OFFSET %d", offset)
} else {
limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset)
}
}
} else if offset <= 0 {
limits = fmt.Sprintf("LIMIT %d", limit)
} else {
limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset)
}
return
}
// crete new tables collection.
func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
tables := &dbTables{}
tables.tablesM = make(map[string]*dbTable)
tables.mi = mi
tables.base = base
return tables
}

63
pkg/orm/db_tidb.go Normal file
View File

@ -0,0 +1,63 @@
// Copyright 2015 TiDB Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"fmt"
)
// mysql dbBaser implementation.
type dbBaseTidb struct {
dbBase
}
var _ dbBaser = new(dbBaseTidb)
// get mysql operator.
func (d *dbBaseTidb) OperatorSQL(operator string) string {
return mysqlOperators[operator]
}
// get mysql table field types.
func (d *dbBaseTidb) DbTypes() map[string]string {
return mysqlTypes
}
// show table sql for mysql.
func (d *dbBaseTidb) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
}
// show columns sql of table for mysql.
func (d *dbBaseTidb) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
"WHERE table_schema = DATABASE() AND table_name = '%s'", table)
}
// execute sql to check index exist.
func (d *dbBaseTidb) IndexExists(db dbQuerier, table string, name string) bool {
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
var cnt int
row.Scan(&cnt)
return cnt > 0
}
// create new mysql dbBaser.
func newdbBaseTidb() dbBaser {
b := new(dbBaseTidb)
b.ins = b
return b
}

177
pkg/orm/db_utils.go Normal file
View File

@ -0,0 +1,177 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"fmt"
"reflect"
"time"
)
// get table alias.
func getDbAlias(name string) *alias {
if al, ok := dataBaseCache.get(name); ok {
return al
}
panic(fmt.Errorf("unknown DataBase alias name %s", name))
}
// get pk column info.
func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
fi := mi.fields.pk
v := ind.FieldByIndex(fi.fieldIndex)
if fi.fieldType&IsPositiveIntegerField > 0 {
vu := v.Uint()
exist = vu > 0
value = vu
} else if fi.fieldType&IsIntegerField > 0 {
vu := v.Int()
exist = true
value = vu
} else if fi.fieldType&IsRelField > 0 {
_, value, exist = getExistPk(fi.relModelInfo, reflect.Indirect(v))
} else {
vu := v.String()
exist = vu != ""
value = vu
}
column = fi.column
return
}
// get fields description as flatted string.
func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) {
outFor:
for _, arg := range args {
val := reflect.ValueOf(arg)
if arg == nil {
params = append(params, arg)
continue
}
kind := val.Kind()
if kind == reflect.Ptr {
val = val.Elem()
kind = val.Kind()
arg = val.Interface()
}
switch kind {
case reflect.String:
v := val.String()
if fi != nil {
if fi.fieldType == TypeTimeField || fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField {
var t time.Time
var err error
if len(v) >= 19 {
s := v[:19]
t, err = time.ParseInLocation(formatDateTime, s, DefaultTimeLoc)
} else if len(v) >= 10 {
s := v
if len(v) > 10 {
s = v[:10]
}
t, err = time.ParseInLocation(formatDate, s, tz)
} else {
s := v
if len(s) > 8 {
s = v[:8]
}
t, err = time.ParseInLocation(formatTime, s, tz)
}
if err == nil {
if fi.fieldType == TypeDateField {
v = t.In(tz).Format(formatDate)
} else if fi.fieldType == TypeDateTimeField {
v = t.In(tz).Format(formatDateTime)
} else {
v = t.In(tz).Format(formatTime)
}
}
}
}
arg = v
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
arg = val.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
arg = val.Uint()
case reflect.Float32:
arg, _ = StrTo(ToStr(arg)).Float64()
case reflect.Float64:
arg = val.Float()
case reflect.Bool:
arg = val.Bool()
case reflect.Slice, reflect.Array:
if _, ok := arg.([]byte); ok {
continue outFor
}
var args []interface{}
for i := 0; i < val.Len(); i++ {
v := val.Index(i)
var vu interface{}
if v.CanInterface() {
vu = v.Interface()
}
if vu == nil {
continue
}
args = append(args, vu)
}
if len(args) > 0 {
p := getFlatParams(fi, args, tz)
params = append(params, p...)
}
continue outFor
case reflect.Struct:
if v, ok := arg.(time.Time); ok {
if fi != nil && fi.fieldType == TypeDateField {
arg = v.In(tz).Format(formatDate)
} else if fi != nil && fi.fieldType == TypeDateTimeField {
arg = v.In(tz).Format(formatDateTime)
} else if fi != nil && fi.fieldType == TypeTimeField {
arg = v.In(tz).Format(formatTime)
} else {
arg = v.In(tz).Format(formatDateTime)
}
} else {
typ := val.Type()
name := getFullName(typ)
var value interface{}
if mmi, ok := modelCache.getByFullName(name); ok {
if _, vu, exist := getExistPk(mmi, val); exist {
value = vu
}
}
arg = value
if arg == nil {
panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name))
}
}
}
params = append(params, arg)
}
return
}

99
pkg/orm/models.go Normal file
View File

@ -0,0 +1,99 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"sync"
)
const (
odCascade = "cascade"
odSetNULL = "set_null"
odSetDefault = "set_default"
odDoNothing = "do_nothing"
defaultStructTagName = "orm"
defaultStructTagDelim = ";"
)
var (
modelCache = &_modelCache{
cache: make(map[string]*modelInfo),
cacheByFullName: make(map[string]*modelInfo),
}
)
// model info collection
type _modelCache struct {
sync.RWMutex // only used outsite for bootStrap
orders []string
cache map[string]*modelInfo
cacheByFullName map[string]*modelInfo
done bool
}
// get all model info
func (mc *_modelCache) all() map[string]*modelInfo {
m := make(map[string]*modelInfo, len(mc.cache))
for k, v := range mc.cache {
m[k] = v
}
return m
}
// get ordered model info
func (mc *_modelCache) allOrdered() []*modelInfo {
m := make([]*modelInfo, 0, len(mc.orders))
for _, table := range mc.orders {
m = append(m, mc.cache[table])
}
return m
}
// get model info by table name
func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) {
mi, ok = mc.cache[table]
return
}
// get model info by full name
func (mc *_modelCache) getByFullName(name string) (mi *modelInfo, ok bool) {
mi, ok = mc.cacheByFullName[name]
return
}
// set model info to collection
func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
mii := mc.cache[table]
mc.cache[table] = mi
mc.cacheByFullName[mi.fullName] = mi
if mii == nil {
mc.orders = append(mc.orders, table)
}
return mii
}
// clean all model info.
func (mc *_modelCache) clean() {
mc.orders = make([]string, 0)
mc.cache = make(map[string]*modelInfo)
mc.cacheByFullName = make(map[string]*modelInfo)
mc.done = false
}
// ResetModelCache Clean model cache. Then you can re-RegisterModel.
// Common use this api for test case.
func ResetModelCache() {
modelCache.clean()
}

347
pkg/orm/models_boot.go Normal file
View File

@ -0,0 +1,347 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"fmt"
"os"
"reflect"
"runtime/debug"
"strings"
)
// register models.
// PrefixOrSuffix means table name prefix or suffix.
// isPrefix whether the prefix is prefix or suffix
func registerModel(PrefixOrSuffix string, model interface{}, isPrefix bool) {
val := reflect.ValueOf(model)
typ := reflect.Indirect(val).Type()
if val.Kind() != reflect.Ptr {
panic(fmt.Errorf("<orm.RegisterModel> cannot use non-ptr model struct `%s`", getFullName(typ)))
}
// For this case:
// u := &User{}
// registerModel(&u)
if typ.Kind() == reflect.Ptr {
panic(fmt.Errorf("<orm.RegisterModel> only allow ptr model struct, it looks you use two reference to the struct `%s`", typ))
}
table := getTableName(val)
if PrefixOrSuffix != "" {
if isPrefix {
table = PrefixOrSuffix + table
} else {
table = table + PrefixOrSuffix
}
}
// models's fullname is pkgpath + struct name
name := getFullName(typ)
if _, ok := modelCache.getByFullName(name); ok {
fmt.Printf("<orm.RegisterModel> model `%s` repeat register, must be unique\n", name)
os.Exit(2)
}
if _, ok := modelCache.get(table); ok {
fmt.Printf("<orm.RegisterModel> table name `%s` repeat register, must be unique\n", table)
os.Exit(2)
}
mi := newModelInfo(val)
if mi.fields.pk == nil {
outFor:
for _, fi := range mi.fields.fieldsDB {
if strings.ToLower(fi.name) == "id" {
switch fi.addrValue.Elem().Kind() {
case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64:
fi.auto = true
fi.pk = true
mi.fields.pk = fi
break outFor
}
}
}
if mi.fields.pk == nil {
fmt.Printf("<orm.RegisterModel> `%s` needs a primary key field, default is to use 'id' if not set\n", name)
os.Exit(2)
}
}
mi.table = table
mi.pkg = typ.PkgPath()
mi.model = model
mi.manual = true
modelCache.set(table, mi)
}
// bootstrap models
func bootStrap() {
if modelCache.done {
return
}
var (
err error
models map[string]*modelInfo
)
if dataBaseCache.getDefault() == nil {
err = fmt.Errorf("must have one register DataBase alias named `default`")
goto end
}
// set rel and reverse model
// RelManyToMany set the relTable
models = modelCache.all()
for _, mi := range models {
for _, fi := range mi.fields.columns {
if fi.rel || fi.reverse {
elm := fi.addrValue.Type().Elem()
if fi.fieldType == RelReverseMany || fi.fieldType == RelManyToMany {
elm = elm.Elem()
}
// check the rel or reverse model already register
name := getFullName(elm)
mii, ok := modelCache.getByFullName(name)
if !ok || mii.pkg != elm.PkgPath() {
err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String())
goto end
}
fi.relModelInfo = mii
switch fi.fieldType {
case RelManyToMany:
if fi.relThrough != "" {
if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) {
pn := fi.relThrough[:i]
rmi, ok := modelCache.getByFullName(fi.relThrough)
if !ok || pn != rmi.pkg {
err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough)
goto end
}
fi.relThroughModelInfo = rmi
fi.relTable = rmi.table
} else {
err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough)
goto end
}
} else {
i := newM2MModelInfo(mi, mii)
if fi.relTable != "" {
i.table = fi.relTable
}
if v := modelCache.set(i.table, i); v != nil {
err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable)
goto end
}
fi.relTable = i.table
fi.relThroughModelInfo = i
}
fi.relThroughModelInfo.isThrough = true
}
}
}
}
// check the rel filed while the relModelInfo also has filed point to current model
// if not exist, add a new field to the relModelInfo
models = modelCache.all()
for _, mi := range models {
for _, fi := range mi.fields.fieldsRel {
switch fi.fieldType {
case RelForeignKey, RelOneToOne, RelManyToMany:
inModel := false
for _, ffi := range fi.relModelInfo.fields.fieldsReverse {
if ffi.relModelInfo == mi {
inModel = true
break
}
}
if !inModel {
rmi := fi.relModelInfo
ffi := new(fieldInfo)
ffi.name = mi.name
ffi.column = ffi.name
ffi.fullName = rmi.fullName + "." + ffi.name
ffi.reverse = true
ffi.relModelInfo = mi
ffi.mi = rmi
if fi.fieldType == RelOneToOne {
ffi.fieldType = RelReverseOne
} else {
ffi.fieldType = RelReverseMany
}
if !rmi.fields.Add(ffi) {
added := false
for cnt := 0; cnt < 5; cnt++ {
ffi.name = fmt.Sprintf("%s%d", mi.name, cnt)
ffi.column = ffi.name
ffi.fullName = rmi.fullName + "." + ffi.name
if added = rmi.fields.Add(ffi); added {
break
}
}
if !added {
panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName))
}
}
}
}
}
}
models = modelCache.all()
for _, mi := range models {
for _, fi := range mi.fields.fieldsRel {
switch fi.fieldType {
case RelManyToMany:
for _, ffi := range fi.relThroughModelInfo.fields.fieldsRel {
switch ffi.fieldType {
case RelOneToOne, RelForeignKey:
if ffi.relModelInfo == fi.relModelInfo {
fi.reverseFieldInfoTwo = ffi
}
if ffi.relModelInfo == mi {
fi.reverseField = ffi.name
fi.reverseFieldInfo = ffi
}
}
}
if fi.reverseFieldInfoTwo == nil {
err = fmt.Errorf("can not find m2m field for m2m model `%s`, ensure your m2m model defined correct",
fi.relThroughModelInfo.fullName)
goto end
}
}
}
}
models = modelCache.all()
for _, mi := range models {
for _, fi := range mi.fields.fieldsReverse {
switch fi.fieldType {
case RelReverseOne:
found := false
mForA:
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] {
if ffi.relModelInfo == mi {
found = true
fi.reverseField = ffi.name
fi.reverseFieldInfo = ffi
ffi.reverseField = fi.name
ffi.reverseFieldInfo = fi
break mForA
}
}
if !found {
err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
goto end
}
case RelReverseMany:
found := false
mForB:
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] {
if ffi.relModelInfo == mi {
found = true
fi.reverseField = ffi.name
fi.reverseFieldInfo = ffi
ffi.reverseField = fi.name
ffi.reverseFieldInfo = fi
break mForB
}
}
if !found {
mForC:
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] {
conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough ||
fi.relTable != "" && fi.relTable == ffi.relTable ||
fi.relThrough == "" && fi.relTable == ""
if ffi.relModelInfo == mi && conditions {
found = true
fi.reverseField = ffi.reverseFieldInfoTwo.name
fi.reverseFieldInfo = ffi.reverseFieldInfoTwo
fi.relThroughModelInfo = ffi.relThroughModelInfo
fi.reverseFieldInfoTwo = ffi.reverseFieldInfo
fi.reverseFieldInfoM2M = ffi
ffi.reverseFieldInfoM2M = fi
break mForC
}
}
}
if !found {
err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
goto end
}
}
}
}
end:
if err != nil {
fmt.Println(err)
debug.PrintStack()
os.Exit(2)
}
}
// RegisterModel register models
func RegisterModel(models ...interface{}) {
if modelCache.done {
panic(fmt.Errorf("RegisterModel must be run before BootStrap"))
}
RegisterModelWithPrefix("", models...)
}
// RegisterModelWithPrefix register models with a prefix
func RegisterModelWithPrefix(prefix string, models ...interface{}) {
if modelCache.done {
panic(fmt.Errorf("RegisterModelWithPrefix must be run before BootStrap"))
}
for _, model := range models {
registerModel(prefix, model, true)
}
}
// RegisterModelWithSuffix register models with a suffix
func RegisterModelWithSuffix(suffix string, models ...interface{}) {
if modelCache.done {
panic(fmt.Errorf("RegisterModelWithSuffix must be run before BootStrap"))
}
for _, model := range models {
registerModel(suffix, model, false)
}
}
// BootStrap bootstrap models.
// make all model parsed and can not add more models
func BootStrap() {
modelCache.Lock()
defer modelCache.Unlock()
if modelCache.done {
return
}
bootStrap()
modelCache.done = true
}

783
pkg/orm/models_fields.go Normal file
View File

@ -0,0 +1,783 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"fmt"
"strconv"
"time"
)
// Define the Type enum
const (
TypeBooleanField = 1 << iota
TypeVarCharField
TypeCharField
TypeTextField
TypeTimeField
TypeDateField
TypeDateTimeField
TypeBitField
TypeSmallIntegerField
TypeIntegerField
TypeBigIntegerField
TypePositiveBitField
TypePositiveSmallIntegerField
TypePositiveIntegerField
TypePositiveBigIntegerField
TypeFloatField
TypeDecimalField
TypeJSONField
TypeJsonbField
RelForeignKey
RelOneToOne
RelManyToMany
RelReverseOne
RelReverseMany
)
// Define some logic enum
const (
IsIntegerField = ^-TypePositiveBigIntegerField >> 6 << 7
IsPositiveIntegerField = ^-TypePositiveBigIntegerField >> 10 << 11
IsRelField = ^-RelReverseMany >> 18 << 19
IsFieldType = ^-RelReverseMany<<1 + 1
)
// BooleanField A true/false field.
type BooleanField bool
// Value return the BooleanField
func (e BooleanField) Value() bool {
return bool(e)
}
// Set will set the BooleanField
func (e *BooleanField) Set(d bool) {
*e = BooleanField(d)
}
// String format the Bool to string
func (e *BooleanField) String() string {
return strconv.FormatBool(e.Value())
}
// FieldType return BooleanField the type
func (e *BooleanField) FieldType() int {
return TypeBooleanField
}
// SetRaw set the interface to bool
func (e *BooleanField) SetRaw(value interface{}) error {
switch d := value.(type) {
case bool:
e.Set(d)
case string:
v, err := StrTo(d).Bool()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<BooleanField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the current value
func (e *BooleanField) RawValue() interface{} {
return e.Value()
}
// verify the BooleanField implement the Fielder interface
var _ Fielder = new(BooleanField)
// CharField A string field
// required values tag: size
// The size is enforced at the database level and in modelss validation.
// eg: `orm:"size(120)"`
type CharField string
// Value return the CharField's Value
func (e CharField) Value() string {
return string(e)
}
// Set CharField value
func (e *CharField) Set(d string) {
*e = CharField(d)
}
// String return the CharField
func (e *CharField) String() string {
return e.Value()
}
// FieldType return the enum type
func (e *CharField) FieldType() int {
return TypeVarCharField
}
// SetRaw set the interface to string
func (e *CharField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
e.Set(d)
default:
return fmt.Errorf("<CharField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the CharField value
func (e *CharField) RawValue() interface{} {
return e.Value()
}
// verify CharField implement Fielder
var _ Fielder = new(CharField)
// TimeField A time, represented in go by a time.Time instance.
// only time values like 10:00:00
// Has a few extra, optional attr tag:
//
// auto_now:
// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps.
// Note that the current date is always used; its not just a default value that you can override.
//
// auto_now_add:
// Automatically set the field to now when the object is first created. Useful for creation of timestamps.
// Note that the current date is always used; its not just a default value that you can override.
//
// eg: `orm:"auto_now"` or `orm:"auto_now_add"`
type TimeField time.Time
// Value return the time.Time
func (e TimeField) Value() time.Time {
return time.Time(e)
}
// Set set the TimeField's value
func (e *TimeField) Set(d time.Time) {
*e = TimeField(d)
}
// String convert time to string
func (e *TimeField) String() string {
return e.Value().String()
}
// FieldType return enum type Date
func (e *TimeField) FieldType() int {
return TypeDateField
}
// SetRaw convert the interface to time.Time. Allow string and time.Time
func (e *TimeField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
v, err := timeParse(d, formatTime)
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<TimeField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return time value
func (e *TimeField) RawValue() interface{} {
return e.Value()
}
var _ Fielder = new(TimeField)
// DateField A date, represented in go by a time.Time instance.
// only date values like 2006-01-02
// Has a few extra, optional attr tag:
//
// auto_now:
// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps.
// Note that the current date is always used; its not just a default value that you can override.
//
// auto_now_add:
// Automatically set the field to now when the object is first created. Useful for creation of timestamps.
// Note that the current date is always used; its not just a default value that you can override.
//
// eg: `orm:"auto_now"` or `orm:"auto_now_add"`
type DateField time.Time
// Value return the time.Time
func (e DateField) Value() time.Time {
return time.Time(e)
}
// Set set the DateField's value
func (e *DateField) Set(d time.Time) {
*e = DateField(d)
}
// String convert datetime to string
func (e *DateField) String() string {
return e.Value().String()
}
// FieldType return enum type Date
func (e *DateField) FieldType() int {
return TypeDateField
}
// SetRaw convert the interface to time.Time. Allow string and time.Time
func (e *DateField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
v, err := timeParse(d, formatDate)
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<DateField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return Date value
func (e *DateField) RawValue() interface{} {
return e.Value()
}
// verify DateField implement fielder interface
var _ Fielder = new(DateField)
// DateTimeField A date, represented in go by a time.Time instance.
// datetime values like 2006-01-02 15:04:05
// Takes the same extra arguments as DateField.
type DateTimeField time.Time
// Value return the datetime value
func (e DateTimeField) Value() time.Time {
return time.Time(e)
}
// Set set the time.Time to datetime
func (e *DateTimeField) Set(d time.Time) {
*e = DateTimeField(d)
}
// String return the time's String
func (e *DateTimeField) String() string {
return e.Value().String()
}
// FieldType return the enum TypeDateTimeField
func (e *DateTimeField) FieldType() int {
return TypeDateTimeField
}
// SetRaw convert the string or time.Time to DateTimeField
func (e *DateTimeField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
v, err := timeParse(d, formatDateTime)
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<DateTimeField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the datetime value
func (e *DateTimeField) RawValue() interface{} {
return e.Value()
}
// verify datetime implement fielder
var _ Fielder = new(DateTimeField)
// FloatField A floating-point number represented in go by a float32 value.
type FloatField float64
// Value return the FloatField value
func (e FloatField) Value() float64 {
return float64(e)
}
// Set the Float64
func (e *FloatField) Set(d float64) {
*e = FloatField(d)
}
// String return the string
func (e *FloatField) String() string {
return ToStr(e.Value(), -1, 32)
}
// FieldType return the enum type
func (e *FloatField) FieldType() int {
return TypeFloatField
}
// SetRaw converter interface Float64 float32 or string to FloatField
func (e *FloatField) SetRaw(value interface{}) error {
switch d := value.(type) {
case float32:
e.Set(float64(d))
case float64:
e.Set(d)
case string:
v, err := StrTo(d).Float64()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<FloatField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the FloatField value
func (e *FloatField) RawValue() interface{} {
return e.Value()
}
// verify FloatField implement Fielder
var _ Fielder = new(FloatField)
// SmallIntegerField -32768 to 32767
type SmallIntegerField int16
// Value return int16 value
func (e SmallIntegerField) Value() int16 {
return int16(e)
}
// Set the SmallIntegerField value
func (e *SmallIntegerField) Set(d int16) {
*e = SmallIntegerField(d)
}
// String convert smallint to string
func (e *SmallIntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return enum type SmallIntegerField
func (e *SmallIntegerField) FieldType() int {
return TypeSmallIntegerField
}
// SetRaw convert interface int16/string to int16
func (e *SmallIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int16:
e.Set(d)
case string:
v, err := StrTo(d).Int16()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<SmallIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return smallint value
func (e *SmallIntegerField) RawValue() interface{} {
return e.Value()
}
// verify SmallIntegerField implement Fielder
var _ Fielder = new(SmallIntegerField)
// IntegerField -2147483648 to 2147483647
type IntegerField int32
// Value return the int32
func (e IntegerField) Value() int32 {
return int32(e)
}
// Set IntegerField value
func (e *IntegerField) Set(d int32) {
*e = IntegerField(d)
}
// String convert Int32 to string
func (e *IntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return the enum type
func (e *IntegerField) FieldType() int {
return TypeIntegerField
}
// SetRaw convert interface int32/string to int32
func (e *IntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int32:
e.Set(d)
case string:
v, err := StrTo(d).Int32()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<IntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return IntegerField value
func (e *IntegerField) RawValue() interface{} {
return e.Value()
}
// verify IntegerField implement Fielder
var _ Fielder = new(IntegerField)
// BigIntegerField -9223372036854775808 to 9223372036854775807.
type BigIntegerField int64
// Value return int64
func (e BigIntegerField) Value() int64 {
return int64(e)
}
// Set the BigIntegerField value
func (e *BigIntegerField) Set(d int64) {
*e = BigIntegerField(d)
}
// String convert BigIntegerField to string
func (e *BigIntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return enum type
func (e *BigIntegerField) FieldType() int {
return TypeBigIntegerField
}
// SetRaw convert interface int64/string to int64
func (e *BigIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int64:
e.Set(d)
case string:
v, err := StrTo(d).Int64()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<BigIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return BigIntegerField value
func (e *BigIntegerField) RawValue() interface{} {
return e.Value()
}
// verify BigIntegerField implement Fielder
var _ Fielder = new(BigIntegerField)
// PositiveSmallIntegerField 0 to 65535
type PositiveSmallIntegerField uint16
// Value return uint16
func (e PositiveSmallIntegerField) Value() uint16 {
return uint16(e)
}
// Set PositiveSmallIntegerField value
func (e *PositiveSmallIntegerField) Set(d uint16) {
*e = PositiveSmallIntegerField(d)
}
// String convert uint16 to string
func (e *PositiveSmallIntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return enum type
func (e *PositiveSmallIntegerField) FieldType() int {
return TypePositiveSmallIntegerField
}
// SetRaw convert Interface uint16/string to uint16
func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint16:
e.Set(d)
case string:
v, err := StrTo(d).Uint16()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<PositiveSmallIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue returns PositiveSmallIntegerField value
func (e *PositiveSmallIntegerField) RawValue() interface{} {
return e.Value()
}
// verify PositiveSmallIntegerField implement Fielder
var _ Fielder = new(PositiveSmallIntegerField)
// PositiveIntegerField 0 to 4294967295
type PositiveIntegerField uint32
// Value return PositiveIntegerField value. Uint32
func (e PositiveIntegerField) Value() uint32 {
return uint32(e)
}
// Set the PositiveIntegerField value
func (e *PositiveIntegerField) Set(d uint32) {
*e = PositiveIntegerField(d)
}
// String convert PositiveIntegerField to string
func (e *PositiveIntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return enum type
func (e *PositiveIntegerField) FieldType() int {
return TypePositiveIntegerField
}
// SetRaw convert interface uint32/string to Uint32
func (e *PositiveIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint32:
e.Set(d)
case string:
v, err := StrTo(d).Uint32()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<PositiveIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the PositiveIntegerField Value
func (e *PositiveIntegerField) RawValue() interface{} {
return e.Value()
}
// verify PositiveIntegerField implement Fielder
var _ Fielder = new(PositiveIntegerField)
// PositiveBigIntegerField 0 to 18446744073709551615
type PositiveBigIntegerField uint64
// Value return uint64
func (e PositiveBigIntegerField) Value() uint64 {
return uint64(e)
}
// Set PositiveBigIntegerField value
func (e *PositiveBigIntegerField) Set(d uint64) {
*e = PositiveBigIntegerField(d)
}
// String convert PositiveBigIntegerField to string
func (e *PositiveBigIntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return enum type
func (e *PositiveBigIntegerField) FieldType() int {
return TypePositiveIntegerField
}
// SetRaw convert interface uint64/string to Uint64
func (e *PositiveBigIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint64:
e.Set(d)
case string:
v, err := StrTo(d).Uint64()
if err == nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<PositiveBigIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return PositiveBigIntegerField value
func (e *PositiveBigIntegerField) RawValue() interface{} {
return e.Value()
}
// verify PositiveBigIntegerField implement Fielder
var _ Fielder = new(PositiveBigIntegerField)
// TextField A large text field.
type TextField string
// Value return TextField value
func (e TextField) Value() string {
return string(e)
}
// Set the TextField value
func (e *TextField) Set(d string) {
*e = TextField(d)
}
// String convert TextField to string
func (e *TextField) String() string {
return e.Value()
}
// FieldType return enum type
func (e *TextField) FieldType() int {
return TypeTextField
}
// SetRaw convert interface string to string
func (e *TextField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
e.Set(d)
default:
return fmt.Errorf("<TextField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return TextField value
func (e *TextField) RawValue() interface{} {
return e.Value()
}
// verify TextField implement Fielder
var _ Fielder = new(TextField)
// JSONField postgres json field.
type JSONField string
// Value return JSONField value
func (j JSONField) Value() string {
return string(j)
}
// Set the JSONField value
func (j *JSONField) Set(d string) {
*j = JSONField(d)
}
// String convert JSONField to string
func (j *JSONField) String() string {
return j.Value()
}
// FieldType return enum type
func (j *JSONField) FieldType() int {
return TypeJSONField
}
// SetRaw convert interface string to string
func (j *JSONField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
j.Set(d)
default:
return fmt.Errorf("<JSONField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return JSONField value
func (j *JSONField) RawValue() interface{} {
return j.Value()
}
// verify JSONField implement Fielder
var _ Fielder = new(JSONField)
// JsonbField postgres json field.
type JsonbField string
// Value return JsonbField value
func (j JsonbField) Value() string {
return string(j)
}
// Set the JsonbField value
func (j *JsonbField) Set(d string) {
*j = JsonbField(d)
}
// String convert JsonbField to string
func (j *JsonbField) String() string {
return j.Value()
}
// FieldType return enum type
func (j *JsonbField) FieldType() int {
return TypeJsonbField
}
// SetRaw convert interface string to string
func (j *JsonbField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
j.Set(d)
default:
return fmt.Errorf("<JsonbField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return JsonbField value
func (j *JsonbField) RawValue() interface{} {
return j.Value()
}
// verify JsonbField implement Fielder
var _ Fielder = new(JsonbField)

473
pkg/orm/models_info_f.go Normal file
View File

@ -0,0 +1,473 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"errors"
"fmt"
"reflect"
"strings"
)
var errSkipField = errors.New("skip field")
// field info collection
type fields struct {
pk *fieldInfo
columns map[string]*fieldInfo
fields map[string]*fieldInfo
fieldsLow map[string]*fieldInfo
fieldsByType map[int][]*fieldInfo
fieldsRel []*fieldInfo
fieldsReverse []*fieldInfo
fieldsDB []*fieldInfo
rels []*fieldInfo
orders []string
dbcols []string
}
// add field info
func (f *fields) Add(fi *fieldInfo) (added bool) {
if f.fields[fi.name] == nil && f.columns[fi.column] == nil {
f.columns[fi.column] = fi
f.fields[fi.name] = fi
f.fieldsLow[strings.ToLower(fi.name)] = fi
} else {
return
}
if _, ok := f.fieldsByType[fi.fieldType]; !ok {
f.fieldsByType[fi.fieldType] = make([]*fieldInfo, 0)
}
f.fieldsByType[fi.fieldType] = append(f.fieldsByType[fi.fieldType], fi)
f.orders = append(f.orders, fi.column)
if fi.dbcol {
f.dbcols = append(f.dbcols, fi.column)
f.fieldsDB = append(f.fieldsDB, fi)
}
if fi.rel {
f.fieldsRel = append(f.fieldsRel, fi)
}
if fi.reverse {
f.fieldsReverse = append(f.fieldsReverse, fi)
}
return true
}
// get field info by name
func (f *fields) GetByName(name string) *fieldInfo {
return f.fields[name]
}
// get field info by column name
func (f *fields) GetByColumn(column string) *fieldInfo {
return f.columns[column]
}
// get field info by string, name is prior
func (f *fields) GetByAny(name string) (*fieldInfo, bool) {
if fi, ok := f.fields[name]; ok {
return fi, ok
}
if fi, ok := f.fieldsLow[strings.ToLower(name)]; ok {
return fi, ok
}
if fi, ok := f.columns[name]; ok {
return fi, ok
}
return nil, false
}
// create new field info collection
func newFields() *fields {
f := new(fields)
f.fields = make(map[string]*fieldInfo)
f.fieldsLow = make(map[string]*fieldInfo)
f.columns = make(map[string]*fieldInfo)
f.fieldsByType = make(map[int][]*fieldInfo)
return f
}
// single field info
type fieldInfo struct {
mi *modelInfo
fieldIndex []int
fieldType int
dbcol bool // table column fk and onetoone
inModel bool
name string
fullName string
column string
addrValue reflect.Value
sf reflect.StructField
auto bool
pk bool
null bool
index bool
unique bool
colDefault bool // whether has default tag
initial StrTo // store the default value
size int
toText bool
autoNow bool
autoNowAdd bool
rel bool // if type equal to RelForeignKey, RelOneToOne, RelManyToMany then true
reverse bool
reverseField string
reverseFieldInfo *fieldInfo
reverseFieldInfoTwo *fieldInfo
reverseFieldInfoM2M *fieldInfo
relTable string
relThrough string
relThroughModelInfo *modelInfo
relModelInfo *modelInfo
digits int
decimals int
isFielder bool // implement Fielder interface
onDelete string
description string
}
// new field info
func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mName string) (fi *fieldInfo, err error) {
var (
tag string
tagValue string
initial StrTo // store the default value
fieldType int
attrs map[string]bool
tags map[string]string
addrField reflect.Value
)
fi = new(fieldInfo)
// if field which CanAddr is the follow type
// A value is addressable if it is an element of a slice,
// an element of an addressable array, a field of an
// addressable struct, or the result of dereferencing a pointer.
addrField = field
if field.CanAddr() && field.Kind() != reflect.Ptr {
addrField = field.Addr()
if _, ok := addrField.Interface().(Fielder); !ok {
if field.Kind() == reflect.Slice {
addrField = field
}
}
}
attrs, tags = parseStructTag(sf.Tag.Get(defaultStructTagName))
if _, ok := attrs["-"]; ok {
return nil, errSkipField
}
digits := tags["digits"]
decimals := tags["decimals"]
size := tags["size"]
onDelete := tags["on_delete"]
initial.Clear()
if v, ok := tags["default"]; ok {
initial.Set(v)
}
checkType:
switch f := addrField.Interface().(type) {
case Fielder:
fi.isFielder = true
if field.Kind() == reflect.Ptr {
err = fmt.Errorf("the model Fielder can not be use ptr")
goto end
}
fieldType = f.FieldType()
if fieldType&IsRelField > 0 {
err = fmt.Errorf("unsupport type custom field, please refer to https://github.com/astaxie/beego/blob/master/orm/models_fields.go#L24-L42")
goto end
}
default:
tag = "rel"
tagValue = tags[tag]
if tagValue != "" {
switch tagValue {
case "fk":
fieldType = RelForeignKey
break checkType
case "one":
fieldType = RelOneToOne
break checkType
case "m2m":
fieldType = RelManyToMany
if tv := tags["rel_table"]; tv != "" {
fi.relTable = tv
} else if tv := tags["rel_through"]; tv != "" {
fi.relThrough = tv
}
break checkType
default:
err = fmt.Errorf("rel only allow these value: fk, one, m2m")
goto wrongTag
}
}
tag = "reverse"
tagValue = tags[tag]
if tagValue != "" {
switch tagValue {
case "one":
fieldType = RelReverseOne
break checkType
case "many":
fieldType = RelReverseMany
if tv := tags["rel_table"]; tv != "" {
fi.relTable = tv
} else if tv := tags["rel_through"]; tv != "" {
fi.relThrough = tv
}
break checkType
default:
err = fmt.Errorf("reverse only allow these value: one, many")
goto wrongTag
}
}
fieldType, err = getFieldType(addrField)
if err != nil {
goto end
}
if fieldType == TypeVarCharField {
switch tags["type"] {
case "char":
fieldType = TypeCharField
case "text":
fieldType = TypeTextField
case "json":
fieldType = TypeJSONField
case "jsonb":
fieldType = TypeJsonbField
}
}
if fieldType == TypeFloatField && (digits != "" || decimals != "") {
fieldType = TypeDecimalField
}
if fieldType == TypeDateTimeField && tags["type"] == "date" {
fieldType = TypeDateField
}
if fieldType == TypeTimeField && tags["type"] == "time" {
fieldType = TypeTimeField
}
}
// check the rel and reverse type
// rel should Ptr
// reverse should slice []*struct
switch fieldType {
case RelForeignKey, RelOneToOne, RelReverseOne:
if field.Kind() != reflect.Ptr {
err = fmt.Errorf("rel/reverse:one field must be *%s", field.Type().Name())
goto end
}
case RelManyToMany, RelReverseMany:
if field.Kind() != reflect.Slice {
err = fmt.Errorf("rel/reverse:many field must be slice")
goto end
} else {
if field.Type().Elem().Kind() != reflect.Ptr {
err = fmt.Errorf("rel/reverse:many slice must be []*%s", field.Type().Elem().Name())
goto end
}
}
}
if fieldType&IsFieldType == 0 {
err = fmt.Errorf("wrong field type")
goto end
}
fi.fieldType = fieldType
fi.name = sf.Name
fi.column = getColumnName(fieldType, addrField, sf, tags["column"])
fi.addrValue = addrField
fi.sf = sf
fi.fullName = mi.fullName + mName + "." + sf.Name
fi.description = tags["description"]
fi.null = attrs["null"]
fi.index = attrs["index"]
fi.auto = attrs["auto"]
fi.pk = attrs["pk"]
fi.unique = attrs["unique"]
// Mark object property if there is attribute "default" in the orm configuration
if _, ok := tags["default"]; ok {
fi.colDefault = true
}
switch fieldType {
case RelManyToMany, RelReverseMany, RelReverseOne:
fi.null = false
fi.index = false
fi.auto = false
fi.pk = false
fi.unique = false
default:
fi.dbcol = true
}
switch fieldType {
case RelForeignKey, RelOneToOne, RelManyToMany:
fi.rel = true
if fieldType == RelOneToOne {
fi.unique = true
}
case RelReverseMany, RelReverseOne:
fi.reverse = true
}
if fi.rel && fi.dbcol {
switch onDelete {
case odCascade, odDoNothing:
case odSetDefault:
if !initial.Exist() {
err = errors.New("on_delete: set_default need set field a default value")
goto end
}
case odSetNULL:
if !fi.null {
err = errors.New("on_delete: set_null need set field null")
goto end
}
default:
if onDelete == "" {
onDelete = odCascade
} else {
err = fmt.Errorf("on_delete value expected choice in `cascade,set_null,set_default,do_nothing`, unknown `%s`", onDelete)
goto end
}
}
fi.onDelete = onDelete
}
switch fieldType {
case TypeBooleanField:
case TypeVarCharField, TypeCharField, TypeJSONField, TypeJsonbField:
if size != "" {
v, e := StrTo(size).Int32()
if e != nil {
err = fmt.Errorf("wrong size value `%s`", size)
} else {
fi.size = int(v)
}
} else {
fi.size = 255
fi.toText = true
}
case TypeTextField:
fi.index = false
fi.unique = false
case TypeTimeField, TypeDateField, TypeDateTimeField:
if attrs["auto_now"] {
fi.autoNow = true
} else if attrs["auto_now_add"] {
fi.autoNowAdd = true
}
case TypeFloatField:
case TypeDecimalField:
d1 := digits
d2 := decimals
v1, er1 := StrTo(d1).Int8()
v2, er2 := StrTo(d2).Int8()
if er1 != nil || er2 != nil {
err = fmt.Errorf("wrong digits/decimals value %s/%s", d2, d1)
goto end
}
fi.digits = int(v1)
fi.decimals = int(v2)
default:
switch {
case fieldType&IsIntegerField > 0:
case fieldType&IsRelField > 0:
}
}
if fieldType&IsIntegerField == 0 {
if fi.auto {
err = fmt.Errorf("non-integer type cannot set auto")
goto end
}
}
if fi.auto || fi.pk {
if fi.auto {
switch addrField.Elem().Kind() {
case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64:
default:
err = fmt.Errorf("auto primary key only support int, int32, int64, uint, uint32, uint64 but found `%s`", addrField.Elem().Kind())
goto end
}
fi.pk = true
}
fi.null = false
fi.index = false
fi.unique = false
}
if fi.unique {
fi.index = false
}
// can not set default for these type
if fi.auto || fi.pk || fi.unique || fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField {
initial.Clear()
}
if initial.Exist() {
v := initial
switch fieldType {
case TypeBooleanField:
_, err = v.Bool()
case TypeFloatField, TypeDecimalField:
_, err = v.Float64()
case TypeBitField:
_, err = v.Int8()
case TypeSmallIntegerField:
_, err = v.Int16()
case TypeIntegerField:
_, err = v.Int32()
case TypeBigIntegerField:
_, err = v.Int64()
case TypePositiveBitField:
_, err = v.Uint8()
case TypePositiveSmallIntegerField:
_, err = v.Uint16()
case TypePositiveIntegerField:
_, err = v.Uint32()
case TypePositiveBigIntegerField:
_, err = v.Uint64()
}
if err != nil {
tag, tagValue = "default", tags["default"]
goto wrongTag
}
}
fi.initial = initial
end:
if err != nil {
return nil, err
}
return
wrongTag:
return nil, fmt.Errorf("wrong tag format: `%s:\"%s\"`, %s", tag, tagValue, err)
}

148
pkg/orm/models_info_m.go Normal file
View File

@ -0,0 +1,148 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"fmt"
"os"
"reflect"
)
// single model info
type modelInfo struct {
pkg string
name string
fullName string
table string
model interface{}
fields *fields
manual bool
addrField reflect.Value //store the original struct value
uniques []string
isThrough bool
}
// new model info
func newModelInfo(val reflect.Value) (mi *modelInfo) {
mi = &modelInfo{}
mi.fields = newFields()
ind := reflect.Indirect(val)
mi.addrField = val
mi.name = ind.Type().Name()
mi.fullName = getFullName(ind.Type())
addModelFields(mi, ind, "", []int{})
return
}
// index: FieldByIndex returns the nested field corresponding to index
func addModelFields(mi *modelInfo, ind reflect.Value, mName string, index []int) {
var (
err error
fi *fieldInfo
sf reflect.StructField
)
for i := 0; i < ind.NumField(); i++ {
field := ind.Field(i)
sf = ind.Type().Field(i)
// if the field is unexported skip
if sf.PkgPath != "" {
continue
}
// add anonymous struct fields
if sf.Anonymous {
addModelFields(mi, field, mName+"."+sf.Name, append(index, i))
continue
}
fi, err = newFieldInfo(mi, field, sf, mName)
if err == errSkipField {
err = nil
continue
} else if err != nil {
break
}
//record current field index
fi.fieldIndex = append(fi.fieldIndex, index...)
fi.fieldIndex = append(fi.fieldIndex, i)
fi.mi = mi
fi.inModel = true
if !mi.fields.Add(fi) {
err = fmt.Errorf("duplicate column name: %s", fi.column)
break
}
if fi.pk {
if mi.fields.pk != nil {
err = fmt.Errorf("one model must have one pk field only")
break
} else {
mi.fields.pk = fi
}
}
}
if err != nil {
fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err))
os.Exit(2)
}
}
// combine related model info to new model info.
// prepare for relation models query.
func newM2MModelInfo(m1, m2 *modelInfo) (mi *modelInfo) {
mi = new(modelInfo)
mi.fields = newFields()
mi.table = m1.table + "_" + m2.table + "s"
mi.name = camelString(mi.table)
mi.fullName = m1.pkg + "." + mi.name
fa := new(fieldInfo) // pk
f1 := new(fieldInfo) // m1 table RelForeignKey
f2 := new(fieldInfo) // m2 table RelForeignKey
fa.fieldType = TypeBigIntegerField
fa.auto = true
fa.pk = true
fa.dbcol = true
fa.name = "Id"
fa.column = "id"
fa.fullName = mi.fullName + "." + fa.name
f1.dbcol = true
f2.dbcol = true
f1.fieldType = RelForeignKey
f2.fieldType = RelForeignKey
f1.name = camelString(m1.table)
f2.name = camelString(m2.table)
f1.fullName = mi.fullName + "." + f1.name
f2.fullName = mi.fullName + "." + f2.name
f1.column = m1.table + "_id"
f2.column = m2.table + "_id"
f1.rel = true
f2.rel = true
f1.relTable = m1.table
f2.relTable = m2.table
f1.relModelInfo = m1
f2.relModelInfo = m2
f1.mi = mi
f2.mi = mi
mi.fields.Add(fa)
mi.fields.Add(f1)
mi.fields.Add(f2)
mi.fields.pk = fa
mi.uniques = []string{f1.column, f2.column}
return
}

497
pkg/orm/models_test.go Normal file
View File

@ -0,0 +1,497 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"database/sql"
"encoding/json"
"fmt"
"os"
"strings"
"time"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
// As tidb can't use go get, so disable the tidb testing now
// _ "github.com/pingcap/tidb"
)
// A slice string field.
type SliceStringField []string
func (e SliceStringField) Value() []string {
return []string(e)
}
func (e *SliceStringField) Set(d []string) {
*e = SliceStringField(d)
}
func (e *SliceStringField) Add(v string) {
*e = append(*e, v)
}
func (e *SliceStringField) String() string {
return strings.Join(e.Value(), ",")
}
func (e *SliceStringField) FieldType() int {
return TypeVarCharField
}
func (e *SliceStringField) SetRaw(value interface{}) error {
switch d := value.(type) {
case []string:
e.Set(d)
case string:
if len(d) > 0 {
parts := strings.Split(d, ",")
v := make([]string, 0, len(parts))
for _, p := range parts {
v = append(v, strings.TrimSpace(p))
}
e.Set(v)
}
default:
return fmt.Errorf("<SliceStringField.SetRaw> unknown value `%v`", value)
}
return nil
}
func (e *SliceStringField) RawValue() interface{} {
return e.String()
}
var _ Fielder = new(SliceStringField)
// A json field.
type JSONFieldTest struct {
Name string
Data string
}
func (e *JSONFieldTest) String() string {
data, _ := json.Marshal(e)
return string(data)
}
func (e *JSONFieldTest) FieldType() int {
return TypeTextField
}
func (e *JSONFieldTest) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
return json.Unmarshal([]byte(d), e)
default:
return fmt.Errorf("<JSONField.SetRaw> unknown value `%v`", value)
}
}
func (e *JSONFieldTest) RawValue() interface{} {
return e.String()
}
var _ Fielder = new(JSONFieldTest)
type Data struct {
ID int `orm:"column(id)"`
Boolean bool
Char string `orm:"size(50)"`
Text string `orm:"type(text)"`
JSON string `orm:"type(json);default({\"name\":\"json\"})"`
Jsonb string `orm:"type(jsonb)"`
Time time.Time `orm:"type(time)"`
Date time.Time `orm:"type(date)"`
DateTime time.Time `orm:"column(datetime)"`
Byte byte
Rune rune
Int int
Int8 int8
Int16 int16
Int32 int32
Int64 int64
Uint uint
Uint8 uint8
Uint16 uint16
Uint32 uint32
Uint64 uint64
Float32 float32
Float64 float64
Decimal float64 `orm:"digits(8);decimals(4)"`
}
type DataNull struct {
ID int `orm:"column(id)"`
Boolean bool `orm:"null"`
Char string `orm:"null;size(50)"`
Text string `orm:"null;type(text)"`
JSON string `orm:"type(json);null"`
Jsonb string `orm:"type(jsonb);null"`
Time time.Time `orm:"null;type(time)"`
Date time.Time `orm:"null;type(date)"`
DateTime time.Time `orm:"null;column(datetime)"`
Byte byte `orm:"null"`
Rune rune `orm:"null"`
Int int `orm:"null"`
Int8 int8 `orm:"null"`
Int16 int16 `orm:"null"`
Int32 int32 `orm:"null"`
Int64 int64 `orm:"null"`
Uint uint `orm:"null"`
Uint8 uint8 `orm:"null"`
Uint16 uint16 `orm:"null"`
Uint32 uint32 `orm:"null"`
Uint64 uint64 `orm:"null"`
Float32 float32 `orm:"null"`
Float64 float64 `orm:"null"`
Decimal float64 `orm:"digits(8);decimals(4);null"`
NullString sql.NullString `orm:"null"`
NullBool sql.NullBool `orm:"null"`
NullFloat64 sql.NullFloat64 `orm:"null"`
NullInt64 sql.NullInt64 `orm:"null"`
BooleanPtr *bool `orm:"null"`
CharPtr *string `orm:"null;size(50)"`
TextPtr *string `orm:"null;type(text)"`
BytePtr *byte `orm:"null"`
RunePtr *rune `orm:"null"`
IntPtr *int `orm:"null"`
Int8Ptr *int8 `orm:"null"`
Int16Ptr *int16 `orm:"null"`
Int32Ptr *int32 `orm:"null"`
Int64Ptr *int64 `orm:"null"`
UintPtr *uint `orm:"null"`
Uint8Ptr *uint8 `orm:"null"`
Uint16Ptr *uint16 `orm:"null"`
Uint32Ptr *uint32 `orm:"null"`
Uint64Ptr *uint64 `orm:"null"`
Float32Ptr *float32 `orm:"null"`
Float64Ptr *float64 `orm:"null"`
DecimalPtr *float64 `orm:"digits(8);decimals(4);null"`
TimePtr *time.Time `orm:"null;type(time)"`
DatePtr *time.Time `orm:"null;type(date)"`
DateTimePtr *time.Time `orm:"null"`
}
type String string
type Boolean bool
type Byte byte
type Rune rune
type Int int
type Int8 int8
type Int16 int16
type Int32 int32
type Int64 int64
type Uint uint
type Uint8 uint8
type Uint16 uint16
type Uint32 uint32
type Uint64 uint64
type Float32 float64
type Float64 float64
type DataCustom struct {
ID int `orm:"column(id)"`
Boolean Boolean
Char string `orm:"size(50)"`
Text string `orm:"type(text)"`
Byte Byte
Rune Rune
Int Int
Int8 Int8
Int16 Int16
Int32 Int32
Int64 Int64
Uint Uint
Uint8 Uint8
Uint16 Uint16
Uint32 Uint32
Uint64 Uint64
Float32 Float32
Float64 Float64
Decimal Float64 `orm:"digits(8);decimals(4)"`
}
// only for mysql
type UserBig struct {
ID uint64 `orm:"column(id)"`
Name string
}
type User struct {
ID int `orm:"column(id)"`
UserName string `orm:"size(30);unique"`
Email string `orm:"size(100)"`
Password string `orm:"size(100)"`
Status int16 `orm:"column(Status)"`
IsStaff bool
IsActive bool `orm:"default(true)"`
Created time.Time `orm:"auto_now_add;type(date)"`
Updated time.Time `orm:"auto_now"`
Profile *Profile `orm:"null;rel(one);on_delete(set_null)"`
Posts []*Post `orm:"reverse(many)" json:"-"`
ShouldSkip string `orm:"-"`
Nums int
Langs SliceStringField `orm:"size(100)"`
Extra JSONFieldTest `orm:"type(text)"`
unexport bool `orm:"-"`
unexportBool bool
}
func (u *User) TableIndex() [][]string {
return [][]string{
{"Id", "UserName"},
{"Id", "Created"},
}
}
func (u *User) TableUnique() [][]string {
return [][]string{
{"UserName", "Email"},
}
}
func NewUser() *User {
obj := new(User)
return obj
}
type Profile struct {
ID int `orm:"column(id)"`
Age int16
Money float64
User *User `orm:"reverse(one)" json:"-"`
BestPost *Post `orm:"rel(one);null"`
}
func (u *Profile) TableName() string {
return "user_profile"
}
func NewProfile() *Profile {
obj := new(Profile)
return obj
}
type Post struct {
ID int `orm:"column(id)"`
User *User `orm:"rel(fk)"`
Title string `orm:"size(60)"`
Content string `orm:"type(text)"`
Created time.Time `orm:"auto_now_add"`
Updated time.Time `orm:"auto_now"`
Tags []*Tag `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.PostTags)"`
}
func (u *Post) TableIndex() [][]string {
return [][]string{
{"Id", "Created"},
}
}
func NewPost() *Post {
obj := new(Post)
return obj
}
type Tag struct {
ID int `orm:"column(id)"`
Name string `orm:"size(30)"`
BestPost *Post `orm:"rel(one);null"`
Posts []*Post `orm:"reverse(many)" json:"-"`
}
func NewTag() *Tag {
obj := new(Tag)
return obj
}
type PostTags struct {
ID int `orm:"column(id)"`
Post *Post `orm:"rel(fk)"`
Tag *Tag `orm:"rel(fk)"`
}
func (m *PostTags) TableName() string {
return "prefix_post_tags"
}
type Comment struct {
ID int `orm:"column(id)"`
Post *Post `orm:"rel(fk);column(post)"`
Content string `orm:"type(text)"`
Parent *Comment `orm:"null;rel(fk)"`
Created time.Time `orm:"auto_now_add"`
}
func NewComment() *Comment {
obj := new(Comment)
return obj
}
type Group struct {
ID int `orm:"column(gid);size(32)"`
Name string
Permissions []*Permission `orm:"reverse(many)" json:"-"`
}
type Permission struct {
ID int `orm:"column(id)"`
Name string
Groups []*Group `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.GroupPermissions)"`
}
type GroupPermissions struct {
ID int `orm:"column(id)"`
Group *Group `orm:"rel(fk)"`
Permission *Permission `orm:"rel(fk)"`
}
type ModelID struct {
ID int64
}
type ModelBase struct {
ModelID
Created time.Time `orm:"auto_now_add;type(datetime)"`
Updated time.Time `orm:"auto_now;type(datetime)"`
}
type InLine struct {
// Common Fields
ModelBase
// Other Fields
Name string `orm:"unique"`
Email string
}
func NewInLine() *InLine {
return new(InLine)
}
type InLineOneToOne struct {
// Common Fields
ModelBase
Note string
InLine *InLine `orm:"rel(fk);column(inline)"`
}
func NewInLineOneToOne() *InLineOneToOne {
return new(InLineOneToOne)
}
type IntegerPk struct {
ID int64 `orm:"pk"`
Value string
}
type UintPk struct {
ID uint32 `orm:"pk"`
Name string
}
type PtrPk struct {
ID *IntegerPk `orm:"pk;rel(one)"`
Positive bool
}
var DBARGS = struct {
Driver string
Source string
Debug string
}{
os.Getenv("ORM_DRIVER"),
os.Getenv("ORM_SOURCE"),
os.Getenv("ORM_DEBUG"),
}
var (
IsMysql = DBARGS.Driver == "mysql"
IsSqlite = DBARGS.Driver == "sqlite3"
IsPostgres = DBARGS.Driver == "postgres"
IsTidb = DBARGS.Driver == "tidb"
)
var (
dORM Ormer
dDbBaser dbBaser
)
var (
helpinfo = `need driver and source!
Default DB Drivers.
driver: url
mysql: https://github.com/go-sql-driver/mysql
sqlite3: https://github.com/mattn/go-sqlite3
postgres: https://github.com/lib/pq
tidb: https://github.com/pingcap/tidb
usage:
go get -u github.com/astaxie/beego/orm
go get -u github.com/go-sql-driver/mysql
go get -u github.com/mattn/go-sqlite3
go get -u github.com/lib/pq
go get -u github.com/pingcap/tidb
#### MySQL
mysql -u root -e 'create database orm_test;'
export ORM_DRIVER=mysql
export ORM_SOURCE="root:@/orm_test?charset=utf8"
go test -v github.com/astaxie/beego/orm
#### Sqlite3
export ORM_DRIVER=sqlite3
export ORM_SOURCE='file:memory_test?mode=memory'
go test -v github.com/astaxie/beego/orm
#### PostgreSQL
psql -c 'create database orm_test;' -U postgres
export ORM_DRIVER=postgres
export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable"
go test -v github.com/astaxie/beego/orm
#### TiDB
export ORM_DRIVER=tidb
export ORM_SOURCE='memory://test/test'
go test -v github.com/astaxie/beego/orm
`
)
func init() {
Debug, _ = StrTo(DBARGS.Debug).Bool()
if DBARGS.Driver == "" || DBARGS.Source == "" {
fmt.Println(helpinfo)
os.Exit(2)
}
RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20)
alias := getDbAlias("default")
if alias.Driver == DRMySQL {
alias.Engine = "INNODB"
}
}

227
pkg/orm/models_utils.go Normal file
View File

@ -0,0 +1,227 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"database/sql"
"fmt"
"reflect"
"strings"
"time"
)
// 1 is attr
// 2 is tag
var supportTag = map[string]int{
"-": 1,
"null": 1,
"index": 1,
"unique": 1,
"pk": 1,
"auto": 1,
"auto_now": 1,
"auto_now_add": 1,
"size": 2,
"column": 2,
"default": 2,
"rel": 2,
"reverse": 2,
"rel_table": 2,
"rel_through": 2,
"digits": 2,
"decimals": 2,
"on_delete": 2,
"type": 2,
"description": 2,
}
// get reflect.Type name with package path.
func getFullName(typ reflect.Type) string {
return typ.PkgPath() + "." + typ.Name()
}
// getTableName get struct table name.
// If the struct implement the TableName, then get the result as tablename
// else use the struct name which will apply snakeString.
func getTableName(val reflect.Value) string {
if fun := val.MethodByName("TableName"); fun.IsValid() {
vals := fun.Call([]reflect.Value{})
// has return and the first val is string
if len(vals) > 0 && vals[0].Kind() == reflect.String {
return vals[0].String()
}
}
return snakeString(reflect.Indirect(val).Type().Name())
}
// get table engine, myisam or innodb.
func getTableEngine(val reflect.Value) string {
fun := val.MethodByName("TableEngine")
if fun.IsValid() {
vals := fun.Call([]reflect.Value{})
if len(vals) > 0 && vals[0].Kind() == reflect.String {
return vals[0].String()
}
}
return ""
}
// get table index from method.
func getTableIndex(val reflect.Value) [][]string {
fun := val.MethodByName("TableIndex")
if fun.IsValid() {
vals := fun.Call([]reflect.Value{})
if len(vals) > 0 && vals[0].CanInterface() {
if d, ok := vals[0].Interface().([][]string); ok {
return d
}
}
}
return nil
}
// get table unique from method
func getTableUnique(val reflect.Value) [][]string {
fun := val.MethodByName("TableUnique")
if fun.IsValid() {
vals := fun.Call([]reflect.Value{})
if len(vals) > 0 && vals[0].CanInterface() {
if d, ok := vals[0].Interface().([][]string); ok {
return d
}
}
}
return nil
}
// get snaked column name
func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string {
column := col
if col == "" {
column = nameStrategyMap[nameStrategy](sf.Name)
}
switch ft {
case RelForeignKey, RelOneToOne:
if len(col) == 0 {
column = column + "_id"
}
case RelManyToMany, RelReverseMany, RelReverseOne:
column = sf.Name
}
return column
}
// return field type as type constant from reflect.Value
func getFieldType(val reflect.Value) (ft int, err error) {
switch val.Type() {
case reflect.TypeOf(new(int8)):
ft = TypeBitField
case reflect.TypeOf(new(int16)):
ft = TypeSmallIntegerField
case reflect.TypeOf(new(int32)),
reflect.TypeOf(new(int)):
ft = TypeIntegerField
case reflect.TypeOf(new(int64)):
ft = TypeBigIntegerField
case reflect.TypeOf(new(uint8)):
ft = TypePositiveBitField
case reflect.TypeOf(new(uint16)):
ft = TypePositiveSmallIntegerField
case reflect.TypeOf(new(uint32)),
reflect.TypeOf(new(uint)):
ft = TypePositiveIntegerField
case reflect.TypeOf(new(uint64)):
ft = TypePositiveBigIntegerField
case reflect.TypeOf(new(float32)),
reflect.TypeOf(new(float64)):
ft = TypeFloatField
case reflect.TypeOf(new(bool)):
ft = TypeBooleanField
case reflect.TypeOf(new(string)):
ft = TypeVarCharField
case reflect.TypeOf(new(time.Time)):
ft = TypeDateTimeField
default:
elm := reflect.Indirect(val)
switch elm.Kind() {
case reflect.Int8:
ft = TypeBitField
case reflect.Int16:
ft = TypeSmallIntegerField
case reflect.Int32, reflect.Int:
ft = TypeIntegerField
case reflect.Int64:
ft = TypeBigIntegerField
case reflect.Uint8:
ft = TypePositiveBitField
case reflect.Uint16:
ft = TypePositiveSmallIntegerField
case reflect.Uint32, reflect.Uint:
ft = TypePositiveIntegerField
case reflect.Uint64:
ft = TypePositiveBigIntegerField
case reflect.Float32, reflect.Float64:
ft = TypeFloatField
case reflect.Bool:
ft = TypeBooleanField
case reflect.String:
ft = TypeVarCharField
default:
if elm.Interface() == nil {
panic(fmt.Errorf("%s is nil pointer, may be miss setting tag", val))
}
switch elm.Interface().(type) {
case sql.NullInt64:
ft = TypeBigIntegerField
case sql.NullFloat64:
ft = TypeFloatField
case sql.NullBool:
ft = TypeBooleanField
case sql.NullString:
ft = TypeVarCharField
case time.Time:
ft = TypeDateTimeField
}
}
}
if ft&IsFieldType == 0 {
err = fmt.Errorf("unsupport field type %s, may be miss setting tag", val)
}
return
}
// parse struct tag string
func parseStructTag(data string) (attrs map[string]bool, tags map[string]string) {
attrs = make(map[string]bool)
tags = make(map[string]string)
for _, v := range strings.Split(data, defaultStructTagDelim) {
if v == "" {
continue
}
v = strings.TrimSpace(v)
if t := strings.ToLower(v); supportTag[t] == 1 {
attrs[t] = true
} else if i := strings.Index(v, "("); i > 0 && strings.Index(v, ")") == len(v)-1 {
name := t[:i]
if supportTag[name] == 2 {
v = v[i+1 : len(v)-1]
tags[name] = v
}
} else {
DebugLog.Println("unsupport orm tag", v)
}
}
return
}

579
pkg/orm/orm.go Normal file
View File

@ -0,0 +1,579 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build go1.8
// Package orm provide ORM for MySQL/PostgreSQL/sqlite
// Simple Usage
//
// package main
//
// import (
// "fmt"
// "github.com/astaxie/beego/orm"
// _ "github.com/go-sql-driver/mysql" // import your used driver
// )
//
// // Model Struct
// type User struct {
// Id int `orm:"auto"`
// Name string `orm:"size(100)"`
// }
//
// func init() {
// orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30)
// }
//
// func main() {
// o := orm.NewOrm()
// user := User{Name: "slene"}
// // insert
// id, err := o.Insert(&user)
// // update
// user.Name = "astaxie"
// num, err := o.Update(&user)
// // read one
// u := User{Id: user.Id}
// err = o.Read(&u)
// // delete
// num, err = o.Delete(&u)
// }
//
// more docs: http://beego.me/docs/mvc/model/overview.md
package orm
import (
"context"
"database/sql"
"errors"
"fmt"
"os"
"reflect"
"sync"
"time"
)
// DebugQueries define the debug
const (
DebugQueries = iota
)
// Define common vars
var (
Debug = false
DebugLog = NewLog(os.Stdout)
DefaultRowsLimit = -1
DefaultRelsDepth = 2
DefaultTimeLoc = time.Local
ErrTxHasBegan = errors.New("<Ormer.Begin> transaction already begin")
ErrTxDone = errors.New("<Ormer.Commit/Rollback> transaction not begin")
ErrMultiRows = errors.New("<QuerySeter> return multi rows")
ErrNoRows = errors.New("<QuerySeter> no row found")
ErrStmtClosed = errors.New("<QuerySeter> stmt already closed")
ErrArgs = errors.New("<Ormer> args error may be empty")
ErrNotImplement = errors.New("have not implement")
)
// Params stores the Params
type Params map[string]interface{}
// ParamsList stores paramslist
type ParamsList []interface{}
type orm struct {
alias *alias
db dbQuerier
isTx bool
}
var _ Ormer = new(orm)
// get model info and model reflect value
func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) {
val := reflect.ValueOf(md)
ind = reflect.Indirect(val)
typ := ind.Type()
if needPtr && val.Kind() != reflect.Ptr {
panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ)))
}
name := getFullName(typ)
if mi, ok := modelCache.getByFullName(name); ok {
return mi, ind
}
panic(fmt.Errorf("<Ormer> table: `%s` not found, make sure it was registered with `RegisterModel()`", name))
}
// get field info from model info by given field name
func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
fi, ok := mi.fields.GetByAny(name)
if !ok {
panic(fmt.Errorf("<Ormer> cannot find field `%s` for model `%s`", name, mi.fullName))
}
return fi
}
// read data to model
func (o *orm) Read(md interface{}, cols ...string) error {
mi, ind := o.getMiInd(md, true)
return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false)
}
// read data to model, like Read(), but use "SELECT FOR UPDATE" form
func (o *orm) ReadForUpdate(md interface{}, cols ...string) error {
mi, ind := o.getMiInd(md, true)
return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true)
}
// Try to read a row from the database, or insert one if it doesn't exist
func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) {
cols = append([]string{col1}, cols...)
mi, ind := o.getMiInd(md, true)
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false)
if err == ErrNoRows {
// Create
id, err := o.Insert(md)
return (err == nil), id, err
}
id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex)
if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
id = int64(vid.Uint())
} else if mi.fields.pk.rel {
return o.ReadOrCreate(vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name)
} else {
id = vid.Int()
}
return false, id, err
}
// insert model data to database
func (o *orm) Insert(md interface{}) (int64, error) {
mi, ind := o.getMiInd(md, true)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
if err != nil {
return id, err
}
o.setPk(mi, ind, id)
return id, nil
}
// set auto pk field
func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) {
if mi.fields.pk.auto {
if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id))
} else {
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(id)
}
}
}
// insert some models to database
func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) {
var cnt int64
sind := reflect.Indirect(reflect.ValueOf(mds))
switch sind.Kind() {
case reflect.Array, reflect.Slice:
if sind.Len() == 0 {
return cnt, ErrArgs
}
default:
return cnt, ErrArgs
}
if bulk <= 1 {
for i := 0; i < sind.Len(); i++ {
ind := reflect.Indirect(sind.Index(i))
mi, _ := o.getMiInd(ind.Interface(), false)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
if err != nil {
return cnt, err
}
o.setPk(mi, ind, id)
cnt++
}
} else {
mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ)
}
return cnt, nil
}
// InsertOrUpdate data to database
func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) {
mi, ind := o.getMiInd(md, true)
id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...)
if err != nil {
return id, err
}
o.setPk(mi, ind, id)
return id, nil
}
// update model to database.
// cols set the columns those want to update.
func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
mi, ind := o.getMiInd(md, true)
return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
}
// delete model in database
// cols shows the delete conditions values read from. default is pk
func (o *orm) Delete(md interface{}, cols ...string) (int64, error) {
mi, ind := o.getMiInd(md, true)
num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols)
if err != nil {
return num, err
}
if num > 0 {
o.setPk(mi, ind, 0)
}
return num, nil
}
// create a models to models queryer
func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer {
mi, ind := o.getMiInd(md, true)
fi := o.getFieldInfo(mi, name)
switch {
case fi.fieldType == RelManyToMany:
case fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough:
default:
panic(fmt.Errorf("<Ormer.QueryM2M> model `%s` . name `%s` is not a m2m field", fi.name, mi.fullName))
}
return newQueryM2M(md, o, mi, fi, ind)
}
// load related models to md model.
// args are limit, offset int and order string.
//
// example:
// orm.LoadRelated(post,"Tags")
// for _,tag := range post.Tags{...}
//
// make sure the relation is defined in model struct tags.
func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) {
_, fi, ind, qseter := o.queryRelated(md, name)
qs := qseter.(*querySet)
var relDepth int
var limit, offset int64
var order string
for i, arg := range args {
switch i {
case 0:
if v, ok := arg.(bool); ok {
if v {
relDepth = DefaultRelsDepth
}
} else if v, ok := arg.(int); ok {
relDepth = v
}
case 1:
limit = ToInt64(arg)
case 2:
offset = ToInt64(arg)
case 3:
order, _ = arg.(string)
}
}
switch fi.fieldType {
case RelOneToOne, RelForeignKey, RelReverseOne:
limit = 1
offset = 0
}
qs.limit = limit
qs.offset = offset
qs.relDepth = relDepth
if len(order) > 0 {
qs.orders = []string{order}
}
find := ind.FieldByIndex(fi.fieldIndex)
var nums int64
var err error
switch fi.fieldType {
case RelOneToOne, RelForeignKey, RelReverseOne:
val := reflect.New(find.Type().Elem())
container := val.Interface()
err = qs.One(container)
if err == nil {
find.Set(val)
nums = 1
}
default:
nums, err = qs.All(find.Addr().Interface())
}
return nums, err
}
// return a QuerySeter for related models to md model.
// it can do all, update, delete in QuerySeter.
// example:
// qs := orm.QueryRelated(post,"Tag")
// qs.All(&[]*Tag{})
//
func (o *orm) QueryRelated(md interface{}, name string) QuerySeter {
// is this api needed ?
_, _, _, qs := o.queryRelated(md, name)
return qs
}
// get QuerySeter for related models to md model
func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) {
mi, ind := o.getMiInd(md, true)
fi := o.getFieldInfo(mi, name)
_, _, exist := getExistPk(mi, ind)
if !exist {
panic(ErrMissPK)
}
var qs *querySet
switch fi.fieldType {
case RelOneToOne, RelForeignKey, RelManyToMany:
if !fi.inModel {
break
}
qs = o.getRelQs(md, mi, fi)
case RelReverseOne, RelReverseMany:
if !fi.inModel {
break
}
qs = o.getReverseQs(md, mi, fi)
}
if qs == nil {
panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available rel/reverse field", md, name))
}
return mi, fi, ind, qs
}
// get reverse relation QuerySeter
func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
switch fi.fieldType {
case RelReverseOne, RelReverseMany:
default:
panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available reverse field", fi.name, mi.fullName))
}
var q *querySet
if fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough {
q = newQuerySet(o, fi.relModelInfo).(*querySet)
q.cond = NewCondition().And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md)
} else {
q = newQuerySet(o, fi.reverseFieldInfo.mi).(*querySet)
q.cond = NewCondition().And(fi.reverseFieldInfo.column, md)
}
return q
}
// get relation QuerySeter
func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
switch fi.fieldType {
case RelOneToOne, RelForeignKey, RelManyToMany:
default:
panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available rel field", fi.name, mi.fullName))
}
q := newQuerySet(o, fi.relModelInfo).(*querySet)
q.cond = NewCondition()
if fi.fieldType == RelManyToMany {
q.cond = q.cond.And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md)
} else {
q.cond = q.cond.And(fi.reverseFieldInfo.column, md)
}
return q
}
// return a QuerySeter for table operations.
// table name can be string or struct.
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
var name string
if table, ok := ptrStructOrTableName.(string); ok {
name = nameStrategyMap[defaultNameStrategy](table)
if mi, ok := modelCache.get(name); ok {
qs = newQuerySet(o, mi)
}
} else {
name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName)))
if mi, ok := modelCache.getByFullName(name); ok {
qs = newQuerySet(o, mi)
}
}
if qs == nil {
panic(fmt.Errorf("<Ormer.QueryTable> table name: `%s` not exists", name))
}
return
}
// switch to another registered database driver by given name.
func (o *orm) Using(name string) error {
if o.isTx {
panic(fmt.Errorf("<Ormer.Using> transaction has been start, cannot change db"))
}
if al, ok := dataBaseCache.get(name); ok {
o.alias = al
if Debug {
o.db = newDbQueryLog(al, al.DB)
} else {
o.db = al.DB
}
} else {
return fmt.Errorf("<Ormer.Using> unknown db alias name `%s`", name)
}
return nil
}
// begin transaction
func (o *orm) Begin() error {
return o.BeginTx(context.Background(), nil)
}
func (o *orm) BeginTx(ctx context.Context, opts *sql.TxOptions) error {
if o.isTx {
return ErrTxHasBegan
}
var tx *sql.Tx
tx, err := o.db.(txer).BeginTx(ctx, opts)
if err != nil {
return err
}
o.isTx = true
if Debug {
o.db.(*dbQueryLog).SetDB(tx)
} else {
o.db = tx
}
return nil
}
// commit transaction
func (o *orm) Commit() error {
if !o.isTx {
return ErrTxDone
}
err := o.db.(txEnder).Commit()
if err == nil {
o.isTx = false
o.Using(o.alias.Name)
} else if err == sql.ErrTxDone {
return ErrTxDone
}
return err
}
// rollback transaction
func (o *orm) Rollback() error {
if !o.isTx {
return ErrTxDone
}
err := o.db.(txEnder).Rollback()
if err == nil {
o.isTx = false
o.Using(o.alias.Name)
} else if err == sql.ErrTxDone {
return ErrTxDone
}
return err
}
// return a raw query seter for raw sql string.
func (o *orm) Raw(query string, args ...interface{}) RawSeter {
return newRawSet(o, query, args)
}
// return current using database Driver
func (o *orm) Driver() Driver {
return driver(o.alias.Name)
}
// return sql.DBStats for current database
func (o *orm) DBStats() *sql.DBStats {
if o.alias != nil && o.alias.DB != nil {
stats := o.alias.DB.DB.Stats()
return &stats
}
return nil
}
// NewOrm create new orm
func NewOrm() Ormer {
BootStrap() // execute only once
o := new(orm)
err := o.Using("default")
if err != nil {
panic(err)
}
return o
}
// NewOrmWithDB create a new ormer object with specify *sql.DB for query
func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) {
var al *alias
if dr, ok := drivers[driverName]; ok {
al = new(alias)
al.DbBaser = dbBasers[dr]
al.Driver = dr
} else {
return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
}
al.Name = aliasName
al.DriverName = driverName
al.DB = &DB{
RWMutex: new(sync.RWMutex),
DB: db,
stmtDecorators: newStmtDecoratorLruWithEvict(),
}
detectTZ(al)
o := new(orm)
o.alias = al
if Debug {
o.db = newDbQueryLog(o.alias, db)
} else {
o.db = db
}
return o, nil
}

153
pkg/orm/orm_conds.go Normal file
View File

@ -0,0 +1,153 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"fmt"
"strings"
)
// ExprSep define the expression separation
const (
ExprSep = "__"
)
type condValue struct {
exprs []string
args []interface{}
cond *Condition
isOr bool
isNot bool
isCond bool
isRaw bool
sql string
}
// Condition struct.
// work for WHERE conditions.
type Condition struct {
params []condValue
}
// NewCondition return new condition struct
func NewCondition() *Condition {
c := &Condition{}
return c
}
// Raw add raw sql to condition
func (c Condition) Raw(expr string, sql string) *Condition {
if len(sql) == 0 {
panic(fmt.Errorf("<Condition.Raw> sql cannot empty"))
}
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), sql: sql, isRaw: true})
return &c
}
// And add expression to condition
func (c Condition) And(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.And> args cannot empty"))
}
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args})
return &c
}
// AndNot add NOT expression to condition
func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.AndNot> args cannot empty"))
}
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true})
return &c
}
// AndCond combine a condition to current condition
func (c *Condition) AndCond(cond *Condition) *Condition {
c = c.clone()
if c == cond {
panic(fmt.Errorf("<Condition.AndCond> cannot use self as sub cond"))
}
if cond != nil {
c.params = append(c.params, condValue{cond: cond, isCond: true})
}
return c
}
// AndNotCond combine a AND NOT condition to current condition
func (c *Condition) AndNotCond(cond *Condition) *Condition {
c = c.clone()
if c == cond {
panic(fmt.Errorf("<Condition.AndNotCond> cannot use self as sub cond"))
}
if cond != nil {
c.params = append(c.params, condValue{cond: cond, isCond: true, isNot: true})
}
return c
}
// Or add OR expression to condition
func (c Condition) Or(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.Or> args cannot empty"))
}
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isOr: true})
return &c
}
// OrNot add OR NOT expression to condition
func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.OrNot> args cannot empty"))
}
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true, isOr: true})
return &c
}
// OrCond combine a OR condition to current condition
func (c *Condition) OrCond(cond *Condition) *Condition {
c = c.clone()
if c == cond {
panic(fmt.Errorf("<Condition.OrCond> cannot use self as sub cond"))
}
if cond != nil {
c.params = append(c.params, condValue{cond: cond, isCond: true, isOr: true})
}
return c
}
// OrNotCond combine a OR NOT condition to current condition
func (c *Condition) OrNotCond(cond *Condition) *Condition {
c = c.clone()
if c == cond {
panic(fmt.Errorf("<Condition.OrNotCond> cannot use self as sub cond"))
}
if cond != nil {
c.params = append(c.params, condValue{cond: cond, isCond: true, isNot: true, isOr: true})
}
return c
}
// IsEmpty check the condition arguments are empty or not.
func (c *Condition) IsEmpty() bool {
return len(c.params) == 0
}
// clone clone a condition
func (c Condition) clone() *Condition {
return &c
}

222
pkg/orm/orm_log.go Normal file
View File

@ -0,0 +1,222 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"context"
"database/sql"
"fmt"
"io"
"log"
"strings"
"time"
)
// Log implement the log.Logger
type Log struct {
*log.Logger
}
//costomer log func
var LogFunc func(query map[string]interface{})
// NewLog set io.Writer to create a Logger.
func NewLog(out io.Writer) *Log {
d := new(Log)
d.Logger = log.New(out, "[ORM]", log.LstdFlags)
return d
}
func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) {
var logMap = make(map[string]interface{})
sub := time.Now().Sub(t) / 1e5
elsp := float64(int(sub)) / 10.0
logMap["cost_time"] = elsp
flag := " OK"
if err != nil {
flag = "FAIL"
}
logMap["flag"] = flag
con := fmt.Sprintf(" -[Queries/%s] - [%s / %11s / %7.1fms] - [%s]", alias.Name, flag, operaton, elsp, query)
cons := make([]string, 0, len(args))
for _, arg := range args {
cons = append(cons, fmt.Sprintf("%v", arg))
}
if len(cons) > 0 {
con += fmt.Sprintf(" - `%s`", strings.Join(cons, "`, `"))
}
if err != nil {
con += " - " + err.Error()
}
logMap["sql"] = fmt.Sprintf("%s-`%s`", query, strings.Join(cons, "`, `"))
if LogFunc != nil{
LogFunc(logMap)
}
DebugLog.Println(con)
}
// statement query logger struct.
// if dev mode, use stmtQueryLog, or use stmtQuerier.
type stmtQueryLog struct {
alias *alias
query string
stmt stmtQuerier
}
var _ stmtQuerier = new(stmtQueryLog)
func (d *stmtQueryLog) Close() error {
a := time.Now()
err := d.stmt.Close()
debugLogQueies(d.alias, "st.Close", d.query, a, err)
return err
}
func (d *stmtQueryLog) Exec(args ...interface{}) (sql.Result, error) {
a := time.Now()
res, err := d.stmt.Exec(args...)
debugLogQueies(d.alias, "st.Exec", d.query, a, err, args...)
return res, err
}
func (d *stmtQueryLog) Query(args ...interface{}) (*sql.Rows, error) {
a := time.Now()
res, err := d.stmt.Query(args...)
debugLogQueies(d.alias, "st.Query", d.query, a, err, args...)
return res, err
}
func (d *stmtQueryLog) QueryRow(args ...interface{}) *sql.Row {
a := time.Now()
res := d.stmt.QueryRow(args...)
debugLogQueies(d.alias, "st.QueryRow", d.query, a, nil, args...)
return res
}
func newStmtQueryLog(alias *alias, stmt stmtQuerier, query string) stmtQuerier {
d := new(stmtQueryLog)
d.stmt = stmt
d.alias = alias
d.query = query
return d
}
// database query logger struct.
// if dev mode, use dbQueryLog, or use dbQuerier.
type dbQueryLog struct {
alias *alias
db dbQuerier
tx txer
txe txEnder
}
var _ dbQuerier = new(dbQueryLog)
var _ txer = new(dbQueryLog)
var _ txEnder = new(dbQueryLog)
func (d *dbQueryLog) Prepare(query string) (*sql.Stmt, error) {
a := time.Now()
stmt, err := d.db.Prepare(query)
debugLogQueies(d.alias, "db.Prepare", query, a, err)
return stmt, err
}
func (d *dbQueryLog) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
a := time.Now()
stmt, err := d.db.PrepareContext(ctx, query)
debugLogQueies(d.alias, "db.Prepare", query, a, err)
return stmt, err
}
func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error) {
a := time.Now()
res, err := d.db.Exec(query, args...)
debugLogQueies(d.alias, "db.Exec", query, a, err, args...)
return res, err
}
func (d *dbQueryLog) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
a := time.Now()
res, err := d.db.ExecContext(ctx, query, args...)
debugLogQueies(d.alias, "db.Exec", query, a, err, args...)
return res, err
}
func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error) {
a := time.Now()
res, err := d.db.Query(query, args...)
debugLogQueies(d.alias, "db.Query", query, a, err, args...)
return res, err
}
func (d *dbQueryLog) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
a := time.Now()
res, err := d.db.QueryContext(ctx, query, args...)
debugLogQueies(d.alias, "db.Query", query, a, err, args...)
return res, err
}
func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row {
a := time.Now()
res := d.db.QueryRow(query, args...)
debugLogQueies(d.alias, "db.QueryRow", query, a, nil, args...)
return res
}
func (d *dbQueryLog) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
a := time.Now()
res := d.db.QueryRowContext(ctx, query, args...)
debugLogQueies(d.alias, "db.QueryRow", query, a, nil, args...)
return res
}
func (d *dbQueryLog) Begin() (*sql.Tx, error) {
a := time.Now()
tx, err := d.db.(txer).Begin()
debugLogQueies(d.alias, "db.Begin", "START TRANSACTION", a, err)
return tx, err
}
func (d *dbQueryLog) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
a := time.Now()
tx, err := d.db.(txer).BeginTx(ctx, opts)
debugLogQueies(d.alias, "db.BeginTx", "START TRANSACTION", a, err)
return tx, err
}
func (d *dbQueryLog) Commit() error {
a := time.Now()
err := d.db.(txEnder).Commit()
debugLogQueies(d.alias, "tx.Commit", "COMMIT", a, err)
return err
}
func (d *dbQueryLog) Rollback() error {
a := time.Now()
err := d.db.(txEnder).Rollback()
debugLogQueies(d.alias, "tx.Rollback", "ROLLBACK", a, err)
return err
}
func (d *dbQueryLog) SetDB(db dbQuerier) {
d.db = db
}
func newDbQueryLog(alias *alias, db dbQuerier) dbQuerier {
d := new(dbQueryLog)
d.alias = alias
d.db = db
return d
}

87
pkg/orm/orm_object.go Normal file
View File

@ -0,0 +1,87 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"fmt"
"reflect"
)
// an insert queryer struct
type insertSet struct {
mi *modelInfo
orm *orm
stmt stmtQuerier
closed bool
}
var _ Inserter = new(insertSet)
// insert model ignore it's registered or not.
func (o *insertSet) Insert(md interface{}) (int64, error) {
if o.closed {
return 0, ErrStmtClosed
}
val := reflect.ValueOf(md)
ind := reflect.Indirect(val)
typ := ind.Type()
name := getFullName(typ)
if val.Kind() != reflect.Ptr {
panic(fmt.Errorf("<Inserter.Insert> cannot use non-ptr model struct `%s`", name))
}
if name != o.mi.fullName {
panic(fmt.Errorf("<Inserter.Insert> need model `%s` but found `%s`", o.mi.fullName, name))
}
id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind, o.orm.alias.TZ)
if err != nil {
return id, err
}
if id > 0 {
if o.mi.fields.pk.auto {
if o.mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetUint(uint64(id))
} else {
ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetInt(id)
}
}
}
return id, nil
}
// close insert queryer statement
func (o *insertSet) Close() error {
if o.closed {
return ErrStmtClosed
}
o.closed = true
return o.stmt.Close()
}
// create new insert queryer.
func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) {
bi := new(insertSet)
bi.orm = orm
bi.mi = mi
st, query, err := orm.alias.DbBaser.PrepareInsert(orm.db, mi)
if err != nil {
return nil, err
}
if Debug {
bi.stmt = newStmtQueryLog(orm.alias, st, query)
} else {
bi.stmt = st
}
return bi, nil
}

140
pkg/orm/orm_querym2m.go Normal file
View File

@ -0,0 +1,140 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import "reflect"
// model to model struct
type queryM2M struct {
md interface{}
mi *modelInfo
fi *fieldInfo
qs *querySet
ind reflect.Value
}
// add models to origin models when creating queryM2M.
// example:
// m2m := orm.QueryM2M(post,"Tag")
// m2m.Add(&Tag1{},&Tag2{})
// for _,tag := range post.Tags{}
//
// make sure the relation is defined in post model struct tag.
func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
fi := o.fi
mi := fi.relThroughModelInfo
mfi := fi.reverseFieldInfo
rfi := fi.reverseFieldInfoTwo
orm := o.qs.orm
dbase := orm.alias.DbBaser
var models []interface{}
var otherValues []interface{}
var otherNames []string
for _, colname := range mi.fields.dbcols {
if colname != mfi.column && colname != rfi.column && colname != fi.mi.fields.pk.column &&
mi.fields.columns[colname] != mi.fields.pk {
otherNames = append(otherNames, colname)
}
}
for i, md := range mds {
if reflect.Indirect(reflect.ValueOf(md)).Kind() != reflect.Struct && i > 0 {
otherValues = append(otherValues, md)
mds = append(mds[:i], mds[i+1:]...)
}
}
for _, md := range mds {
val := reflect.ValueOf(md)
if val.Kind() == reflect.Slice || val.Kind() == reflect.Array {
for i := 0; i < val.Len(); i++ {
v := val.Index(i)
if v.CanInterface() {
models = append(models, v.Interface())
}
}
} else {
models = append(models, md)
}
}
_, v1, exist := getExistPk(o.mi, o.ind)
if !exist {
panic(ErrMissPK)
}
names := []string{mfi.column, rfi.column}
values := make([]interface{}, 0, len(models)*2)
for _, md := range models {
ind := reflect.Indirect(reflect.ValueOf(md))
var v2 interface{}
if ind.Kind() != reflect.Struct {
v2 = ind.Interface()
} else {
_, v2, exist = getExistPk(fi.relModelInfo, ind)
if !exist {
panic(ErrMissPK)
}
}
values = append(values, v1, v2)
}
names = append(names, otherNames...)
values = append(values, otherValues...)
return dbase.InsertValue(orm.db, mi, true, names, values)
}
// remove models following the origin model relationship
func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
fi := o.fi
qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md)
return qs.Filter(fi.reverseFieldInfoTwo.name+ExprSep+"in", mds).Delete()
}
// check model is existed in relationship of origin model
func (o *queryM2M) Exist(md interface{}) bool {
fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).
Filter(fi.reverseFieldInfoTwo.name, md).Exist()
}
// clean all models in related of origin model
func (o *queryM2M) Clear() (int64, error) {
fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete()
}
// count all related models of origin model
func (o *queryM2M) Count() (int64, error) {
fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count()
}
var _ QueryM2Mer = new(queryM2M)
// create new M2M queryer.
func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer {
qm2m := new(queryM2M)
qm2m.md = md
qm2m.mi = mi
qm2m.fi = fi
qm2m.ind = ind
qm2m.qs = newQuerySet(o, fi.relThroughModelInfo).(*querySet)
return qm2m
}

300
pkg/orm/orm_queryset.go Normal file
View File

@ -0,0 +1,300 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"context"
"fmt"
)
type colValue struct {
value int64
opt operator
}
type operator int
// define Col operations
const (
ColAdd operator = iota
ColMinus
ColMultiply
ColExcept
ColBitAnd
ColBitRShift
ColBitLShift
ColBitXOR
ColBitOr
)
// ColValue do the field raw changes. e.g Nums = Nums + 10. usage:
// Params{
// "Nums": ColValue(Col_Add, 10),
// }
func ColValue(opt operator, value interface{}) interface{} {
switch opt {
case ColAdd, ColMinus, ColMultiply, ColExcept, ColBitAnd, ColBitRShift,
ColBitLShift, ColBitXOR, ColBitOr:
default:
panic(fmt.Errorf("orm.ColValue wrong operator"))
}
v, err := StrTo(ToStr(value)).Int64()
if err != nil {
panic(fmt.Errorf("orm.ColValue doesn't support non string/numeric type, %s", err))
}
var val colValue
val.value = v
val.opt = opt
return val
}
// real query struct
type querySet struct {
mi *modelInfo
cond *Condition
related []string
relDepth int
limit int64
offset int64
groups []string
orders []string
distinct bool
forupdate bool
orm *orm
ctx context.Context
forContext bool
}
var _ QuerySeter = new(querySet)
// add condition expression to QuerySeter.
func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
if o.cond == nil {
o.cond = NewCondition()
}
o.cond = o.cond.And(expr, args...)
return &o
}
// add raw sql to querySeter.
func (o querySet) FilterRaw(expr string, sql string) QuerySeter {
if o.cond == nil {
o.cond = NewCondition()
}
o.cond = o.cond.Raw(expr, sql)
return &o
}
// add NOT condition to querySeter.
func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
if o.cond == nil {
o.cond = NewCondition()
}
o.cond = o.cond.AndNot(expr, args...)
return &o
}
// set offset number
func (o *querySet) setOffset(num interface{}) {
o.offset = ToInt64(num)
}
// add LIMIT value.
// args[0] means offset, e.g. LIMIT num,offset.
func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter {
o.limit = ToInt64(limit)
if len(args) > 0 {
o.setOffset(args[0])
}
return &o
}
// add OFFSET value
func (o querySet) Offset(offset interface{}) QuerySeter {
o.setOffset(offset)
return &o
}
// add GROUP expression
func (o querySet) GroupBy(exprs ...string) QuerySeter {
o.groups = exprs
return &o
}
// add ORDER expression.
// "column" means ASC, "-column" means DESC.
func (o querySet) OrderBy(exprs ...string) QuerySeter {
o.orders = exprs
return &o
}
// add DISTINCT to SELECT
func (o querySet) Distinct() QuerySeter {
o.distinct = true
return &o
}
// add FOR UPDATE to SELECT
func (o querySet) ForUpdate() QuerySeter {
o.forupdate = true
return &o
}
// set relation model to query together.
// it will query relation models and assign to parent model.
func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
if len(params) == 0 {
o.relDepth = DefaultRelsDepth
} else {
for _, p := range params {
switch val := p.(type) {
case string:
o.related = append(o.related, val)
case int:
o.relDepth = val
default:
panic(fmt.Errorf("<QuerySeter.RelatedSel> wrong param kind: %v", val))
}
}
}
return &o
}
// set condition to QuerySeter.
func (o querySet) SetCond(cond *Condition) QuerySeter {
o.cond = cond
return &o
}
// get condition from QuerySeter
func (o querySet) GetCond() *Condition {
return o.cond
}
// return QuerySeter execution result number
func (o *querySet) Count() (int64, error) {
return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
}
// check result empty or not after QuerySeter executed
func (o *querySet) Exist() bool {
cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
return cnt > 0
}
// execute update with parameters
func (o *querySet) Update(values Params) (int64, error) {
return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ)
}
// execute delete
func (o *querySet) Delete() (int64, error) {
return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
}
// return a insert queryer.
// it can be used in times.
// example:
// i,err := sq.PrepareInsert()
// i.Add(&user1{},&user2{})
func (o *querySet) PrepareInsert() (Inserter, error) {
return newInsertSet(o.orm, o.mi)
}
// query all data and map to containers.
// cols means the columns when querying.
func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
}
// query one row data and map to containers.
// cols means the columns when querying.
func (o *querySet) One(container interface{}, cols ...string) error {
o.limit = 1
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
if err != nil {
return err
}
if num == 0 {
return ErrNoRows
}
if num > 1 {
return ErrMultiRows
}
return nil
}
// query all data and map to []map[string]interface.
// expres means condition expression.
// it converts data to []map[column]value.
func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
}
// query all data and map to [][]interface
// it converts data to [][column_index]value
func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
}
// query all data and map to []interface.
// it's designed for one row record set, auto change to []value, not [][column]value.
func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ)
}
// query all rows into map[string]interface with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to map[string]interface{}{
// "total": 100,
// "found": 200,
// }
func (o *querySet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) {
panic(ErrNotImplement)
}
// query all rows into struct with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to struct {
// Total int
// Found int
// }
func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) {
panic(ErrNotImplement)
}
// set context to QuerySeter.
func (o querySet) WithContext(ctx context.Context) QuerySeter {
o.ctx = ctx
o.forContext = true
return &o
}
// create new QuerySeter.
func newQuerySet(orm *orm, mi *modelInfo) QuerySeter {
o := new(querySet)
o.mi = mi
o.orm = orm
return o
}

867
pkg/orm/orm_raw.go Normal file
View File

@ -0,0 +1,867 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"database/sql"
"fmt"
"reflect"
"time"
)
// raw sql string prepared statement
type rawPrepare struct {
rs *rawSet
stmt stmtQuerier
closed bool
}
func (o *rawPrepare) Exec(args ...interface{}) (sql.Result, error) {
if o.closed {
return nil, ErrStmtClosed
}
return o.stmt.Exec(args...)
}
func (o *rawPrepare) Close() error {
o.closed = true
return o.stmt.Close()
}
func newRawPreparer(rs *rawSet) (RawPreparer, error) {
o := new(rawPrepare)
o.rs = rs
query := rs.query
rs.orm.alias.DbBaser.ReplaceMarks(&query)
st, err := rs.orm.db.Prepare(query)
if err != nil {
return nil, err
}
if Debug {
o.stmt = newStmtQueryLog(rs.orm.alias, st, query)
} else {
o.stmt = st
}
return o, nil
}
// raw query seter
type rawSet struct {
query string
args []interface{}
orm *orm
}
var _ RawSeter = new(rawSet)
// set args for every query
func (o rawSet) SetArgs(args ...interface{}) RawSeter {
o.args = args
return &o
}
// execute raw sql and return sql.Result
func (o *rawSet) Exec() (sql.Result, error) {
query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query)
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
return o.orm.db.Exec(query, args...)
}
// set field value to row container
func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
switch ind.Kind() {
case reflect.Bool:
if value == nil {
ind.SetBool(false)
} else if v, ok := value.(bool); ok {
ind.SetBool(v)
} else {
v, _ := StrTo(ToStr(value)).Bool()
ind.SetBool(v)
}
case reflect.String:
if value == nil {
ind.SetString("")
} else {
ind.SetString(ToStr(value))
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if value == nil {
ind.SetInt(0)
} else {
val := reflect.ValueOf(value)
switch val.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
ind.SetInt(val.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
ind.SetInt(int64(val.Uint()))
default:
v, _ := StrTo(ToStr(value)).Int64()
ind.SetInt(v)
}
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if value == nil {
ind.SetUint(0)
} else {
val := reflect.ValueOf(value)
switch val.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
ind.SetUint(uint64(val.Int()))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
ind.SetUint(val.Uint())
default:
v, _ := StrTo(ToStr(value)).Uint64()
ind.SetUint(v)
}
}
case reflect.Float64, reflect.Float32:
if value == nil {
ind.SetFloat(0)
} else {
val := reflect.ValueOf(value)
switch val.Kind() {
case reflect.Float64:
ind.SetFloat(val.Float())
default:
v, _ := StrTo(ToStr(value)).Float64()
ind.SetFloat(v)
}
}
case reflect.Struct:
if value == nil {
ind.Set(reflect.Zero(ind.Type()))
return
}
switch ind.Interface().(type) {
case time.Time:
var str string
switch d := value.(type) {
case time.Time:
o.orm.alias.DbBaser.TimeFromDB(&d, o.orm.alias.TZ)
ind.Set(reflect.ValueOf(d))
case []byte:
str = string(d)
case string:
str = d
}
if str != "" {
if len(str) >= 19 {
str = str[:19]
t, err := time.ParseInLocation(formatDateTime, str, o.orm.alias.TZ)
if err == nil {
t = t.In(DefaultTimeLoc)
ind.Set(reflect.ValueOf(t))
}
} else if len(str) >= 10 {
str = str[:10]
t, err := time.ParseInLocation(formatDate, str, DefaultTimeLoc)
if err == nil {
ind.Set(reflect.ValueOf(t))
}
}
}
case sql.NullString, sql.NullInt64, sql.NullFloat64, sql.NullBool:
indi := reflect.New(ind.Type()).Interface()
sc, ok := indi.(sql.Scanner)
if !ok {
return
}
err := sc.Scan(value)
if err == nil {
ind.Set(reflect.Indirect(reflect.ValueOf(sc)))
}
}
case reflect.Ptr:
if value == nil {
ind.Set(reflect.Zero(ind.Type()))
break
}
ind.Set(reflect.New(ind.Type().Elem()))
o.setFieldValue(reflect.Indirect(ind), value)
}
}
// set field value in loop for slice container
func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) {
nInds := *nIndsPtr
cur := 0
for i := 0; i < len(sInds); i++ {
sInd := sInds[i]
eTyp := eTyps[i]
typ := eTyp
isPtr := false
if typ.Kind() == reflect.Ptr {
isPtr = true
typ = typ.Elem()
}
if typ.Kind() == reflect.Ptr {
isPtr = true
typ = typ.Elem()
}
var nInd reflect.Value
if init {
nInd = reflect.New(sInd.Type()).Elem()
} else {
nInd = nInds[i]
}
val := reflect.New(typ)
ind := val.Elem()
tpName := ind.Type().String()
if ind.Kind() == reflect.Struct {
if tpName == "time.Time" {
value := reflect.ValueOf(refs[cur]).Elem().Interface()
if isPtr && value == nil {
val = reflect.New(val.Type()).Elem()
} else {
o.setFieldValue(ind, value)
}
cur++
}
} else {
value := reflect.ValueOf(refs[cur]).Elem().Interface()
if isPtr && value == nil {
val = reflect.New(val.Type()).Elem()
} else {
o.setFieldValue(ind, value)
}
cur++
}
if nInd.Kind() == reflect.Slice {
if isPtr {
nInd = reflect.Append(nInd, val)
} else {
nInd = reflect.Append(nInd, ind)
}
} else {
if isPtr {
nInd.Set(val)
} else {
nInd.Set(ind)
}
}
nInds[i] = nInd
}
}
// query data and map to container
func (o *rawSet) QueryRow(containers ...interface{}) error {
var (
refs = make([]interface{}, 0, len(containers))
sInds []reflect.Value
eTyps []reflect.Type
sMi *modelInfo
)
structMode := false
for _, container := range containers {
val := reflect.ValueOf(container)
ind := reflect.Indirect(val)
if val.Kind() != reflect.Ptr {
panic(fmt.Errorf("<RawSeter.QueryRow> all args must be use ptr"))
}
etyp := ind.Type()
typ := etyp
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
sInds = append(sInds, ind)
eTyps = append(eTyps, etyp)
if typ.Kind() == reflect.Struct && typ.String() != "time.Time" {
if len(containers) > 1 {
panic(fmt.Errorf("<RawSeter.QueryRow> now support one struct only. see #384"))
}
structMode = true
fn := getFullName(typ)
if mi, ok := modelCache.getByFullName(fn); ok {
sMi = mi
}
} else {
var ref interface{}
refs = append(refs, &ref)
}
}
query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query)
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
rows, err := o.orm.db.Query(query, args...)
if err != nil {
if err == sql.ErrNoRows {
return ErrNoRows
}
return err
}
defer rows.Close()
if rows.Next() {
if structMode {
columns, err := rows.Columns()
if err != nil {
return err
}
columnsMp := make(map[string]interface{}, len(columns))
refs = make([]interface{}, 0, len(columns))
for _, col := range columns {
var ref interface{}
columnsMp[col] = &ref
refs = append(refs, &ref)
}
if err := rows.Scan(refs...); err != nil {
return err
}
ind := sInds[0]
if ind.Kind() == reflect.Ptr {
if ind.IsNil() || !ind.IsValid() {
ind.Set(reflect.New(eTyps[0].Elem()))
}
ind = ind.Elem()
}
if sMi != nil {
for _, col := range columns {
if fi := sMi.fields.GetByColumn(col); fi != nil {
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
field := ind.FieldByIndex(fi.fieldIndex)
if fi.fieldType&IsRelField > 0 {
mf := reflect.New(fi.relModelInfo.addrField.Elem().Type())
field.Set(mf)
field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex)
}
o.setFieldValue(field, value)
}
}
} else {
for i := 0; i < ind.NumField(); i++ {
f := ind.Field(i)
fe := ind.Type().Field(i)
_, tags := parseStructTag(fe.Tag.Get(defaultStructTagName))
var col string
if col = tags["column"]; col == "" {
col = nameStrategyMap[nameStrategy](fe.Name)
}
if v, ok := columnsMp[col]; ok {
value := reflect.ValueOf(v).Elem().Interface()
o.setFieldValue(f, value)
}
}
}
} else {
if err := rows.Scan(refs...); err != nil {
return err
}
nInds := make([]reflect.Value, len(sInds))
o.loopSetRefs(refs, sInds, &nInds, eTyps, true)
for i, sInd := range sInds {
nInd := nInds[i]
sInd.Set(nInd)
}
}
} else {
return ErrNoRows
}
return nil
}
// query data rows and map to container
func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
var (
refs = make([]interface{}, 0, len(containers))
sInds []reflect.Value
eTyps []reflect.Type
sMi *modelInfo
)
structMode := false
for _, container := range containers {
val := reflect.ValueOf(container)
sInd := reflect.Indirect(val)
if val.Kind() != reflect.Ptr || sInd.Kind() != reflect.Slice {
panic(fmt.Errorf("<RawSeter.QueryRows> all args must be use ptr slice"))
}
etyp := sInd.Type().Elem()
typ := etyp
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
sInds = append(sInds, sInd)
eTyps = append(eTyps, etyp)
if typ.Kind() == reflect.Struct && typ.String() != "time.Time" {
if len(containers) > 1 {
panic(fmt.Errorf("<RawSeter.QueryRow> now support one struct only. see #384"))
}
structMode = true
fn := getFullName(typ)
if mi, ok := modelCache.getByFullName(fn); ok {
sMi = mi
}
} else {
var ref interface{}
refs = append(refs, &ref)
}
}
query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query)
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
rows, err := o.orm.db.Query(query, args...)
if err != nil {
return 0, err
}
defer rows.Close()
var cnt int64
nInds := make([]reflect.Value, len(sInds))
sInd := sInds[0]
for rows.Next() {
if structMode {
columns, err := rows.Columns()
if err != nil {
return 0, err
}
columnsMp := make(map[string]interface{}, len(columns))
refs = make([]interface{}, 0, len(columns))
for _, col := range columns {
var ref interface{}
columnsMp[col] = &ref
refs = append(refs, &ref)
}
if err := rows.Scan(refs...); err != nil {
return 0, err
}
if cnt == 0 && !sInd.IsNil() {
sInd.Set(reflect.New(sInd.Type()).Elem())
}
var ind reflect.Value
if eTyps[0].Kind() == reflect.Ptr {
ind = reflect.New(eTyps[0].Elem())
} else {
ind = reflect.New(eTyps[0])
}
if ind.Kind() == reflect.Ptr {
ind = ind.Elem()
}
if sMi != nil {
for _, col := range columns {
if fi := sMi.fields.GetByColumn(col); fi != nil {
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
field := ind.FieldByIndex(fi.fieldIndex)
if fi.fieldType&IsRelField > 0 {
mf := reflect.New(fi.relModelInfo.addrField.Elem().Type())
field.Set(mf)
field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex)
}
o.setFieldValue(field, value)
}
}
} else {
// define recursive function
var recursiveSetField func(rv reflect.Value)
recursiveSetField = func(rv reflect.Value) {
for i := 0; i < rv.NumField(); i++ {
f := rv.Field(i)
fe := rv.Type().Field(i)
// check if the field is a Struct
// recursive the Struct type
if fe.Type.Kind() == reflect.Struct {
recursiveSetField(f)
}
_, tags := parseStructTag(fe.Tag.Get(defaultStructTagName))
var col string
if col = tags["column"]; col == "" {
col = nameStrategyMap[nameStrategy](fe.Name)
}
if v, ok := columnsMp[col]; ok {
value := reflect.ValueOf(v).Elem().Interface()
o.setFieldValue(f, value)
}
}
}
// init call the recursive function
recursiveSetField(ind)
}
if eTyps[0].Kind() == reflect.Ptr {
ind = ind.Addr()
}
sInd = reflect.Append(sInd, ind)
} else {
if err := rows.Scan(refs...); err != nil {
return 0, err
}
o.loopSetRefs(refs, sInds, &nInds, eTyps, cnt == 0)
}
cnt++
}
if cnt > 0 {
if structMode {
sInds[0].Set(sInd)
} else {
for i, sInd := range sInds {
nInd := nInds[i]
sInd.Set(nInd)
}
}
}
return cnt, nil
}
func (o *rawSet) readValues(container interface{}, needCols []string) (int64, error) {
var (
maps []Params
lists []ParamsList
list ParamsList
)
typ := 0
switch container.(type) {
case *[]Params:
typ = 1
case *[]ParamsList:
typ = 2
case *ParamsList:
typ = 3
default:
panic(fmt.Errorf("<RawSeter> unsupport read values type `%T`", container))
}
query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query)
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
var rs *sql.Rows
rs, err := o.orm.db.Query(query, args...)
if err != nil {
return 0, err
}
defer rs.Close()
var (
refs []interface{}
cnt int64
cols []string
indexs []int
)
for rs.Next() {
if cnt == 0 {
columns, err := rs.Columns()
if err != nil {
return 0, err
}
if len(needCols) > 0 {
indexs = make([]int, 0, len(needCols))
} else {
indexs = make([]int, 0, len(columns))
}
cols = columns
refs = make([]interface{}, len(cols))
for i := range refs {
var ref sql.NullString
refs[i] = &ref
if len(needCols) > 0 {
for _, c := range needCols {
if c == cols[i] {
indexs = append(indexs, i)
}
}
} else {
indexs = append(indexs, i)
}
}
}
if err := rs.Scan(refs...); err != nil {
return 0, err
}
switch typ {
case 1:
params := make(Params, len(cols))
for _, i := range indexs {
ref := refs[i]
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
if value.Valid {
params[cols[i]] = value.String
} else {
params[cols[i]] = nil
}
}
maps = append(maps, params)
case 2:
params := make(ParamsList, 0, len(cols))
for _, i := range indexs {
ref := refs[i]
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
if value.Valid {
params = append(params, value.String)
} else {
params = append(params, nil)
}
}
lists = append(lists, params)
case 3:
for _, i := range indexs {
ref := refs[i]
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
if value.Valid {
list = append(list, value.String)
} else {
list = append(list, nil)
}
}
}
cnt++
}
switch v := container.(type) {
case *[]Params:
*v = maps
case *[]ParamsList:
*v = lists
case *ParamsList:
*v = list
}
return cnt, nil
}
func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (int64, error) {
var (
maps Params
ind *reflect.Value
)
var typ int
switch container.(type) {
case *Params:
typ = 1
default:
typ = 2
vl := reflect.ValueOf(container)
id := reflect.Indirect(vl)
if vl.Kind() != reflect.Ptr || id.Kind() != reflect.Struct {
panic(fmt.Errorf("<RawSeter> RowsTo unsupport type `%T` need ptr struct", container))
}
ind = &id
}
query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query)
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
rs, err := o.orm.db.Query(query, args...)
if err != nil {
return 0, err
}
defer rs.Close()
var (
refs []interface{}
cnt int64
cols []string
)
var (
keyIndex = -1
valueIndex = -1
)
for rs.Next() {
if cnt == 0 {
columns, err := rs.Columns()
if err != nil {
return 0, err
}
cols = columns
refs = make([]interface{}, len(cols))
for i := range refs {
if keyCol == cols[i] {
keyIndex = i
}
if typ == 1 || keyIndex == i {
var ref sql.NullString
refs[i] = &ref
} else {
var ref interface{}
refs[i] = &ref
}
if valueCol == cols[i] {
valueIndex = i
}
}
if keyIndex == -1 || valueIndex == -1 {
panic(fmt.Errorf("<RawSeter> RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol))
}
}
if err := rs.Scan(refs...); err != nil {
return 0, err
}
if cnt == 0 {
switch typ {
case 1:
maps = make(Params)
}
}
key := reflect.Indirect(reflect.ValueOf(refs[keyIndex])).Interface().(sql.NullString).String
switch typ {
case 1:
value := reflect.Indirect(reflect.ValueOf(refs[valueIndex])).Interface().(sql.NullString)
if value.Valid {
maps[key] = value.String
} else {
maps[key] = nil
}
default:
if id := ind.FieldByName(camelString(key)); id.IsValid() {
o.setFieldValue(id, reflect.ValueOf(refs[valueIndex]).Elem().Interface())
}
}
cnt++
}
if typ == 1 {
v, _ := container.(*Params)
*v = maps
}
return cnt, nil
}
// query data to []map[string]interface
func (o *rawSet) Values(container *[]Params, cols ...string) (int64, error) {
return o.readValues(container, cols)
}
// query data to [][]interface
func (o *rawSet) ValuesList(container *[]ParamsList, cols ...string) (int64, error) {
return o.readValues(container, cols)
}
// query data to []interface
func (o *rawSet) ValuesFlat(container *ParamsList, cols ...string) (int64, error) {
return o.readValues(container, cols)
}
// query all rows into map[string]interface with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to map[string]interface{}{
// "total": 100,
// "found": 200,
// }
func (o *rawSet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) {
return o.queryRowsTo(result, keyCol, valueCol)
}
// query all rows into struct with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to struct {
// Total int
// Found int
// }
func (o *rawSet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) {
return o.queryRowsTo(ptrStruct, keyCol, valueCol)
}
// return prepared raw statement for used in times.
func (o *rawSet) Prepare() (RawPreparer, error) {
return newRawPreparer(o)
}
func newRawSet(orm *orm, query string, args []interface{}) RawSeter {
o := new(rawSet)
o.query = query
o.args = args
o.orm = orm
return o
}

2494
pkg/orm/orm_test.go Normal file

File diff suppressed because it is too large Load Diff

62
pkg/orm/qb.go Normal file
View File

@ -0,0 +1,62 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import "errors"
// QueryBuilder is the Query builder interface
type QueryBuilder interface {
Select(fields ...string) QueryBuilder
ForUpdate() QueryBuilder
From(tables ...string) QueryBuilder
InnerJoin(table string) QueryBuilder
LeftJoin(table string) QueryBuilder
RightJoin(table string) QueryBuilder
On(cond string) QueryBuilder
Where(cond string) QueryBuilder
And(cond string) QueryBuilder
Or(cond string) QueryBuilder
In(vals ...string) QueryBuilder
OrderBy(fields ...string) QueryBuilder
Asc() QueryBuilder
Desc() QueryBuilder
Limit(limit int) QueryBuilder
Offset(offset int) QueryBuilder
GroupBy(fields ...string) QueryBuilder
Having(cond string) QueryBuilder
Update(tables ...string) QueryBuilder
Set(kv ...string) QueryBuilder
Delete(tables ...string) QueryBuilder
InsertInto(table string, fields ...string) QueryBuilder
Values(vals ...string) QueryBuilder
Subquery(sub string, alias string) string
String() string
}
// NewQueryBuilder return the QueryBuilder
func NewQueryBuilder(driver string) (qb QueryBuilder, err error) {
if driver == "mysql" {
qb = new(MySQLQueryBuilder)
} else if driver == "tidb" {
qb = new(TiDBQueryBuilder)
} else if driver == "postgres" {
err = errors.New("postgres query builder is not supported yet")
} else if driver == "sqlite" {
err = errors.New("sqlite query builder is not supported yet")
} else {
err = errors.New("unknown driver for query builder")
}
return
}

185
pkg/orm/qb_mysql.go Normal file
View File

@ -0,0 +1,185 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"fmt"
"strconv"
"strings"
)
// CommaSpace is the separation
const CommaSpace = ", "
// MySQLQueryBuilder is the SQL build
type MySQLQueryBuilder struct {
Tokens []string
}
// Select will join the fields
func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace))
return qb
}
// ForUpdate add the FOR UPDATE clause
func (qb *MySQLQueryBuilder) ForUpdate() QueryBuilder {
qb.Tokens = append(qb.Tokens, "FOR UPDATE")
return qb
}
// From join the tables
func (qb *MySQLQueryBuilder) From(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace))
return qb
}
// InnerJoin INNER JOIN the table
func (qb *MySQLQueryBuilder) InnerJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "INNER JOIN", table)
return qb
}
// LeftJoin LEFT JOIN the table
func (qb *MySQLQueryBuilder) LeftJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "LEFT JOIN", table)
return qb
}
// RightJoin RIGHT JOIN the table
func (qb *MySQLQueryBuilder) RightJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table)
return qb
}
// On join with on cond
func (qb *MySQLQueryBuilder) On(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "ON", cond)
return qb
}
// Where join the Where cond
func (qb *MySQLQueryBuilder) Where(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "WHERE", cond)
return qb
}
// And join the and cond
func (qb *MySQLQueryBuilder) And(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "AND", cond)
return qb
}
// Or join the or cond
func (qb *MySQLQueryBuilder) Or(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "OR", cond)
return qb
}
// In join the IN (vals)
func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")")
return qb
}
// OrderBy join the Order by fields
func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace))
return qb
}
// Asc join the asc
func (qb *MySQLQueryBuilder) Asc() QueryBuilder {
qb.Tokens = append(qb.Tokens, "ASC")
return qb
}
// Desc join the desc
func (qb *MySQLQueryBuilder) Desc() QueryBuilder {
qb.Tokens = append(qb.Tokens, "DESC")
return qb
}
// Limit join the limit num
func (qb *MySQLQueryBuilder) Limit(limit int) QueryBuilder {
qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit))
return qb
}
// Offset join the offset num
func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder {
qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset))
return qb
}
// GroupBy join the Group by fields
func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace))
return qb
}
// Having join the Having cond
func (qb *MySQLQueryBuilder) Having(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "HAVING", cond)
return qb
}
// Update join the update table
func (qb *MySQLQueryBuilder) Update(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace))
return qb
}
// Set join the set kv
func (qb *MySQLQueryBuilder) Set(kv ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace))
return qb
}
// Delete join the Delete tables
func (qb *MySQLQueryBuilder) Delete(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "DELETE")
if len(tables) != 0 {
qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace))
}
return qb
}
// InsertInto join the insert SQL
func (qb *MySQLQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "INSERT INTO", table)
if len(fields) != 0 {
fieldsStr := strings.Join(fields, CommaSpace)
qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")")
}
return qb
}
// Values join the Values(vals)
func (qb *MySQLQueryBuilder) Values(vals ...string) QueryBuilder {
valsStr := strings.Join(vals, CommaSpace)
qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")")
return qb
}
// Subquery join the sub as alias
func (qb *MySQLQueryBuilder) Subquery(sub string, alias string) string {
return fmt.Sprintf("(%s) AS %s", sub, alias)
}
// String join all Tokens
func (qb *MySQLQueryBuilder) String() string {
return strings.Join(qb.Tokens, " ")
}

182
pkg/orm/qb_tidb.go Normal file
View File

@ -0,0 +1,182 @@
// Copyright 2015 TiDB Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"fmt"
"strconv"
"strings"
)
// TiDBQueryBuilder is the SQL build
type TiDBQueryBuilder struct {
Tokens []string
}
// Select will join the fields
func (qb *TiDBQueryBuilder) Select(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace))
return qb
}
// ForUpdate add the FOR UPDATE clause
func (qb *TiDBQueryBuilder) ForUpdate() QueryBuilder {
qb.Tokens = append(qb.Tokens, "FOR UPDATE")
return qb
}
// From join the tables
func (qb *TiDBQueryBuilder) From(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace))
return qb
}
// InnerJoin INNER JOIN the table
func (qb *TiDBQueryBuilder) InnerJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "INNER JOIN", table)
return qb
}
// LeftJoin LEFT JOIN the table
func (qb *TiDBQueryBuilder) LeftJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "LEFT JOIN", table)
return qb
}
// RightJoin RIGHT JOIN the table
func (qb *TiDBQueryBuilder) RightJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table)
return qb
}
// On join with on cond
func (qb *TiDBQueryBuilder) On(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "ON", cond)
return qb
}
// Where join the Where cond
func (qb *TiDBQueryBuilder) Where(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "WHERE", cond)
return qb
}
// And join the and cond
func (qb *TiDBQueryBuilder) And(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "AND", cond)
return qb
}
// Or join the or cond
func (qb *TiDBQueryBuilder) Or(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "OR", cond)
return qb
}
// In join the IN (vals)
func (qb *TiDBQueryBuilder) In(vals ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")")
return qb
}
// OrderBy join the Order by fields
func (qb *TiDBQueryBuilder) OrderBy(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace))
return qb
}
// Asc join the asc
func (qb *TiDBQueryBuilder) Asc() QueryBuilder {
qb.Tokens = append(qb.Tokens, "ASC")
return qb
}
// Desc join the desc
func (qb *TiDBQueryBuilder) Desc() QueryBuilder {
qb.Tokens = append(qb.Tokens, "DESC")
return qb
}
// Limit join the limit num
func (qb *TiDBQueryBuilder) Limit(limit int) QueryBuilder {
qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit))
return qb
}
// Offset join the offset num
func (qb *TiDBQueryBuilder) Offset(offset int) QueryBuilder {
qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset))
return qb
}
// GroupBy join the Group by fields
func (qb *TiDBQueryBuilder) GroupBy(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace))
return qb
}
// Having join the Having cond
func (qb *TiDBQueryBuilder) Having(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "HAVING", cond)
return qb
}
// Update join the update table
func (qb *TiDBQueryBuilder) Update(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace))
return qb
}
// Set join the set kv
func (qb *TiDBQueryBuilder) Set(kv ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace))
return qb
}
// Delete join the Delete tables
func (qb *TiDBQueryBuilder) Delete(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "DELETE")
if len(tables) != 0 {
qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace))
}
return qb
}
// InsertInto join the insert SQL
func (qb *TiDBQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "INSERT INTO", table)
if len(fields) != 0 {
fieldsStr := strings.Join(fields, CommaSpace)
qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")")
}
return qb
}
// Values join the Values(vals)
func (qb *TiDBQueryBuilder) Values(vals ...string) QueryBuilder {
valsStr := strings.Join(vals, CommaSpace)
qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")")
return qb
}
// Subquery join the sub as alias
func (qb *TiDBQueryBuilder) Subquery(sub string, alias string) string {
return fmt.Sprintf("(%s) AS %s", sub, alias)
}
// String join all Tokens
func (qb *TiDBQueryBuilder) String() string {
return strings.Join(qb.Tokens, " ")
}

473
pkg/orm/types.go Normal file
View File

@ -0,0 +1,473 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"context"
"database/sql"
"reflect"
"time"
)
// Driver define database driver
type Driver interface {
Name() string
Type() DriverType
}
// Fielder define field info
type Fielder interface {
String() string
FieldType() int
SetRaw(interface{}) error
RawValue() interface{}
}
// Ormer define the orm interface
type Ormer interface {
// read data to model
// for example:
// this will find User by Id field
// u = &User{Id: user.Id}
// err = Ormer.Read(u)
// this will find User by UserName field
// u = &User{UserName: "astaxie", Password: "pass"}
// err = Ormer.Read(u, "UserName")
Read(md interface{}, cols ...string) error
// Like Read(), but with "FOR UPDATE" clause, useful in transaction.
// Some databases are not support this feature.
ReadForUpdate(md interface{}, cols ...string) error
// Try to read a row from the database, or insert one if it doesn't exist
ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error)
// insert model data to database
// for example:
// user := new(User)
// id, err = Ormer.Insert(user)
// user must be a pointer and Insert will set user's pk field
Insert(interface{}) (int64, error)
// mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value")
// if colu type is integer : can use(+-*/), string : convert(colu,"value")
// postgres: InsertOrUpdate(model,"conflictColumnName") or InsertOrUpdate(model,"conflictColumnName","colu=colu+value")
// if colu type is integer : can use(+-*/), string : colu || "value"
InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error)
// insert some models to database
InsertMulti(bulk int, mds interface{}) (int64, error)
// update model to database.
// cols set the columns those want to update.
// find model by Id(pk) field and update columns specified by fields, if cols is null then update all columns
// for example:
// user := User{Id: 2}
// user.Langs = append(user.Langs, "zh-CN", "en-US")
// user.Extra.Name = "beego"
// user.Extra.Data = "orm"
// num, err = Ormer.Update(&user, "Langs", "Extra")
Update(md interface{}, cols ...string) (int64, error)
// delete model in database
Delete(md interface{}, cols ...string) (int64, error)
// load related models to md model.
// args are limit, offset int and order string.
//
// example:
// Ormer.LoadRelated(post,"Tags")
// for _,tag := range post.Tags{...}
//args[0] bool true useDefaultRelsDepth ; false depth 0
//args[0] int loadRelationDepth
//args[1] int limit default limit 1000
//args[2] int offset default offset 0
//args[3] string order for example : "-Id"
// make sure the relation is defined in model struct tags.
LoadRelated(md interface{}, name string, args ...interface{}) (int64, error)
// create a models to models queryer
// for example:
// post := Post{Id: 4}
// m2m := Ormer.QueryM2M(&post, "Tags")
QueryM2M(md interface{}, name string) QueryM2Mer
// return a QuerySeter for table operations.
// table name can be string or struct.
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
QueryTable(ptrStructOrTableName interface{}) QuerySeter
// switch to another registered database driver by given name.
Using(name string) error
// begin transaction
// for example:
// o := NewOrm()
// err := o.Begin()
// ...
// err = o.Rollback()
Begin() error
// begin transaction with provided context and option
// the provided context is used until the transaction is committed or rolled back.
// if the context is canceled, the transaction will be rolled back.
// the provided TxOptions is optional and may be nil if defaults should be used.
// if a non-default isolation level is used that the driver doesn't support, an error will be returned.
// for example:
// o := NewOrm()
// err := o.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead})
// ...
// err = o.Rollback()
BeginTx(ctx context.Context, opts *sql.TxOptions) error
// commit transaction
Commit() error
// rollback transaction
Rollback() error
// return a raw query seter for raw sql string.
// for example:
// ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec()
// // update user testing's name to slene
Raw(query string, args ...interface{}) RawSeter
Driver() Driver
DBStats() *sql.DBStats
}
// Inserter insert prepared statement
type Inserter interface {
Insert(interface{}) (int64, error)
Close() error
}
// QuerySeter query seter
type QuerySeter interface {
// add condition expression to QuerySeter.
// for example:
// filter by UserName == 'slene'
// qs.Filter("UserName", "slene")
// sql : left outer join profile on t0.id1==t1.id2 where t1.age == 28
// Filter("profile__Age", 28)
// // time compare
// qs.Filter("created", time.Now())
Filter(string, ...interface{}) QuerySeter
// add raw sql to querySeter.
// for example:
// qs.FilterRaw("user_id IN (SELECT id FROM profile WHERE age>=18)")
// //sql-> WHERE user_id IN (SELECT id FROM profile WHERE age>=18)
FilterRaw(string, string) QuerySeter
// add NOT condition to querySeter.
// have the same usage as Filter
Exclude(string, ...interface{}) QuerySeter
// set condition to QuerySeter.
// sql's where condition
// cond := orm.NewCondition()
// cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000)
// //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000
// num, err := qs.SetCond(cond1).Count()
SetCond(*Condition) QuerySeter
// get condition from QuerySeter.
// sql's where condition
// cond := orm.NewCondition()
// cond = cond.And("profile__isnull", false).AndNot("status__in", 1)
// qs = qs.SetCond(cond)
// cond = qs.GetCond()
// cond := cond.Or("profile__age__gt", 2000)
// //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000
// num, err := qs.SetCond(cond).Count()
GetCond() *Condition
// add LIMIT value.
// args[0] means offset, e.g. LIMIT num,offset.
// if Limit <= 0 then Limit will be set to default limit ,eg 1000
// if QuerySeter doesn't call Limit, the sql's Limit will be set to default limit, eg 1000
// for example:
// qs.Limit(10, 2)
// // sql-> limit 10 offset 2
Limit(limit interface{}, args ...interface{}) QuerySeter
// add OFFSET value
// same as Limit function's args[0]
Offset(offset interface{}) QuerySeter
// add GROUP BY expression
// for example:
// qs.GroupBy("id")
GroupBy(exprs ...string) QuerySeter
// add ORDER expression.
// "column" means ASC, "-column" means DESC.
// for example:
// qs.OrderBy("-status")
OrderBy(exprs ...string) QuerySeter
// set relation model to query together.
// it will query relation models and assign to parent model.
// for example:
// // will load all related fields use left join .
// qs.RelatedSel().One(&user)
// // will load related field only profile
// qs.RelatedSel("profile").One(&user)
// user.Profile.Age = 32
RelatedSel(params ...interface{}) QuerySeter
// Set Distinct
// for example:
// o.QueryTable("policy").Filter("Groups__Group__Users__User", user).
// Distinct().
// All(&permissions)
Distinct() QuerySeter
// set FOR UPDATE to query.
// for example:
// o.QueryTable("user").Filter("uid", uid).ForUpdate().All(&users)
ForUpdate() QuerySeter
// return QuerySeter execution result number
// for example:
// num, err = qs.Filter("profile__age__gt", 28).Count()
Count() (int64, error)
// check result empty or not after QuerySeter executed
// the same as QuerySeter.Count > 0
Exist() bool
// execute update with parameters
// for example:
// num, err = qs.Filter("user_name", "slene").Update(Params{
// "Nums": ColValue(Col_Minus, 50),
// }) // user slene's Nums will minus 50
// num, err = qs.Filter("UserName", "slene").Update(Params{
// "user_name": "slene2"
// }) // user slene's name will change to slene2
Update(values Params) (int64, error)
// delete from table
//for example:
// num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete()
// //delete two user who's name is testing1 or testing2
Delete() (int64, error)
// return a insert queryer.
// it can be used in times.
// example:
// i,err := sq.PrepareInsert()
// num, err = i.Insert(&user1) // user table will add one record user1 at once
// num, err = i.Insert(&user2) // user table will add one record user2 at once
// err = i.Close() //don't forget call Close
PrepareInsert() (Inserter, error)
// query all data and map to containers.
// cols means the columns when querying.
// for example:
// var users []*User
// qs.All(&users) // users[0],users[1],users[2] ...
All(container interface{}, cols ...string) (int64, error)
// query one row data and map to containers.
// cols means the columns when querying.
// for example:
// var user User
// qs.One(&user) //user.UserName == "slene"
One(container interface{}, cols ...string) error
// query all data and map to []map[string]interface.
// expres means condition expression.
// it converts data to []map[column]value.
// for example:
// var maps []Params
// qs.Values(&maps) //maps[0]["UserName"]=="slene"
Values(results *[]Params, exprs ...string) (int64, error)
// query all data and map to [][]interface
// it converts data to [][column_index]value
// for example:
// var list []ParamsList
// qs.ValuesList(&list) // list[0][1] == "slene"
ValuesList(results *[]ParamsList, exprs ...string) (int64, error)
// query all data and map to []interface.
// it's designed for one column record set, auto change to []value, not [][column]value.
// for example:
// var list ParamsList
// qs.ValuesFlat(&list, "UserName") // list[0] == "slene"
ValuesFlat(result *ParamsList, expr string) (int64, error)
// query all rows into map[string]interface with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to map[string]interface{}{
// "total": 100,
// "found": 200,
// }
RowsToMap(result *Params, keyCol, valueCol string) (int64, error)
// query all rows into struct with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to struct {
// Total int
// Found int
// }
RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error)
}
// QueryM2Mer model to model query struct
// all operations are on the m2m table only, will not affect the origin model table
type QueryM2Mer interface {
// add models to origin models when creating queryM2M.
// example:
// m2m := orm.QueryM2M(post,"Tag")
// m2m.Add(&Tag1{},&Tag2{})
// for _,tag := range post.Tags{}{ ... }
// param could also be any of the follow
// []*Tag{{Id:3,Name: "TestTag1"}, {Id:4,Name: "TestTag2"}}
// &Tag{Id:5,Name: "TestTag3"}
// []interface{}{&Tag{Id:6,Name: "TestTag4"}}
// insert one or more rows to m2m table
// make sure the relation is defined in post model struct tag.
Add(...interface{}) (int64, error)
// remove models following the origin model relationship
// only delete rows from m2m table
// for example:
//tag3 := &Tag{Id:5,Name: "TestTag3"}
//num, err = m2m.Remove(tag3)
Remove(...interface{}) (int64, error)
// check model is existed in relationship of origin model
Exist(interface{}) bool
// clean all models in related of origin model
Clear() (int64, error)
// count all related models of origin model
Count() (int64, error)
}
// RawPreparer raw query statement
type RawPreparer interface {
Exec(...interface{}) (sql.Result, error)
Close() error
}
// RawSeter raw query seter
// create From Ormer.Raw
// for example:
// sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q)
// rs := Ormer.Raw(sql, 1)
type RawSeter interface {
//execute sql and get result
Exec() (sql.Result, error)
//query data and map to container
//for example:
// var name string
// var id int
// rs.QueryRow(&id,&name) // id==2 name=="slene"
QueryRow(containers ...interface{}) error
// query data rows and map to container
// var ids []int
// var names []int
// query = fmt.Sprintf("SELECT 'id','name' FROM %suser%s", Q, Q)
// num, err = dORM.Raw(query).QueryRows(&ids,&names) // ids=>{1,2},names=>{"nobody","slene"}
QueryRows(containers ...interface{}) (int64, error)
SetArgs(...interface{}) RawSeter
// query data to []map[string]interface
// see QuerySeter's Values
Values(container *[]Params, cols ...string) (int64, error)
// query data to [][]interface
// see QuerySeter's ValuesList
ValuesList(container *[]ParamsList, cols ...string) (int64, error)
// query data to []interface
// see QuerySeter's ValuesFlat
ValuesFlat(container *ParamsList, cols ...string) (int64, error)
// query all rows into map[string]interface with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to map[string]interface{}{
// "total": 100,
// "found": 200,
// }
RowsToMap(result *Params, keyCol, valueCol string) (int64, error)
// query all rows into struct with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to struct {
// Total int
// Found int
// }
RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error)
// return prepared raw statement for used in times.
// for example:
// pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare()
// r, err := pre.Exec("name1") // INSERT INTO tag (name) VALUES (`name1`)
Prepare() (RawPreparer, error)
}
// stmtQuerier statement querier
type stmtQuerier interface {
Close() error
Exec(args ...interface{}) (sql.Result, error)
//ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
Query(args ...interface{}) (*sql.Rows, error)
//QueryContext(args ...interface{}) (*sql.Rows, error)
QueryRow(args ...interface{}) *sql.Row
//QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row
}
// db querier
type dbQuerier interface {
Prepare(query string) (*sql.Stmt, error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
Exec(query string, args ...interface{}) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
// type DB interface {
// Begin() (*sql.Tx, error)
// Prepare(query string) (stmtQuerier, error)
// Exec(query string, args ...interface{}) (sql.Result, error)
// Query(query string, args ...interface{}) (*sql.Rows, error)
// QueryRow(query string, args ...interface{}) *sql.Row
// }
// transaction beginner
type txer interface {
Begin() (*sql.Tx, error)
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}
// transaction ending
type txEnder interface {
Commit() error
Rollback() error
}
// base database struct
type dbBaser interface {
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error)
InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error)
SupportUpdateJoin() bool
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
OperatorSQL(string) string
GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{})
GenerateOperatorLeftCol(*fieldInfo, string, *string)
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error)
MaxLimit() uint64
TableQuote() string
ReplaceMarks(*string)
HasReturningID(*modelInfo, *string) bool
TimeFromDB(*time.Time, *time.Location)
TimeToDB(*time.Time, *time.Location)
DbTypes() map[string]string
GetTables(dbQuerier) (map[string]bool, error)
GetColumns(dbQuerier, string) (map[string][3]string, error)
ShowTablesQuery() string
ShowColumnsQuery(string) string
IndexExists(dbQuerier, string, string) bool
collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error)
setval(dbQuerier, *modelInfo, []string) error
}

319
pkg/orm/utils.go Normal file
View File

@ -0,0 +1,319 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"fmt"
"math/big"
"reflect"
"strconv"
"strings"
"time"
)
type fn func(string) string
var (
nameStrategyMap = map[string]fn{
defaultNameStrategy: snakeString,
SnakeAcronymNameStrategy: snakeStringWithAcronym,
}
defaultNameStrategy = "snakeString"
SnakeAcronymNameStrategy = "snakeStringWithAcronym"
nameStrategy = defaultNameStrategy
)
// StrTo is the target string
type StrTo string
// Set string
func (f *StrTo) Set(v string) {
if v != "" {
*f = StrTo(v)
} else {
f.Clear()
}
}
// Clear string
func (f *StrTo) Clear() {
*f = StrTo(0x1E)
}
// Exist check string exist
func (f StrTo) Exist() bool {
return string(f) != string(0x1E)
}
// Bool string to bool
func (f StrTo) Bool() (bool, error) {
return strconv.ParseBool(f.String())
}
// Float32 string to float32
func (f StrTo) Float32() (float32, error) {
v, err := strconv.ParseFloat(f.String(), 32)
return float32(v), err
}
// Float64 string to float64
func (f StrTo) Float64() (float64, error) {
return strconv.ParseFloat(f.String(), 64)
}
// Int string to int
func (f StrTo) Int() (int, error) {
v, err := strconv.ParseInt(f.String(), 10, 32)
return int(v), err
}
// Int8 string to int8
func (f StrTo) Int8() (int8, error) {
v, err := strconv.ParseInt(f.String(), 10, 8)
return int8(v), err
}
// Int16 string to int16
func (f StrTo) Int16() (int16, error) {
v, err := strconv.ParseInt(f.String(), 10, 16)
return int16(v), err
}
// Int32 string to int32
func (f StrTo) Int32() (int32, error) {
v, err := strconv.ParseInt(f.String(), 10, 32)
return int32(v), err
}
// Int64 string to int64
func (f StrTo) Int64() (int64, error) {
v, err := strconv.ParseInt(f.String(), 10, 64)
if err != nil {
i := new(big.Int)
ni, ok := i.SetString(f.String(), 10) // octal
if !ok {
return v, err
}
return ni.Int64(), nil
}
return v, err
}
// Uint string to uint
func (f StrTo) Uint() (uint, error) {
v, err := strconv.ParseUint(f.String(), 10, 32)
return uint(v), err
}
// Uint8 string to uint8
func (f StrTo) Uint8() (uint8, error) {
v, err := strconv.ParseUint(f.String(), 10, 8)
return uint8(v), err
}
// Uint16 string to uint16
func (f StrTo) Uint16() (uint16, error) {
v, err := strconv.ParseUint(f.String(), 10, 16)
return uint16(v), err
}
// Uint32 string to uint32
func (f StrTo) Uint32() (uint32, error) {
v, err := strconv.ParseUint(f.String(), 10, 32)
return uint32(v), err
}
// Uint64 string to uint64
func (f StrTo) Uint64() (uint64, error) {
v, err := strconv.ParseUint(f.String(), 10, 64)
if err != nil {
i := new(big.Int)
ni, ok := i.SetString(f.String(), 10)
if !ok {
return v, err
}
return ni.Uint64(), nil
}
return v, err
}
// String string to string
func (f StrTo) String() string {
if f.Exist() {
return string(f)
}
return ""
}
// ToStr interface to string
func ToStr(value interface{}, args ...int) (s string) {
switch v := value.(type) {
case bool:
s = strconv.FormatBool(v)
case float32:
s = strconv.FormatFloat(float64(v), 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 32))
case float64:
s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64))
case int:
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
case int8:
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
case int16:
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
case int32:
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
case int64:
s = strconv.FormatInt(v, argInt(args).Get(0, 10))
case uint:
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
case uint8:
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
case uint16:
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
case uint32:
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
case uint64:
s = strconv.FormatUint(v, argInt(args).Get(0, 10))
case string:
s = v
case []byte:
s = string(v)
default:
s = fmt.Sprintf("%v", v)
}
return s
}
// ToInt64 interface to int64
func ToInt64(value interface{}) (d int64) {
val := reflect.ValueOf(value)
switch value.(type) {
case int, int8, int16, int32, int64:
d = val.Int()
case uint, uint8, uint16, uint32, uint64:
d = int64(val.Uint())
default:
panic(fmt.Errorf("ToInt64 need numeric not `%T`", value))
}
return
}
func snakeStringWithAcronym(s string) string {
data := make([]byte, 0, len(s)*2)
num := len(s)
for i := 0; i < num; i++ {
d := s[i]
before := false
after := false
if i > 0 {
before = s[i-1] >= 'a' && s[i-1] <= 'z'
}
if i+1 < num {
after = s[i+1] >= 'a' && s[i+1] <= 'z'
}
if i > 0 && d >= 'A' && d <= 'Z' && (before || after) {
data = append(data, '_')
}
data = append(data, d)
}
return strings.ToLower(string(data[:]))
}
// snake string, XxYy to xx_yy , XxYY to xx_y_y
func snakeString(s string) string {
data := make([]byte, 0, len(s)*2)
j := false
num := len(s)
for i := 0; i < num; i++ {
d := s[i]
if i > 0 && d >= 'A' && d <= 'Z' && j {
data = append(data, '_')
}
if d != '_' {
j = true
}
data = append(data, d)
}
return strings.ToLower(string(data[:]))
}
// SetNameStrategy set different name strategy
func SetNameStrategy(s string) {
if SnakeAcronymNameStrategy != s {
nameStrategy = defaultNameStrategy
}
nameStrategy = s
}
// camel string, xx_yy to XxYy
func camelString(s string) string {
data := make([]byte, 0, len(s))
flag, num := true, len(s)-1
for i := 0; i <= num; i++ {
d := s[i]
if d == '_' {
flag = true
continue
} else if flag {
if d >= 'a' && d <= 'z' {
d = d - 32
}
flag = false
}
data = append(data, d)
}
return string(data[:])
}
type argString []string
// get string by index from string slice
func (a argString) Get(i int, args ...string) (r string) {
if i >= 0 && i < len(a) {
r = a[i]
} else if len(args) > 0 {
r = args[0]
}
return
}
type argInt []int
// get int by index from int slice
func (a argInt) Get(i int, args ...int) (r int) {
if i >= 0 && i < len(a) {
r = a[i]
}
if len(args) > 0 {
r = args[0]
}
return
}
// parse time to string with location
func timeParse(dateString, format string) (time.Time, error) {
tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
return tp, err
}
// get pointer indirect type
func indirectType(v reflect.Type) reflect.Type {
switch v.Kind() {
case reflect.Ptr:
return indirectType(v.Elem())
default:
return v
}
}

70
pkg/orm/utils_test.go Normal file
View File

@ -0,0 +1,70 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"testing"
)
func TestCamelString(t *testing.T) {
snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"}
camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"}
answer := make(map[string]string)
for i, v := range snake {
answer[v] = camel[i]
}
for _, v := range snake {
res := camelString(v)
if res != answer[v] {
t.Error("Unit Test Fail:", v, res, answer[v])
}
}
}
func TestSnakeString(t *testing.T) {
camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"}
snake := []string{"pic_url", "hello_world", "hello_world", "hel_l_o_word", "pic_url1", "xy_x_x"}
answer := make(map[string]string)
for i, v := range camel {
answer[v] = snake[i]
}
for _, v := range camel {
res := snakeString(v)
if res != answer[v] {
t.Error("Unit Test Fail:", v, res, answer[v])
}
}
}
func TestSnakeStringWithAcronym(t *testing.T) {
camel := []string{"ID", "PicURL", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"}
snake := []string{"id", "pic_url", "hello_world", "hello_world", "hel_lo_word", "pic_url1", "xy_xx"}
answer := make(map[string]string)
for i, v := range camel {
answer[v] = snake[i]
}
for _, v := range camel {
res := snakeStringWithAcronym(v)
if res != answer[v] {
t.Error("Unit Test Fail:", v, res, answer[v])
}
}
}