1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-22 15:10:55 +00:00

orm Improve syncdb

This commit is contained in:
slene 2013-08-27 12:33:27 +08:00
parent 6686d9235c
commit 49bbca0ce3
8 changed files with 344 additions and 87 deletions

View File

@ -9,7 +9,7 @@ import (
type commander interface {
Parse([]string)
Run()
Run() error
}
var (
@ -59,9 +59,11 @@ func RunCommand() {
}
type commandSyncDb struct {
al *alias
force bool
verbose bool
al *alias
force bool
verbose bool
noInfo bool
rtOnError bool
}
func (d *commandSyncDb) Parse(args []string) {
@ -76,7 +78,7 @@ func (d *commandSyncDb) Parse(args []string) {
d.al = getDbAlias(name)
}
func (d *commandSyncDb) Run() {
func (d *commandSyncDb) Run() error {
var drops []string
if d.force {
drops = getDbDropSql(d.al)
@ -87,25 +89,103 @@ func (d *commandSyncDb) Run() {
if d.force {
for i, mi := range modelCache.allOrdered() {
query := drops[i]
_, err := db.Exec(query)
result := ""
if err != nil {
result = err.Error()
if !d.noInfo {
fmt.Printf("drop table `%s`\n", mi.table)
}
fmt.Printf("drop table `%s` %s\n", mi.table, result)
_, err := db.Exec(query)
if d.verbose {
fmt.Printf(" %s\n\n", query)
}
if err != nil {
if d.rtOnError {
return err
}
fmt.Printf(" %s\n", err.Error())
}
}
}
sqls, indexes := getDbCreateSql(d.al)
tables, err := d.al.DbBaser.GetTables(db)
if err != nil {
if d.rtOnError {
return err
}
fmt.Printf(" %s\n", err.Error())
}
for i, mi := range modelCache.allOrdered() {
fmt.Printf("create table `%s` \n", mi.table)
if tables[mi.table] {
if !d.noInfo {
fmt.Printf("table `%s` already exists, skip\n", mi.table)
}
var fields []*fieldInfo
columns, err := d.al.DbBaser.GetColumns(db, mi.table)
if err != nil {
if d.rtOnError {
return err
}
fmt.Printf(" %s\n", err.Error())
}
for _, fi := range mi.fields.fieldsDB {
if _, ok := columns[fi.column]; ok == false {
fields = append(fields, fi)
}
}
for _, fi := range fields {
query := getColumnAddQuery(d.al, fi)
if !d.noInfo {
fmt.Printf("add column `%s` for table `%s`\n", fi.fullName, mi.table)
}
_, err := db.Exec(query)
if d.verbose {
fmt.Printf(" %s\n", query)
}
if err != nil {
if d.rtOnError {
return err
}
fmt.Printf(" %s\n", err.Error())
}
}
for _, idx := range indexes[mi.table] {
if d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) == false {
if !d.noInfo {
fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
}
query := idx.Sql
_, err := db.Exec(query)
if d.verbose {
fmt.Printf(" %s\n", query)
}
if err != nil {
if d.rtOnError {
return err
}
fmt.Printf(" %s\n", err.Error())
}
}
}
continue
}
if !d.noInfo {
fmt.Printf("create table `%s` \n", mi.table)
}
queries := []string{sqls[i]}
queries = append(queries, indexes[mi.table]...)
for _, idx := range indexes[mi.table] {
queries = append(queries, idx.Sql)
}
for _, query := range queries {
_, err := db.Exec(query)
@ -114,6 +194,9 @@ func (d *commandSyncDb) Run() {
fmt.Println(query)
}
if err != nil {
if d.rtOnError {
return err
}
fmt.Printf(" %s\n", err.Error())
}
}
@ -121,6 +204,8 @@ func (d *commandSyncDb) Run() {
fmt.Println("")
}
}
return nil
}
type commandSqlAll struct {
@ -137,19 +222,36 @@ func (d *commandSqlAll) Parse(args []string) {
d.al = getDbAlias(name)
}
func (d *commandSqlAll) Run() {
func (d *commandSqlAll) Run() error {
sqls, indexes := getDbCreateSql(d.al)
var all []string
for i, mi := range modelCache.allOrdered() {
queries := []string{sqls[i]}
queries = append(queries, indexes[mi.table]...)
for _, idx := range indexes[mi.table] {
queries = append(queries, idx.Sql)
}
sql := strings.Join(queries, "\n")
all = append(all, sql)
}
fmt.Println(strings.Join(all, "\n\n"))
return nil
}
func init() {
commands["syncdb"] = new(commandSyncDb)
commands["sqlall"] = new(commandSqlAll)
}
func RunSyncdb(name string, force bool, verbose bool) error {
BootStrap()
al := getDbAlias(name)
cmd := new(commandSyncDb)
cmd.al = al
cmd.force = force
cmd.noInfo = !verbose
cmd.verbose = verbose
cmd.rtOnError = true
return cmd.Run()
}

View File

@ -6,6 +6,12 @@ import (
"strings"
)
type dbIndex struct {
Table string
Name string
Sql string
}
func getDbAlias(name string) *alias {
if al, ok := dataBaseCache.get(name); ok {
return al
@ -31,7 +37,71 @@ func getDbDropSql(al *alias) (sqls []string) {
return sqls
}
func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]string) {
func getColumnTyp(al *alias, fi *fieldInfo) (col string) {
T := al.DbBaser.DbTypes()
fieldType := fi.fieldType
checkColumn:
switch fieldType {
case TypeBooleanField:
col = T["bool"]
case TypeCharField:
col = fmt.Sprintf(T["string"], fi.size)
case TypeTextField:
col = T["string-text"]
case TypeDateField:
col = T["time.Time-date"]
case TypeDateTimeField:
col = T["time.Time"]
case TypeBitField:
col = T["int8"]
case TypeSmallIntegerField:
col = T["int16"]
case TypeIntegerField:
col = T["int32"]
case TypeBigIntegerField:
if al.Driver == DR_Sqlite {
fieldType = TypeIntegerField
goto checkColumn
}
col = T["int64"]
case TypePositiveBitField:
col = T["uint8"]
case TypePositiveSmallIntegerField:
col = T["uint16"]
case TypePositiveIntegerField:
col = T["uint32"]
case TypePositiveBigIntegerField:
col = T["uint64"]
case TypeFloatField:
col = T["float64"]
case TypeDecimalField:
s := T["float64-decimal"]
if strings.Index(s, "%d") == -1 {
col = s
} else {
col = fmt.Sprintf(s, fi.digits, fi.decimals)
}
case RelForeignKey, RelOneToOne:
fieldType = fi.relModelInfo.fields.pk.fieldType
goto checkColumn
}
return
}
func getColumnAddQuery(al *alias, fi *fieldInfo) string {
Q := al.DbBaser.TableQuote()
typ := getColumnTyp(al, fi)
if fi.null == false {
typ += " " + "NOT NULL"
}
return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s", Q, fi.mi.table, Q, Q, fi.column, Q, typ)
}
func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model")
os.Exit(2)
@ -41,7 +111,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]string)
T := al.DbBaser.DbTypes()
sep := fmt.Sprintf("%s, %s", Q, Q)
tableIndexes = make(map[string][]string)
tableIndexes = make(map[string][]dbIndex)
for _, mi := range modelCache.allOrdered() {
sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
@ -56,55 +126,8 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]string)
for _, fi := range mi.fields.fieldsDB {
fieldType := fi.fieldType
column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q)
col := ""
checkColumn:
switch fieldType {
case TypeBooleanField:
col = T["bool"]
case TypeCharField:
col = fmt.Sprintf(T["string"], fi.size)
case TypeTextField:
col = T["string-text"]
case TypeDateField:
col = T["time.Time-date"]
case TypeDateTimeField:
col = T["time.Time"]
case TypeBitField:
col = T["int8"]
case TypeSmallIntegerField:
col = T["int16"]
case TypeIntegerField:
col = T["int32"]
case TypeBigIntegerField:
if al.Driver == DR_Sqlite {
fieldType = TypeIntegerField
goto checkColumn
}
col = T["int64"]
case TypePositiveBitField:
col = T["uint8"]
case TypePositiveSmallIntegerField:
col = T["uint16"]
case TypePositiveIntegerField:
col = T["uint32"]
case TypePositiveBigIntegerField:
col = T["uint64"]
case TypeFloatField:
col = T["float64"]
case TypeDecimalField:
s := T["float64-decimal"]
if strings.Index(s, "%d") == -1 {
col = s
} else {
col = fmt.Sprintf(s, fi.digits, fi.decimals)
}
case RelForeignKey, RelOneToOne:
fieldType = fi.relModelInfo.fields.pk.fieldType
goto checkColumn
}
col := getColumnTyp(al, fi)
if fi.auto {
switch al.Driver {
@ -181,7 +204,13 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]string)
name := mi.table + "_" + strings.Join(names, "_")
cols := strings.Join(names, sep)
sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q)
tableIndexes[mi.table] = append(tableIndexes[mi.table], sql)
index := dbIndex{}
index.Table = mi.table
index.Name = name
index.Sql = sql
tableIndexes[mi.table] = append(tableIndexes[mi.table], index)
}
}

View File

@ -1116,3 +1116,61 @@ func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) {
func (d *dbBase) DbTypes() map[string]string {
return nil
}
func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
tables := make(map[string]bool)
query := d.ins.ShowTablesQuery()
rows, err := db.Query(query)
if err != nil {
return tables, err
}
for rows.Next() {
var table string
err := rows.Scan(&table)
if err != nil {
return tables, err
}
if table != "" {
tables[table] = true
}
}
return tables, nil
}
func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
columns := make(map[string][3]string)
query := d.ins.ShowColumnsQuery(table)
rows, err := db.Query(query)
if err != nil {
return columns, err
}
for rows.Next() {
var (
name string
typ string
null string
)
err := rows.Scan(&name, &typ, &null)
if err != nil {
return columns, err
}
columns[name] = [3]string{name, typ, null}
}
return columns, nil
}
func (d *dbBase) ShowTablesQuery() string {
panic(ErrNotImplement)
}
func (d *dbBase) ShowColumnsQuery(table string) string {
panic(ErrNotImplement)
}
func (d *dbBase) IndexExists(dbQuerier, string, string) bool {
panic(ErrNotImplement)
}

View File

@ -1,5 +1,9 @@
package orm
import (
"fmt"
)
var mysqlOperators = map[string]string{
"exact": "= ?",
"iexact": "LIKE ?",
@ -51,6 +55,23 @@ func (d *dbBaseMysql) DbTypes() map[string]string {
return mysqlTypes
}
func (d *dbBaseMysql) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
}
func (d *dbBaseMysql) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
"WHERE table_schema = DATABASE() AND table_name = '%s'", table)
}
func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool {
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
var cnt int
row.Scan(&cnt)
return cnt > 0
}
func newdbBaseMysql() dbBaser {
b := new(dbBaseMysql)
b.ins = b

View File

@ -107,10 +107,26 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool)
return
}
func (d *dbBasePostgres) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')"
}
func (d *dbBasePostgres) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table)
}
func (d *dbBasePostgres) DbTypes() map[string]string {
return postgresTypes
}
func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool {
query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name)
row := db.QueryRow(query)
var cnt int
row.Scan(&cnt)
return cnt > 0
}
func newdbBasePostgres() dbBaser {
b := new(dbBasePostgres)
b.ins = b

View File

@ -1,6 +1,7 @@
package orm
import (
"database/sql"
"fmt"
)
@ -67,6 +68,51 @@ func (d *dbBaseSqlite) DbTypes() map[string]string {
return sqliteTypes
}
func (d *dbBaseSqlite) ShowTablesQuery() string {
return "SELECT name FROM sqlite_master WHERE type = 'table'"
}
func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
query := d.ins.ShowColumnsQuery(table)
rows, err := db.Query(query)
if err != nil {
return nil, err
}
columns := make(map[string][3]string)
for rows.Next() {
var tmp, name, typ, null sql.NullString
err := rows.Scan(&tmp, &name, &typ, &null, &tmp, &tmp)
if err != nil {
return nil, err
}
columns[name.String] = [3]string{name.String, typ.String, null.String}
}
return columns, nil
}
func (d *dbBaseSqlite) ShowColumnsQuery(table string) string {
return fmt.Sprintf("pragma table_info('%s')", table)
}
func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool {
query := fmt.Sprintf("PRAGMA index_list('%s')", table)
rows, err := db.Query(query)
if err != nil {
panic(err)
}
defer rows.Close()
for rows.Next() {
var tmp, index sql.NullString
rows.Scan(&tmp, &index, &tmp)
if name == index.String {
return true
}
}
return false
}
func newdbBaseSqlite() dbBaser {
b := new(dbBaseSqlite)
b.ins = b

View File

@ -198,28 +198,8 @@ func TestSyncDb(t *testing.T) {
RegisterModel(new(Comment))
RegisterModel(new(UserBig))
BootStrap()
al := dataBaseCache.getDefault()
db := al.DB
drops := getDbDropSql(al)
for _, query := range drops {
_, err := db.Exec(query)
throwFail(t, err, query)
}
sqls, indexes := getDbCreateSql(al)
for i, mi := range modelCache.allOrdered() {
queries := []string{sqls[i]}
queries = append(queries, indexes[mi.table]...)
for _, query := range queries {
_, err := db.Exec(query)
throwFail(t, err, query)
}
}
err := RunSyncdb("default", true, false)
throwFail(t, err)
modelCache.clean()
}

View File

@ -133,4 +133,9 @@ type dbBaser interface {
TimeFromDB(*time.Time, *time.Location)
TimeToDB(*time.Time, *time.Location)
DbTypes() map[string]string
GetTables(dbQuerier) (map[string]bool, error)
GetColumns(dbQuerier, string) (map[string][3]string, error)
ShowTablesQuery() string
ShowColumnsQuery(string) string
IndexExists(dbQuerier, string, string) bool
}