1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-22 20:40:54 +00:00

orm Improve syncdb

This commit is contained in:
slene 2013-08-27 12:33:27 +08:00
parent 6686d9235c
commit 49bbca0ce3
8 changed files with 344 additions and 87 deletions

View File

@ -9,7 +9,7 @@ import (
type commander interface { type commander interface {
Parse([]string) Parse([]string)
Run() Run() error
} }
var ( var (
@ -59,9 +59,11 @@ func RunCommand() {
} }
type commandSyncDb struct { type commandSyncDb struct {
al *alias al *alias
force bool force bool
verbose bool verbose bool
noInfo bool
rtOnError bool
} }
func (d *commandSyncDb) Parse(args []string) { func (d *commandSyncDb) Parse(args []string) {
@ -76,7 +78,7 @@ func (d *commandSyncDb) Parse(args []string) {
d.al = getDbAlias(name) d.al = getDbAlias(name)
} }
func (d *commandSyncDb) Run() { func (d *commandSyncDb) Run() error {
var drops []string var drops []string
if d.force { if d.force {
drops = getDbDropSql(d.al) drops = getDbDropSql(d.al)
@ -87,25 +89,103 @@ func (d *commandSyncDb) Run() {
if d.force { if d.force {
for i, mi := range modelCache.allOrdered() { for i, mi := range modelCache.allOrdered() {
query := drops[i] query := drops[i]
_, err := db.Exec(query) if !d.noInfo {
result := "" fmt.Printf("drop table `%s`\n", mi.table)
if err != nil {
result = err.Error()
} }
fmt.Printf("drop table `%s` %s\n", mi.table, result) _, err := db.Exec(query)
if d.verbose { if d.verbose {
fmt.Printf(" %s\n\n", query) fmt.Printf(" %s\n\n", query)
} }
if err != nil {
if d.rtOnError {
return err
}
fmt.Printf(" %s\n", err.Error())
}
} }
} }
sqls, indexes := getDbCreateSql(d.al) sqls, indexes := getDbCreateSql(d.al)
tables, err := d.al.DbBaser.GetTables(db)
if err != nil {
if d.rtOnError {
return err
}
fmt.Printf(" %s\n", err.Error())
}
for i, mi := range modelCache.allOrdered() { for i, mi := range modelCache.allOrdered() {
fmt.Printf("create table `%s` \n", mi.table) if tables[mi.table] {
if !d.noInfo {
fmt.Printf("table `%s` already exists, skip\n", mi.table)
}
var fields []*fieldInfo
columns, err := d.al.DbBaser.GetColumns(db, mi.table)
if err != nil {
if d.rtOnError {
return err
}
fmt.Printf(" %s\n", err.Error())
}
for _, fi := range mi.fields.fieldsDB {
if _, ok := columns[fi.column]; ok == false {
fields = append(fields, fi)
}
}
for _, fi := range fields {
query := getColumnAddQuery(d.al, fi)
if !d.noInfo {
fmt.Printf("add column `%s` for table `%s`\n", fi.fullName, mi.table)
}
_, err := db.Exec(query)
if d.verbose {
fmt.Printf(" %s\n", query)
}
if err != nil {
if d.rtOnError {
return err
}
fmt.Printf(" %s\n", err.Error())
}
}
for _, idx := range indexes[mi.table] {
if d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) == false {
if !d.noInfo {
fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
}
query := idx.Sql
_, err := db.Exec(query)
if d.verbose {
fmt.Printf(" %s\n", query)
}
if err != nil {
if d.rtOnError {
return err
}
fmt.Printf(" %s\n", err.Error())
}
}
}
continue
}
if !d.noInfo {
fmt.Printf("create table `%s` \n", mi.table)
}
queries := []string{sqls[i]} queries := []string{sqls[i]}
queries = append(queries, indexes[mi.table]...) for _, idx := range indexes[mi.table] {
queries = append(queries, idx.Sql)
}
for _, query := range queries { for _, query := range queries {
_, err := db.Exec(query) _, err := db.Exec(query)
@ -114,6 +194,9 @@ func (d *commandSyncDb) Run() {
fmt.Println(query) fmt.Println(query)
} }
if err != nil { if err != nil {
if d.rtOnError {
return err
}
fmt.Printf(" %s\n", err.Error()) fmt.Printf(" %s\n", err.Error())
} }
} }
@ -121,6 +204,8 @@ func (d *commandSyncDb) Run() {
fmt.Println("") fmt.Println("")
} }
} }
return nil
} }
type commandSqlAll struct { type commandSqlAll struct {
@ -137,19 +222,36 @@ func (d *commandSqlAll) Parse(args []string) {
d.al = getDbAlias(name) d.al = getDbAlias(name)
} }
func (d *commandSqlAll) Run() { func (d *commandSqlAll) Run() error {
sqls, indexes := getDbCreateSql(d.al) sqls, indexes := getDbCreateSql(d.al)
var all []string var all []string
for i, mi := range modelCache.allOrdered() { for i, mi := range modelCache.allOrdered() {
queries := []string{sqls[i]} queries := []string{sqls[i]}
queries = append(queries, indexes[mi.table]...) for _, idx := range indexes[mi.table] {
queries = append(queries, idx.Sql)
}
sql := strings.Join(queries, "\n") sql := strings.Join(queries, "\n")
all = append(all, sql) all = append(all, sql)
} }
fmt.Println(strings.Join(all, "\n\n")) fmt.Println(strings.Join(all, "\n\n"))
return nil
} }
func init() { func init() {
commands["syncdb"] = new(commandSyncDb) commands["syncdb"] = new(commandSyncDb)
commands["sqlall"] = new(commandSqlAll) commands["sqlall"] = new(commandSqlAll)
} }
func RunSyncdb(name string, force bool, verbose bool) error {
BootStrap()
al := getDbAlias(name)
cmd := new(commandSyncDb)
cmd.al = al
cmd.force = force
cmd.noInfo = !verbose
cmd.verbose = verbose
cmd.rtOnError = true
return cmd.Run()
}

View File

@ -6,6 +6,12 @@ import (
"strings" "strings"
) )
type dbIndex struct {
Table string
Name string
Sql string
}
func getDbAlias(name string) *alias { func getDbAlias(name string) *alias {
if al, ok := dataBaseCache.get(name); ok { if al, ok := dataBaseCache.get(name); ok {
return al return al
@ -31,7 +37,71 @@ func getDbDropSql(al *alias) (sqls []string) {
return sqls return sqls
} }
func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]string) { func getColumnTyp(al *alias, fi *fieldInfo) (col string) {
T := al.DbBaser.DbTypes()
fieldType := fi.fieldType
checkColumn:
switch fieldType {
case TypeBooleanField:
col = T["bool"]
case TypeCharField:
col = fmt.Sprintf(T["string"], fi.size)
case TypeTextField:
col = T["string-text"]
case TypeDateField:
col = T["time.Time-date"]
case TypeDateTimeField:
col = T["time.Time"]
case TypeBitField:
col = T["int8"]
case TypeSmallIntegerField:
col = T["int16"]
case TypeIntegerField:
col = T["int32"]
case TypeBigIntegerField:
if al.Driver == DR_Sqlite {
fieldType = TypeIntegerField
goto checkColumn
}
col = T["int64"]
case TypePositiveBitField:
col = T["uint8"]
case TypePositiveSmallIntegerField:
col = T["uint16"]
case TypePositiveIntegerField:
col = T["uint32"]
case TypePositiveBigIntegerField:
col = T["uint64"]
case TypeFloatField:
col = T["float64"]
case TypeDecimalField:
s := T["float64-decimal"]
if strings.Index(s, "%d") == -1 {
col = s
} else {
col = fmt.Sprintf(s, fi.digits, fi.decimals)
}
case RelForeignKey, RelOneToOne:
fieldType = fi.relModelInfo.fields.pk.fieldType
goto checkColumn
}
return
}
func getColumnAddQuery(al *alias, fi *fieldInfo) string {
Q := al.DbBaser.TableQuote()
typ := getColumnTyp(al, fi)
if fi.null == false {
typ += " " + "NOT NULL"
}
return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s", Q, fi.mi.table, Q, Q, fi.column, Q, typ)
}
func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
if len(modelCache.cache) == 0 { if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model") fmt.Println("no Model found, need register your model")
os.Exit(2) os.Exit(2)
@ -41,7 +111,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]string)
T := al.DbBaser.DbTypes() T := al.DbBaser.DbTypes()
sep := fmt.Sprintf("%s, %s", Q, Q) sep := fmt.Sprintf("%s, %s", Q, Q)
tableIndexes = make(map[string][]string) tableIndexes = make(map[string][]dbIndex)
for _, mi := range modelCache.allOrdered() { for _, mi := range modelCache.allOrdered() {
sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
@ -56,55 +126,8 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]string)
for _, fi := range mi.fields.fieldsDB { for _, fi := range mi.fields.fieldsDB {
fieldType := fi.fieldType
column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q) column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q)
col := "" col := getColumnTyp(al, fi)
checkColumn:
switch fieldType {
case TypeBooleanField:
col = T["bool"]
case TypeCharField:
col = fmt.Sprintf(T["string"], fi.size)
case TypeTextField:
col = T["string-text"]
case TypeDateField:
col = T["time.Time-date"]
case TypeDateTimeField:
col = T["time.Time"]
case TypeBitField:
col = T["int8"]
case TypeSmallIntegerField:
col = T["int16"]
case TypeIntegerField:
col = T["int32"]
case TypeBigIntegerField:
if al.Driver == DR_Sqlite {
fieldType = TypeIntegerField
goto checkColumn
}
col = T["int64"]
case TypePositiveBitField:
col = T["uint8"]
case TypePositiveSmallIntegerField:
col = T["uint16"]
case TypePositiveIntegerField:
col = T["uint32"]
case TypePositiveBigIntegerField:
col = T["uint64"]
case TypeFloatField:
col = T["float64"]
case TypeDecimalField:
s := T["float64-decimal"]
if strings.Index(s, "%d") == -1 {
col = s
} else {
col = fmt.Sprintf(s, fi.digits, fi.decimals)
}
case RelForeignKey, RelOneToOne:
fieldType = fi.relModelInfo.fields.pk.fieldType
goto checkColumn
}
if fi.auto { if fi.auto {
switch al.Driver { switch al.Driver {
@ -181,7 +204,13 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]string)
name := mi.table + "_" + strings.Join(names, "_") name := mi.table + "_" + strings.Join(names, "_")
cols := strings.Join(names, sep) cols := strings.Join(names, sep)
sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q) sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q)
tableIndexes[mi.table] = append(tableIndexes[mi.table], sql)
index := dbIndex{}
index.Table = mi.table
index.Name = name
index.Sql = sql
tableIndexes[mi.table] = append(tableIndexes[mi.table], index)
} }
} }

View File

@ -1116,3 +1116,61 @@ func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) {
func (d *dbBase) DbTypes() map[string]string { func (d *dbBase) DbTypes() map[string]string {
return nil return nil
} }
func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
tables := make(map[string]bool)
query := d.ins.ShowTablesQuery()
rows, err := db.Query(query)
if err != nil {
return tables, err
}
for rows.Next() {
var table string
err := rows.Scan(&table)
if err != nil {
return tables, err
}
if table != "" {
tables[table] = true
}
}
return tables, nil
}
func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
columns := make(map[string][3]string)
query := d.ins.ShowColumnsQuery(table)
rows, err := db.Query(query)
if err != nil {
return columns, err
}
for rows.Next() {
var (
name string
typ string
null string
)
err := rows.Scan(&name, &typ, &null)
if err != nil {
return columns, err
}
columns[name] = [3]string{name, typ, null}
}
return columns, nil
}
func (d *dbBase) ShowTablesQuery() string {
panic(ErrNotImplement)
}
func (d *dbBase) ShowColumnsQuery(table string) string {
panic(ErrNotImplement)
}
func (d *dbBase) IndexExists(dbQuerier, string, string) bool {
panic(ErrNotImplement)
}

View File

@ -1,5 +1,9 @@
package orm package orm
import (
"fmt"
)
var mysqlOperators = map[string]string{ var mysqlOperators = map[string]string{
"exact": "= ?", "exact": "= ?",
"iexact": "LIKE ?", "iexact": "LIKE ?",
@ -51,6 +55,23 @@ func (d *dbBaseMysql) DbTypes() map[string]string {
return mysqlTypes return mysqlTypes
} }
func (d *dbBaseMysql) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
}
func (d *dbBaseMysql) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
"WHERE table_schema = DATABASE() AND table_name = '%s'", table)
}
func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool {
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
var cnt int
row.Scan(&cnt)
return cnt > 0
}
func newdbBaseMysql() dbBaser { func newdbBaseMysql() dbBaser {
b := new(dbBaseMysql) b := new(dbBaseMysql)
b.ins = b b.ins = b

View File

@ -107,10 +107,26 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool)
return return
} }
func (d *dbBasePostgres) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')"
}
func (d *dbBasePostgres) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table)
}
func (d *dbBasePostgres) DbTypes() map[string]string { func (d *dbBasePostgres) DbTypes() map[string]string {
return postgresTypes return postgresTypes
} }
func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool {
query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name)
row := db.QueryRow(query)
var cnt int
row.Scan(&cnt)
return cnt > 0
}
func newdbBasePostgres() dbBaser { func newdbBasePostgres() dbBaser {
b := new(dbBasePostgres) b := new(dbBasePostgres)
b.ins = b b.ins = b

View File

@ -1,6 +1,7 @@
package orm package orm
import ( import (
"database/sql"
"fmt" "fmt"
) )
@ -67,6 +68,51 @@ func (d *dbBaseSqlite) DbTypes() map[string]string {
return sqliteTypes return sqliteTypes
} }
func (d *dbBaseSqlite) ShowTablesQuery() string {
return "SELECT name FROM sqlite_master WHERE type = 'table'"
}
func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
query := d.ins.ShowColumnsQuery(table)
rows, err := db.Query(query)
if err != nil {
return nil, err
}
columns := make(map[string][3]string)
for rows.Next() {
var tmp, name, typ, null sql.NullString
err := rows.Scan(&tmp, &name, &typ, &null, &tmp, &tmp)
if err != nil {
return nil, err
}
columns[name.String] = [3]string{name.String, typ.String, null.String}
}
return columns, nil
}
func (d *dbBaseSqlite) ShowColumnsQuery(table string) string {
return fmt.Sprintf("pragma table_info('%s')", table)
}
func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool {
query := fmt.Sprintf("PRAGMA index_list('%s')", table)
rows, err := db.Query(query)
if err != nil {
panic(err)
}
defer rows.Close()
for rows.Next() {
var tmp, index sql.NullString
rows.Scan(&tmp, &index, &tmp)
if name == index.String {
return true
}
}
return false
}
func newdbBaseSqlite() dbBaser { func newdbBaseSqlite() dbBaser {
b := new(dbBaseSqlite) b := new(dbBaseSqlite)
b.ins = b b.ins = b

View File

@ -198,28 +198,8 @@ func TestSyncDb(t *testing.T) {
RegisterModel(new(Comment)) RegisterModel(new(Comment))
RegisterModel(new(UserBig)) RegisterModel(new(UserBig))
BootStrap() err := RunSyncdb("default", true, false)
throwFail(t, err)
al := dataBaseCache.getDefault()
db := al.DB
drops := getDbDropSql(al)
for _, query := range drops {
_, err := db.Exec(query)
throwFail(t, err, query)
}
sqls, indexes := getDbCreateSql(al)
for i, mi := range modelCache.allOrdered() {
queries := []string{sqls[i]}
queries = append(queries, indexes[mi.table]...)
for _, query := range queries {
_, err := db.Exec(query)
throwFail(t, err, query)
}
}
modelCache.clean() modelCache.clean()
} }

View File

@ -133,4 +133,9 @@ type dbBaser interface {
TimeFromDB(*time.Time, *time.Location) TimeFromDB(*time.Time, *time.Location)
TimeToDB(*time.Time, *time.Location) TimeToDB(*time.Time, *time.Location)
DbTypes() map[string]string DbTypes() map[string]string
GetTables(dbQuerier) (map[string]bool, error)
GetColumns(dbQuerier, string) (map[string][3]string, error)
ShowTablesQuery() string
ShowColumnsQuery(string) string
IndexExists(dbQuerier, string, string) bool
} }