mirror of
https://github.com/astaxie/beego.git
synced 2024-12-22 20:30:49 +00:00
orm Improve syncdb
This commit is contained in:
parent
6686d9235c
commit
49bbca0ce3
130
orm/cmd.go
130
orm/cmd.go
@ -9,7 +9,7 @@ import (
|
||||
|
||||
type commander interface {
|
||||
Parse([]string)
|
||||
Run()
|
||||
Run() error
|
||||
}
|
||||
|
||||
var (
|
||||
@ -59,9 +59,11 @@ func RunCommand() {
|
||||
}
|
||||
|
||||
type commandSyncDb struct {
|
||||
al *alias
|
||||
force bool
|
||||
verbose bool
|
||||
al *alias
|
||||
force bool
|
||||
verbose bool
|
||||
noInfo bool
|
||||
rtOnError bool
|
||||
}
|
||||
|
||||
func (d *commandSyncDb) Parse(args []string) {
|
||||
@ -76,7 +78,7 @@ func (d *commandSyncDb) Parse(args []string) {
|
||||
d.al = getDbAlias(name)
|
||||
}
|
||||
|
||||
func (d *commandSyncDb) Run() {
|
||||
func (d *commandSyncDb) Run() error {
|
||||
var drops []string
|
||||
if d.force {
|
||||
drops = getDbDropSql(d.al)
|
||||
@ -87,25 +89,103 @@ func (d *commandSyncDb) Run() {
|
||||
if d.force {
|
||||
for i, mi := range modelCache.allOrdered() {
|
||||
query := drops[i]
|
||||
_, err := db.Exec(query)
|
||||
result := ""
|
||||
if err != nil {
|
||||
result = err.Error()
|
||||
if !d.noInfo {
|
||||
fmt.Printf("drop table `%s`\n", mi.table)
|
||||
}
|
||||
fmt.Printf("drop table `%s` %s\n", mi.table, result)
|
||||
_, err := db.Exec(query)
|
||||
if d.verbose {
|
||||
fmt.Printf(" %s\n\n", query)
|
||||
}
|
||||
if err != nil {
|
||||
if d.rtOnError {
|
||||
return err
|
||||
}
|
||||
fmt.Printf(" %s\n", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sqls, indexes := getDbCreateSql(d.al)
|
||||
|
||||
tables, err := d.al.DbBaser.GetTables(db)
|
||||
if err != nil {
|
||||
if d.rtOnError {
|
||||
return err
|
||||
}
|
||||
fmt.Printf(" %s\n", err.Error())
|
||||
}
|
||||
|
||||
for i, mi := range modelCache.allOrdered() {
|
||||
fmt.Printf("create table `%s` \n", mi.table)
|
||||
if tables[mi.table] {
|
||||
if !d.noInfo {
|
||||
fmt.Printf("table `%s` already exists, skip\n", mi.table)
|
||||
}
|
||||
|
||||
var fields []*fieldInfo
|
||||
columns, err := d.al.DbBaser.GetColumns(db, mi.table)
|
||||
if err != nil {
|
||||
if d.rtOnError {
|
||||
return err
|
||||
}
|
||||
fmt.Printf(" %s\n", err.Error())
|
||||
}
|
||||
|
||||
for _, fi := range mi.fields.fieldsDB {
|
||||
if _, ok := columns[fi.column]; ok == false {
|
||||
fields = append(fields, fi)
|
||||
}
|
||||
}
|
||||
|
||||
for _, fi := range fields {
|
||||
query := getColumnAddQuery(d.al, fi)
|
||||
|
||||
if !d.noInfo {
|
||||
fmt.Printf("add column `%s` for table `%s`\n", fi.fullName, mi.table)
|
||||
}
|
||||
|
||||
_, err := db.Exec(query)
|
||||
if d.verbose {
|
||||
fmt.Printf(" %s\n", query)
|
||||
}
|
||||
if err != nil {
|
||||
if d.rtOnError {
|
||||
return err
|
||||
}
|
||||
fmt.Printf(" %s\n", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
for _, idx := range indexes[mi.table] {
|
||||
if d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) == false {
|
||||
if !d.noInfo {
|
||||
fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
|
||||
}
|
||||
|
||||
query := idx.Sql
|
||||
_, err := db.Exec(query)
|
||||
if d.verbose {
|
||||
fmt.Printf(" %s\n", query)
|
||||
}
|
||||
if err != nil {
|
||||
if d.rtOnError {
|
||||
return err
|
||||
}
|
||||
fmt.Printf(" %s\n", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if !d.noInfo {
|
||||
fmt.Printf("create table `%s` \n", mi.table)
|
||||
}
|
||||
|
||||
queries := []string{sqls[i]}
|
||||
queries = append(queries, indexes[mi.table]...)
|
||||
for _, idx := range indexes[mi.table] {
|
||||
queries = append(queries, idx.Sql)
|
||||
}
|
||||
|
||||
for _, query := range queries {
|
||||
_, err := db.Exec(query)
|
||||
@ -114,6 +194,9 @@ func (d *commandSyncDb) Run() {
|
||||
fmt.Println(query)
|
||||
}
|
||||
if err != nil {
|
||||
if d.rtOnError {
|
||||
return err
|
||||
}
|
||||
fmt.Printf(" %s\n", err.Error())
|
||||
}
|
||||
}
|
||||
@ -121,6 +204,8 @@ func (d *commandSyncDb) Run() {
|
||||
fmt.Println("")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type commandSqlAll struct {
|
||||
@ -137,19 +222,36 @@ func (d *commandSqlAll) Parse(args []string) {
|
||||
d.al = getDbAlias(name)
|
||||
}
|
||||
|
||||
func (d *commandSqlAll) Run() {
|
||||
func (d *commandSqlAll) Run() error {
|
||||
sqls, indexes := getDbCreateSql(d.al)
|
||||
var all []string
|
||||
for i, mi := range modelCache.allOrdered() {
|
||||
queries := []string{sqls[i]}
|
||||
queries = append(queries, indexes[mi.table]...)
|
||||
for _, idx := range indexes[mi.table] {
|
||||
queries = append(queries, idx.Sql)
|
||||
}
|
||||
sql := strings.Join(queries, "\n")
|
||||
all = append(all, sql)
|
||||
}
|
||||
fmt.Println(strings.Join(all, "\n\n"))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
commands["syncdb"] = new(commandSyncDb)
|
||||
commands["sqlall"] = new(commandSqlAll)
|
||||
}
|
||||
|
||||
func RunSyncdb(name string, force bool, verbose bool) error {
|
||||
BootStrap()
|
||||
|
||||
al := getDbAlias(name)
|
||||
cmd := new(commandSyncDb)
|
||||
cmd.al = al
|
||||
cmd.force = force
|
||||
cmd.noInfo = !verbose
|
||||
cmd.verbose = verbose
|
||||
cmd.rtOnError = true
|
||||
return cmd.Run()
|
||||
}
|
||||
|
131
orm/cmd_utils.go
131
orm/cmd_utils.go
@ -6,6 +6,12 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type dbIndex struct {
|
||||
Table string
|
||||
Name string
|
||||
Sql string
|
||||
}
|
||||
|
||||
func getDbAlias(name string) *alias {
|
||||
if al, ok := dataBaseCache.get(name); ok {
|
||||
return al
|
||||
@ -31,7 +37,71 @@ func getDbDropSql(al *alias) (sqls []string) {
|
||||
return sqls
|
||||
}
|
||||
|
||||
func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]string) {
|
||||
func getColumnTyp(al *alias, fi *fieldInfo) (col string) {
|
||||
T := al.DbBaser.DbTypes()
|
||||
fieldType := fi.fieldType
|
||||
|
||||
checkColumn:
|
||||
switch fieldType {
|
||||
case TypeBooleanField:
|
||||
col = T["bool"]
|
||||
case TypeCharField:
|
||||
col = fmt.Sprintf(T["string"], fi.size)
|
||||
case TypeTextField:
|
||||
col = T["string-text"]
|
||||
case TypeDateField:
|
||||
col = T["time.Time-date"]
|
||||
case TypeDateTimeField:
|
||||
col = T["time.Time"]
|
||||
case TypeBitField:
|
||||
col = T["int8"]
|
||||
case TypeSmallIntegerField:
|
||||
col = T["int16"]
|
||||
case TypeIntegerField:
|
||||
col = T["int32"]
|
||||
case TypeBigIntegerField:
|
||||
if al.Driver == DR_Sqlite {
|
||||
fieldType = TypeIntegerField
|
||||
goto checkColumn
|
||||
}
|
||||
col = T["int64"]
|
||||
case TypePositiveBitField:
|
||||
col = T["uint8"]
|
||||
case TypePositiveSmallIntegerField:
|
||||
col = T["uint16"]
|
||||
case TypePositiveIntegerField:
|
||||
col = T["uint32"]
|
||||
case TypePositiveBigIntegerField:
|
||||
col = T["uint64"]
|
||||
case TypeFloatField:
|
||||
col = T["float64"]
|
||||
case TypeDecimalField:
|
||||
s := T["float64-decimal"]
|
||||
if strings.Index(s, "%d") == -1 {
|
||||
col = s
|
||||
} else {
|
||||
col = fmt.Sprintf(s, fi.digits, fi.decimals)
|
||||
}
|
||||
case RelForeignKey, RelOneToOne:
|
||||
fieldType = fi.relModelInfo.fields.pk.fieldType
|
||||
goto checkColumn
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func getColumnAddQuery(al *alias, fi *fieldInfo) string {
|
||||
Q := al.DbBaser.TableQuote()
|
||||
typ := getColumnTyp(al, fi)
|
||||
|
||||
if fi.null == false {
|
||||
typ += " " + "NOT NULL"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s", Q, fi.mi.table, Q, Q, fi.column, Q, typ)
|
||||
}
|
||||
|
||||
func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
|
||||
if len(modelCache.cache) == 0 {
|
||||
fmt.Println("no Model found, need register your model")
|
||||
os.Exit(2)
|
||||
@ -41,7 +111,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]string)
|
||||
T := al.DbBaser.DbTypes()
|
||||
sep := fmt.Sprintf("%s, %s", Q, Q)
|
||||
|
||||
tableIndexes = make(map[string][]string)
|
||||
tableIndexes = make(map[string][]dbIndex)
|
||||
|
||||
for _, mi := range modelCache.allOrdered() {
|
||||
sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
|
||||
@ -56,55 +126,8 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]string)
|
||||
|
||||
for _, fi := range mi.fields.fieldsDB {
|
||||
|
||||
fieldType := fi.fieldType
|
||||
column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q)
|
||||
col := ""
|
||||
|
||||
checkColumn:
|
||||
switch fieldType {
|
||||
case TypeBooleanField:
|
||||
col = T["bool"]
|
||||
case TypeCharField:
|
||||
col = fmt.Sprintf(T["string"], fi.size)
|
||||
case TypeTextField:
|
||||
col = T["string-text"]
|
||||
case TypeDateField:
|
||||
col = T["time.Time-date"]
|
||||
case TypeDateTimeField:
|
||||
col = T["time.Time"]
|
||||
case TypeBitField:
|
||||
col = T["int8"]
|
||||
case TypeSmallIntegerField:
|
||||
col = T["int16"]
|
||||
case TypeIntegerField:
|
||||
col = T["int32"]
|
||||
case TypeBigIntegerField:
|
||||
if al.Driver == DR_Sqlite {
|
||||
fieldType = TypeIntegerField
|
||||
goto checkColumn
|
||||
}
|
||||
col = T["int64"]
|
||||
case TypePositiveBitField:
|
||||
col = T["uint8"]
|
||||
case TypePositiveSmallIntegerField:
|
||||
col = T["uint16"]
|
||||
case TypePositiveIntegerField:
|
||||
col = T["uint32"]
|
||||
case TypePositiveBigIntegerField:
|
||||
col = T["uint64"]
|
||||
case TypeFloatField:
|
||||
col = T["float64"]
|
||||
case TypeDecimalField:
|
||||
s := T["float64-decimal"]
|
||||
if strings.Index(s, "%d") == -1 {
|
||||
col = s
|
||||
} else {
|
||||
col = fmt.Sprintf(s, fi.digits, fi.decimals)
|
||||
}
|
||||
case RelForeignKey, RelOneToOne:
|
||||
fieldType = fi.relModelInfo.fields.pk.fieldType
|
||||
goto checkColumn
|
||||
}
|
||||
col := getColumnTyp(al, fi)
|
||||
|
||||
if fi.auto {
|
||||
switch al.Driver {
|
||||
@ -181,7 +204,13 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]string)
|
||||
name := mi.table + "_" + strings.Join(names, "_")
|
||||
cols := strings.Join(names, sep)
|
||||
sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q)
|
||||
tableIndexes[mi.table] = append(tableIndexes[mi.table], sql)
|
||||
|
||||
index := dbIndex{}
|
||||
index.Table = mi.table
|
||||
index.Name = name
|
||||
index.Sql = sql
|
||||
|
||||
tableIndexes[mi.table] = append(tableIndexes[mi.table], index)
|
||||
}
|
||||
|
||||
}
|
||||
|
58
orm/db.go
58
orm/db.go
@ -1116,3 +1116,61 @@ func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) {
|
||||
func (d *dbBase) DbTypes() map[string]string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
|
||||
tables := make(map[string]bool)
|
||||
query := d.ins.ShowTablesQuery()
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return tables, err
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var table string
|
||||
err := rows.Scan(&table)
|
||||
if err != nil {
|
||||
return tables, err
|
||||
}
|
||||
if table != "" {
|
||||
tables[table] = true
|
||||
}
|
||||
}
|
||||
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
|
||||
columns := make(map[string][3]string)
|
||||
query := d.ins.ShowColumnsQuery(table)
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return columns, err
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var (
|
||||
name string
|
||||
typ string
|
||||
null string
|
||||
)
|
||||
err := rows.Scan(&name, &typ, &null)
|
||||
if err != nil {
|
||||
return columns, err
|
||||
}
|
||||
columns[name] = [3]string{name, typ, null}
|
||||
}
|
||||
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func (d *dbBase) ShowTablesQuery() string {
|
||||
panic(ErrNotImplement)
|
||||
}
|
||||
|
||||
func (d *dbBase) ShowColumnsQuery(table string) string {
|
||||
panic(ErrNotImplement)
|
||||
}
|
||||
|
||||
func (d *dbBase) IndexExists(dbQuerier, string, string) bool {
|
||||
panic(ErrNotImplement)
|
||||
}
|
||||
|
@ -1,5 +1,9 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var mysqlOperators = map[string]string{
|
||||
"exact": "= ?",
|
||||
"iexact": "LIKE ?",
|
||||
@ -51,6 +55,23 @@ func (d *dbBaseMysql) DbTypes() map[string]string {
|
||||
return mysqlTypes
|
||||
}
|
||||
|
||||
func (d *dbBaseMysql) ShowTablesQuery() string {
|
||||
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
|
||||
}
|
||||
|
||||
func (d *dbBaseMysql) ShowColumnsQuery(table string) string {
|
||||
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
|
||||
"WHERE table_schema = DATABASE() AND table_name = '%s'", table)
|
||||
}
|
||||
|
||||
func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool {
|
||||
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
|
||||
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
|
||||
var cnt int
|
||||
row.Scan(&cnt)
|
||||
return cnt > 0
|
||||
}
|
||||
|
||||
func newdbBaseMysql() dbBaser {
|
||||
b := new(dbBaseMysql)
|
||||
b.ins = b
|
||||
|
@ -107,10 +107,26 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *dbBasePostgres) ShowTablesQuery() string {
|
||||
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')"
|
||||
}
|
||||
|
||||
func (d *dbBasePostgres) ShowColumnsQuery(table string) string {
|
||||
return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table)
|
||||
}
|
||||
|
||||
func (d *dbBasePostgres) DbTypes() map[string]string {
|
||||
return postgresTypes
|
||||
}
|
||||
|
||||
func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool {
|
||||
query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name)
|
||||
row := db.QueryRow(query)
|
||||
var cnt int
|
||||
row.Scan(&cnt)
|
||||
return cnt > 0
|
||||
}
|
||||
|
||||
func newdbBasePostgres() dbBaser {
|
||||
b := new(dbBasePostgres)
|
||||
b.ins = b
|
||||
|
@ -1,6 +1,7 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
@ -67,6 +68,51 @@ func (d *dbBaseSqlite) DbTypes() map[string]string {
|
||||
return sqliteTypes
|
||||
}
|
||||
|
||||
func (d *dbBaseSqlite) ShowTablesQuery() string {
|
||||
return "SELECT name FROM sqlite_master WHERE type = 'table'"
|
||||
}
|
||||
|
||||
func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
|
||||
query := d.ins.ShowColumnsQuery(table)
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns := make(map[string][3]string)
|
||||
for rows.Next() {
|
||||
var tmp, name, typ, null sql.NullString
|
||||
err := rows.Scan(&tmp, &name, &typ, &null, &tmp, &tmp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
columns[name.String] = [3]string{name.String, typ.String, null.String}
|
||||
}
|
||||
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func (d *dbBaseSqlite) ShowColumnsQuery(table string) string {
|
||||
return fmt.Sprintf("pragma table_info('%s')", table)
|
||||
}
|
||||
|
||||
func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool {
|
||||
query := fmt.Sprintf("PRAGMA index_list('%s')", table)
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var tmp, index sql.NullString
|
||||
rows.Scan(&tmp, &index, &tmp)
|
||||
if name == index.String {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func newdbBaseSqlite() dbBaser {
|
||||
b := new(dbBaseSqlite)
|
||||
b.ins = b
|
||||
|
@ -198,28 +198,8 @@ func TestSyncDb(t *testing.T) {
|
||||
RegisterModel(new(Comment))
|
||||
RegisterModel(new(UserBig))
|
||||
|
||||
BootStrap()
|
||||
|
||||
al := dataBaseCache.getDefault()
|
||||
db := al.DB
|
||||
|
||||
drops := getDbDropSql(al)
|
||||
for _, query := range drops {
|
||||
_, err := db.Exec(query)
|
||||
throwFail(t, err, query)
|
||||
}
|
||||
|
||||
sqls, indexes := getDbCreateSql(al)
|
||||
|
||||
for i, mi := range modelCache.allOrdered() {
|
||||
queries := []string{sqls[i]}
|
||||
queries = append(queries, indexes[mi.table]...)
|
||||
|
||||
for _, query := range queries {
|
||||
_, err := db.Exec(query)
|
||||
throwFail(t, err, query)
|
||||
}
|
||||
}
|
||||
err := RunSyncdb("default", true, false)
|
||||
throwFail(t, err)
|
||||
|
||||
modelCache.clean()
|
||||
}
|
||||
|
@ -133,4 +133,9 @@ type dbBaser interface {
|
||||
TimeFromDB(*time.Time, *time.Location)
|
||||
TimeToDB(*time.Time, *time.Location)
|
||||
DbTypes() map[string]string
|
||||
GetTables(dbQuerier) (map[string]bool, error)
|
||||
GetColumns(dbQuerier, string) (map[string][3]string, error)
|
||||
ShowTablesQuery() string
|
||||
ShowColumnsQuery(string) string
|
||||
IndexExists(dbQuerier, string, string) bool
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user