mirror of
https://github.com/astaxie/beego.git
synced 2024-11-22 15:00:54 +00:00
orm Improve syncdb
This commit is contained in:
parent
6686d9235c
commit
49bbca0ce3
130
orm/cmd.go
130
orm/cmd.go
@ -9,7 +9,7 @@ import (
|
|||||||
|
|
||||||
type commander interface {
|
type commander interface {
|
||||||
Parse([]string)
|
Parse([]string)
|
||||||
Run()
|
Run() error
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -59,9 +59,11 @@ func RunCommand() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type commandSyncDb struct {
|
type commandSyncDb struct {
|
||||||
al *alias
|
al *alias
|
||||||
force bool
|
force bool
|
||||||
verbose bool
|
verbose bool
|
||||||
|
noInfo bool
|
||||||
|
rtOnError bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *commandSyncDb) Parse(args []string) {
|
func (d *commandSyncDb) Parse(args []string) {
|
||||||
@ -76,7 +78,7 @@ func (d *commandSyncDb) Parse(args []string) {
|
|||||||
d.al = getDbAlias(name)
|
d.al = getDbAlias(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *commandSyncDb) Run() {
|
func (d *commandSyncDb) Run() error {
|
||||||
var drops []string
|
var drops []string
|
||||||
if d.force {
|
if d.force {
|
||||||
drops = getDbDropSql(d.al)
|
drops = getDbDropSql(d.al)
|
||||||
@ -87,25 +89,103 @@ func (d *commandSyncDb) Run() {
|
|||||||
if d.force {
|
if d.force {
|
||||||
for i, mi := range modelCache.allOrdered() {
|
for i, mi := range modelCache.allOrdered() {
|
||||||
query := drops[i]
|
query := drops[i]
|
||||||
_, err := db.Exec(query)
|
if !d.noInfo {
|
||||||
result := ""
|
fmt.Printf("drop table `%s`\n", mi.table)
|
||||||
if err != nil {
|
|
||||||
result = err.Error()
|
|
||||||
}
|
}
|
||||||
fmt.Printf("drop table `%s` %s\n", mi.table, result)
|
_, err := db.Exec(query)
|
||||||
if d.verbose {
|
if d.verbose {
|
||||||
fmt.Printf(" %s\n\n", query)
|
fmt.Printf(" %s\n\n", query)
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
|
if d.rtOnError {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Printf(" %s\n", err.Error())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sqls, indexes := getDbCreateSql(d.al)
|
sqls, indexes := getDbCreateSql(d.al)
|
||||||
|
|
||||||
|
tables, err := d.al.DbBaser.GetTables(db)
|
||||||
|
if err != nil {
|
||||||
|
if d.rtOnError {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Printf(" %s\n", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
for i, mi := range modelCache.allOrdered() {
|
for i, mi := range modelCache.allOrdered() {
|
||||||
fmt.Printf("create table `%s` \n", mi.table)
|
if tables[mi.table] {
|
||||||
|
if !d.noInfo {
|
||||||
|
fmt.Printf("table `%s` already exists, skip\n", mi.table)
|
||||||
|
}
|
||||||
|
|
||||||
|
var fields []*fieldInfo
|
||||||
|
columns, err := d.al.DbBaser.GetColumns(db, mi.table)
|
||||||
|
if err != nil {
|
||||||
|
if d.rtOnError {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Printf(" %s\n", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, fi := range mi.fields.fieldsDB {
|
||||||
|
if _, ok := columns[fi.column]; ok == false {
|
||||||
|
fields = append(fields, fi)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, fi := range fields {
|
||||||
|
query := getColumnAddQuery(d.al, fi)
|
||||||
|
|
||||||
|
if !d.noInfo {
|
||||||
|
fmt.Printf("add column `%s` for table `%s`\n", fi.fullName, mi.table)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := db.Exec(query)
|
||||||
|
if d.verbose {
|
||||||
|
fmt.Printf(" %s\n", query)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
if d.rtOnError {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Printf(" %s\n", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, idx := range indexes[mi.table] {
|
||||||
|
if d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) == false {
|
||||||
|
if !d.noInfo {
|
||||||
|
fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
|
||||||
|
}
|
||||||
|
|
||||||
|
query := idx.Sql
|
||||||
|
_, err := db.Exec(query)
|
||||||
|
if d.verbose {
|
||||||
|
fmt.Printf(" %s\n", query)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
if d.rtOnError {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Printf(" %s\n", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !d.noInfo {
|
||||||
|
fmt.Printf("create table `%s` \n", mi.table)
|
||||||
|
}
|
||||||
|
|
||||||
queries := []string{sqls[i]}
|
queries := []string{sqls[i]}
|
||||||
queries = append(queries, indexes[mi.table]...)
|
for _, idx := range indexes[mi.table] {
|
||||||
|
queries = append(queries, idx.Sql)
|
||||||
|
}
|
||||||
|
|
||||||
for _, query := range queries {
|
for _, query := range queries {
|
||||||
_, err := db.Exec(query)
|
_, err := db.Exec(query)
|
||||||
@ -114,6 +194,9 @@ func (d *commandSyncDb) Run() {
|
|||||||
fmt.Println(query)
|
fmt.Println(query)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if d.rtOnError {
|
||||||
|
return err
|
||||||
|
}
|
||||||
fmt.Printf(" %s\n", err.Error())
|
fmt.Printf(" %s\n", err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -121,6 +204,8 @@ func (d *commandSyncDb) Run() {
|
|||||||
fmt.Println("")
|
fmt.Println("")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type commandSqlAll struct {
|
type commandSqlAll struct {
|
||||||
@ -137,19 +222,36 @@ func (d *commandSqlAll) Parse(args []string) {
|
|||||||
d.al = getDbAlias(name)
|
d.al = getDbAlias(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *commandSqlAll) Run() {
|
func (d *commandSqlAll) Run() error {
|
||||||
sqls, indexes := getDbCreateSql(d.al)
|
sqls, indexes := getDbCreateSql(d.al)
|
||||||
var all []string
|
var all []string
|
||||||
for i, mi := range modelCache.allOrdered() {
|
for i, mi := range modelCache.allOrdered() {
|
||||||
queries := []string{sqls[i]}
|
queries := []string{sqls[i]}
|
||||||
queries = append(queries, indexes[mi.table]...)
|
for _, idx := range indexes[mi.table] {
|
||||||
|
queries = append(queries, idx.Sql)
|
||||||
|
}
|
||||||
sql := strings.Join(queries, "\n")
|
sql := strings.Join(queries, "\n")
|
||||||
all = append(all, sql)
|
all = append(all, sql)
|
||||||
}
|
}
|
||||||
fmt.Println(strings.Join(all, "\n\n"))
|
fmt.Println(strings.Join(all, "\n\n"))
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
commands["syncdb"] = new(commandSyncDb)
|
commands["syncdb"] = new(commandSyncDb)
|
||||||
commands["sqlall"] = new(commandSqlAll)
|
commands["sqlall"] = new(commandSqlAll)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func RunSyncdb(name string, force bool, verbose bool) error {
|
||||||
|
BootStrap()
|
||||||
|
|
||||||
|
al := getDbAlias(name)
|
||||||
|
cmd := new(commandSyncDb)
|
||||||
|
cmd.al = al
|
||||||
|
cmd.force = force
|
||||||
|
cmd.noInfo = !verbose
|
||||||
|
cmd.verbose = verbose
|
||||||
|
cmd.rtOnError = true
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
|
131
orm/cmd_utils.go
131
orm/cmd_utils.go
@ -6,6 +6,12 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type dbIndex struct {
|
||||||
|
Table string
|
||||||
|
Name string
|
||||||
|
Sql string
|
||||||
|
}
|
||||||
|
|
||||||
func getDbAlias(name string) *alias {
|
func getDbAlias(name string) *alias {
|
||||||
if al, ok := dataBaseCache.get(name); ok {
|
if al, ok := dataBaseCache.get(name); ok {
|
||||||
return al
|
return al
|
||||||
@ -31,7 +37,71 @@ func getDbDropSql(al *alias) (sqls []string) {
|
|||||||
return sqls
|
return sqls
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]string) {
|
func getColumnTyp(al *alias, fi *fieldInfo) (col string) {
|
||||||
|
T := al.DbBaser.DbTypes()
|
||||||
|
fieldType := fi.fieldType
|
||||||
|
|
||||||
|
checkColumn:
|
||||||
|
switch fieldType {
|
||||||
|
case TypeBooleanField:
|
||||||
|
col = T["bool"]
|
||||||
|
case TypeCharField:
|
||||||
|
col = fmt.Sprintf(T["string"], fi.size)
|
||||||
|
case TypeTextField:
|
||||||
|
col = T["string-text"]
|
||||||
|
case TypeDateField:
|
||||||
|
col = T["time.Time-date"]
|
||||||
|
case TypeDateTimeField:
|
||||||
|
col = T["time.Time"]
|
||||||
|
case TypeBitField:
|
||||||
|
col = T["int8"]
|
||||||
|
case TypeSmallIntegerField:
|
||||||
|
col = T["int16"]
|
||||||
|
case TypeIntegerField:
|
||||||
|
col = T["int32"]
|
||||||
|
case TypeBigIntegerField:
|
||||||
|
if al.Driver == DR_Sqlite {
|
||||||
|
fieldType = TypeIntegerField
|
||||||
|
goto checkColumn
|
||||||
|
}
|
||||||
|
col = T["int64"]
|
||||||
|
case TypePositiveBitField:
|
||||||
|
col = T["uint8"]
|
||||||
|
case TypePositiveSmallIntegerField:
|
||||||
|
col = T["uint16"]
|
||||||
|
case TypePositiveIntegerField:
|
||||||
|
col = T["uint32"]
|
||||||
|
case TypePositiveBigIntegerField:
|
||||||
|
col = T["uint64"]
|
||||||
|
case TypeFloatField:
|
||||||
|
col = T["float64"]
|
||||||
|
case TypeDecimalField:
|
||||||
|
s := T["float64-decimal"]
|
||||||
|
if strings.Index(s, "%d") == -1 {
|
||||||
|
col = s
|
||||||
|
} else {
|
||||||
|
col = fmt.Sprintf(s, fi.digits, fi.decimals)
|
||||||
|
}
|
||||||
|
case RelForeignKey, RelOneToOne:
|
||||||
|
fieldType = fi.relModelInfo.fields.pk.fieldType
|
||||||
|
goto checkColumn
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func getColumnAddQuery(al *alias, fi *fieldInfo) string {
|
||||||
|
Q := al.DbBaser.TableQuote()
|
||||||
|
typ := getColumnTyp(al, fi)
|
||||||
|
|
||||||
|
if fi.null == false {
|
||||||
|
typ += " " + "NOT NULL"
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s", Q, fi.mi.table, Q, Q, fi.column, Q, typ)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
|
||||||
if len(modelCache.cache) == 0 {
|
if len(modelCache.cache) == 0 {
|
||||||
fmt.Println("no Model found, need register your model")
|
fmt.Println("no Model found, need register your model")
|
||||||
os.Exit(2)
|
os.Exit(2)
|
||||||
@ -41,7 +111,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]string)
|
|||||||
T := al.DbBaser.DbTypes()
|
T := al.DbBaser.DbTypes()
|
||||||
sep := fmt.Sprintf("%s, %s", Q, Q)
|
sep := fmt.Sprintf("%s, %s", Q, Q)
|
||||||
|
|
||||||
tableIndexes = make(map[string][]string)
|
tableIndexes = make(map[string][]dbIndex)
|
||||||
|
|
||||||
for _, mi := range modelCache.allOrdered() {
|
for _, mi := range modelCache.allOrdered() {
|
||||||
sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
|
sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
|
||||||
@ -56,55 +126,8 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]string)
|
|||||||
|
|
||||||
for _, fi := range mi.fields.fieldsDB {
|
for _, fi := range mi.fields.fieldsDB {
|
||||||
|
|
||||||
fieldType := fi.fieldType
|
|
||||||
column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q)
|
column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q)
|
||||||
col := ""
|
col := getColumnTyp(al, fi)
|
||||||
|
|
||||||
checkColumn:
|
|
||||||
switch fieldType {
|
|
||||||
case TypeBooleanField:
|
|
||||||
col = T["bool"]
|
|
||||||
case TypeCharField:
|
|
||||||
col = fmt.Sprintf(T["string"], fi.size)
|
|
||||||
case TypeTextField:
|
|
||||||
col = T["string-text"]
|
|
||||||
case TypeDateField:
|
|
||||||
col = T["time.Time-date"]
|
|
||||||
case TypeDateTimeField:
|
|
||||||
col = T["time.Time"]
|
|
||||||
case TypeBitField:
|
|
||||||
col = T["int8"]
|
|
||||||
case TypeSmallIntegerField:
|
|
||||||
col = T["int16"]
|
|
||||||
case TypeIntegerField:
|
|
||||||
col = T["int32"]
|
|
||||||
case TypeBigIntegerField:
|
|
||||||
if al.Driver == DR_Sqlite {
|
|
||||||
fieldType = TypeIntegerField
|
|
||||||
goto checkColumn
|
|
||||||
}
|
|
||||||
col = T["int64"]
|
|
||||||
case TypePositiveBitField:
|
|
||||||
col = T["uint8"]
|
|
||||||
case TypePositiveSmallIntegerField:
|
|
||||||
col = T["uint16"]
|
|
||||||
case TypePositiveIntegerField:
|
|
||||||
col = T["uint32"]
|
|
||||||
case TypePositiveBigIntegerField:
|
|
||||||
col = T["uint64"]
|
|
||||||
case TypeFloatField:
|
|
||||||
col = T["float64"]
|
|
||||||
case TypeDecimalField:
|
|
||||||
s := T["float64-decimal"]
|
|
||||||
if strings.Index(s, "%d") == -1 {
|
|
||||||
col = s
|
|
||||||
} else {
|
|
||||||
col = fmt.Sprintf(s, fi.digits, fi.decimals)
|
|
||||||
}
|
|
||||||
case RelForeignKey, RelOneToOne:
|
|
||||||
fieldType = fi.relModelInfo.fields.pk.fieldType
|
|
||||||
goto checkColumn
|
|
||||||
}
|
|
||||||
|
|
||||||
if fi.auto {
|
if fi.auto {
|
||||||
switch al.Driver {
|
switch al.Driver {
|
||||||
@ -181,7 +204,13 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]string)
|
|||||||
name := mi.table + "_" + strings.Join(names, "_")
|
name := mi.table + "_" + strings.Join(names, "_")
|
||||||
cols := strings.Join(names, sep)
|
cols := strings.Join(names, sep)
|
||||||
sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q)
|
sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q)
|
||||||
tableIndexes[mi.table] = append(tableIndexes[mi.table], sql)
|
|
||||||
|
index := dbIndex{}
|
||||||
|
index.Table = mi.table
|
||||||
|
index.Name = name
|
||||||
|
index.Sql = sql
|
||||||
|
|
||||||
|
tableIndexes[mi.table] = append(tableIndexes[mi.table], index)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
58
orm/db.go
58
orm/db.go
@ -1116,3 +1116,61 @@ func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) {
|
|||||||
func (d *dbBase) DbTypes() map[string]string {
|
func (d *dbBase) DbTypes() map[string]string {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
|
||||||
|
tables := make(map[string]bool)
|
||||||
|
query := d.ins.ShowTablesQuery()
|
||||||
|
rows, err := db.Query(query)
|
||||||
|
if err != nil {
|
||||||
|
return tables, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var table string
|
||||||
|
err := rows.Scan(&table)
|
||||||
|
if err != nil {
|
||||||
|
return tables, err
|
||||||
|
}
|
||||||
|
if table != "" {
|
||||||
|
tables[table] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tables, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
|
||||||
|
columns := make(map[string][3]string)
|
||||||
|
query := d.ins.ShowColumnsQuery(table)
|
||||||
|
rows, err := db.Query(query)
|
||||||
|
if err != nil {
|
||||||
|
return columns, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var (
|
||||||
|
name string
|
||||||
|
typ string
|
||||||
|
null string
|
||||||
|
)
|
||||||
|
err := rows.Scan(&name, &typ, &null)
|
||||||
|
if err != nil {
|
||||||
|
return columns, err
|
||||||
|
}
|
||||||
|
columns[name] = [3]string{name, typ, null}
|
||||||
|
}
|
||||||
|
|
||||||
|
return columns, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dbBase) ShowTablesQuery() string {
|
||||||
|
panic(ErrNotImplement)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dbBase) ShowColumnsQuery(table string) string {
|
||||||
|
panic(ErrNotImplement)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dbBase) IndexExists(dbQuerier, string, string) bool {
|
||||||
|
panic(ErrNotImplement)
|
||||||
|
}
|
||||||
|
@ -1,5 +1,9 @@
|
|||||||
package orm
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
var mysqlOperators = map[string]string{
|
var mysqlOperators = map[string]string{
|
||||||
"exact": "= ?",
|
"exact": "= ?",
|
||||||
"iexact": "LIKE ?",
|
"iexact": "LIKE ?",
|
||||||
@ -51,6 +55,23 @@ func (d *dbBaseMysql) DbTypes() map[string]string {
|
|||||||
return mysqlTypes
|
return mysqlTypes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *dbBaseMysql) ShowTablesQuery() string {
|
||||||
|
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dbBaseMysql) ShowColumnsQuery(table string) string {
|
||||||
|
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
|
||||||
|
"WHERE table_schema = DATABASE() AND table_name = '%s'", table)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool {
|
||||||
|
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
|
||||||
|
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
|
||||||
|
var cnt int
|
||||||
|
row.Scan(&cnt)
|
||||||
|
return cnt > 0
|
||||||
|
}
|
||||||
|
|
||||||
func newdbBaseMysql() dbBaser {
|
func newdbBaseMysql() dbBaser {
|
||||||
b := new(dbBaseMysql)
|
b := new(dbBaseMysql)
|
||||||
b.ins = b
|
b.ins = b
|
||||||
|
@ -107,10 +107,26 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *dbBasePostgres) ShowTablesQuery() string {
|
||||||
|
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dbBasePostgres) ShowColumnsQuery(table string) string {
|
||||||
|
return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table)
|
||||||
|
}
|
||||||
|
|
||||||
func (d *dbBasePostgres) DbTypes() map[string]string {
|
func (d *dbBasePostgres) DbTypes() map[string]string {
|
||||||
return postgresTypes
|
return postgresTypes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool {
|
||||||
|
query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name)
|
||||||
|
row := db.QueryRow(query)
|
||||||
|
var cnt int
|
||||||
|
row.Scan(&cnt)
|
||||||
|
return cnt > 0
|
||||||
|
}
|
||||||
|
|
||||||
func newdbBasePostgres() dbBaser {
|
func newdbBasePostgres() dbBaser {
|
||||||
b := new(dbBasePostgres)
|
b := new(dbBasePostgres)
|
||||||
b.ins = b
|
b.ins = b
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package orm
|
package orm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -67,6 +68,51 @@ func (d *dbBaseSqlite) DbTypes() map[string]string {
|
|||||||
return sqliteTypes
|
return sqliteTypes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *dbBaseSqlite) ShowTablesQuery() string {
|
||||||
|
return "SELECT name FROM sqlite_master WHERE type = 'table'"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
|
||||||
|
query := d.ins.ShowColumnsQuery(table)
|
||||||
|
rows, err := db.Query(query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
columns := make(map[string][3]string)
|
||||||
|
for rows.Next() {
|
||||||
|
var tmp, name, typ, null sql.NullString
|
||||||
|
err := rows.Scan(&tmp, &name, &typ, &null, &tmp, &tmp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
columns[name.String] = [3]string{name.String, typ.String, null.String}
|
||||||
|
}
|
||||||
|
|
||||||
|
return columns, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dbBaseSqlite) ShowColumnsQuery(table string) string {
|
||||||
|
return fmt.Sprintf("pragma table_info('%s')", table)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool {
|
||||||
|
query := fmt.Sprintf("PRAGMA index_list('%s')", table)
|
||||||
|
rows, err := db.Query(query)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
for rows.Next() {
|
||||||
|
var tmp, index sql.NullString
|
||||||
|
rows.Scan(&tmp, &index, &tmp)
|
||||||
|
if name == index.String {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func newdbBaseSqlite() dbBaser {
|
func newdbBaseSqlite() dbBaser {
|
||||||
b := new(dbBaseSqlite)
|
b := new(dbBaseSqlite)
|
||||||
b.ins = b
|
b.ins = b
|
||||||
|
@ -198,28 +198,8 @@ func TestSyncDb(t *testing.T) {
|
|||||||
RegisterModel(new(Comment))
|
RegisterModel(new(Comment))
|
||||||
RegisterModel(new(UserBig))
|
RegisterModel(new(UserBig))
|
||||||
|
|
||||||
BootStrap()
|
err := RunSyncdb("default", true, false)
|
||||||
|
throwFail(t, err)
|
||||||
al := dataBaseCache.getDefault()
|
|
||||||
db := al.DB
|
|
||||||
|
|
||||||
drops := getDbDropSql(al)
|
|
||||||
for _, query := range drops {
|
|
||||||
_, err := db.Exec(query)
|
|
||||||
throwFail(t, err, query)
|
|
||||||
}
|
|
||||||
|
|
||||||
sqls, indexes := getDbCreateSql(al)
|
|
||||||
|
|
||||||
for i, mi := range modelCache.allOrdered() {
|
|
||||||
queries := []string{sqls[i]}
|
|
||||||
queries = append(queries, indexes[mi.table]...)
|
|
||||||
|
|
||||||
for _, query := range queries {
|
|
||||||
_, err := db.Exec(query)
|
|
||||||
throwFail(t, err, query)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
modelCache.clean()
|
modelCache.clean()
|
||||||
}
|
}
|
||||||
|
@ -133,4 +133,9 @@ type dbBaser interface {
|
|||||||
TimeFromDB(*time.Time, *time.Location)
|
TimeFromDB(*time.Time, *time.Location)
|
||||||
TimeToDB(*time.Time, *time.Location)
|
TimeToDB(*time.Time, *time.Location)
|
||||||
DbTypes() map[string]string
|
DbTypes() map[string]string
|
||||||
|
GetTables(dbQuerier) (map[string]bool, error)
|
||||||
|
GetColumns(dbQuerier, string) (map[string][3]string, error)
|
||||||
|
ShowTablesQuery() string
|
||||||
|
ShowColumnsQuery(string) string
|
||||||
|
IndexExists(dbQuerier, string, string) bool
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user