diff --git a/cmd/commands/beegopro/beegopro.go b/cmd/commands/beegopro/beegopro.go index 0490e3c..978c8f8 100644 --- a/cmd/commands/beegopro/beegopro.go +++ b/cmd/commands/beegopro/beegopro.go @@ -32,6 +32,8 @@ var CmdBeegoPro = &commands.Command{ func init() { CmdBeegoPro.Flag.Var(&beegopro.SQL, "sql", "sql file path") + CmdBeegoPro.Flag.Var(&beegopro.SQLMode, "sqlmode", "sql mode") + CmdBeegoPro.Flag.Var(&beegopro.SQLModePath, "sqlpath", "sql mode path") commands.AvailableCommands = append(commands.AvailableCommands, CmdBeegoPro) } diff --git a/internal/app/module/beegopro/migration.go b/internal/app/module/beegopro/migration.go index 11caa2c..a17a233 100644 --- a/internal/app/module/beegopro/migration.go +++ b/internal/app/module/beegopro/migration.go @@ -6,9 +6,17 @@ import ( "github.com/beego/bee/utils" "io/ioutil" "path/filepath" + "strings" ) var SQL utils.DocValue +var SQLMode utils.DocValue +var SQLModePath utils.DocValue + +var ( + SQLModeUp = "up" + SQLModeDown = "down" +) func (c *Container) Migration(args []string) { c.initUserOption() @@ -17,19 +25,55 @@ func (c *Container) Migration(args []string) { beeLogger.Log.Fatalf("Could not connect to '%s' database using '%s': %s", c.UserOption.Driver, c.UserOption.Dsn, err) return } - defer db.Close() + switch SQLMode.String() { + case SQLModeUp: + doByMode(db, "up.sql") + case SQLModeDown: + doByMode(db, "down.sql") + default: + doBySqlFile(db) + } +} - absFile, _ := filepath.Abs(SQL.String()) - content, err := ioutil.ReadFile(SQL.String()) +func doBySqlFile(db *sql.DB) { + fileName := SQL.String() + if !utils.IsExist(fileName) { + beeLogger.Log.Fatalf("sql mode path not exist, path %s", SQL.String()) + } + doDb(db, fileName) +} + +func doByMode(db *sql.DB, suffix string) { + pathName := SQLModePath.String() + if !utils.IsExist(pathName) { + beeLogger.Log.Fatalf("sql mode path not exist, path %s", SQLModePath.String()) + } + + rd, err := ioutil.ReadDir(pathName) + if err != nil { + beeLogger.Log.Fatalf("read dir err, path %s, err %s", pathName, err) + } + for _, fi := range rd { + if !fi.IsDir() { + if !strings.HasSuffix(fi.Name(), suffix) { + continue + } + doDb(db, filepath.Join(pathName, fi.Name())) + } + } +} + +func doDb(db *sql.DB, filePath string) { + absFile, _ := filepath.Abs(filePath) + content, err := ioutil.ReadFile(filePath) if err != nil { beeLogger.Log.Errorf("read file err %s, abs file %s", err, absFile) } - result, err := db.Exec(string(content)) + _, err = db.Exec(string(content)) if err != nil { beeLogger.Log.Errorf("db exec err %s", err) } - beeLogger.Log.Infof("db exec info %v", result) - + beeLogger.Log.Infof("db exec info %s", filePath) }