From 9db1e8fb4c1c50fa27b8986c72247d07e468859b Mon Sep 17 00:00:00 2001 From: yitea Date: Sun, 5 Jul 2020 14:54:26 +0800 Subject: [PATCH] beego pro add mysql source type --- internal/app/module/beegopro/container.go | 74 ++++++------ internal/app/module/beegopro/migration.go | 6 +- internal/app/module/beegopro/parser.go | 13 ++ internal/app/module/beegopro/parser_mysql.go | 88 ++++++++++++++ internal/app/module/beegopro/parser_text.go | 35 ++++++ internal/app/module/beegopro/render.go | 4 +- internal/app/module/beegopro/schema.go | 53 ++++---- internal/app/module/beegopro/schema_model.go | 54 +-------- .../app/module/beegopro/schema_mysql_model.go | 74 ++++++++++++ internal/app/module/beegopro/schema_render.go | 11 ++ .../app/module/beegopro/schema_text_model.go | 50 ++++++++ internal/pkg/git/repository.go | 9 +- internal/pkg/utils/dsn.go | 113 ++++++++++++++++++ 13 files changed, 452 insertions(+), 132 deletions(-) create mode 100644 internal/app/module/beegopro/parser.go create mode 100644 internal/app/module/beegopro/parser_mysql.go create mode 100644 internal/app/module/beegopro/parser_text.go create mode 100644 internal/app/module/beegopro/schema_mysql_model.go create mode 100644 internal/app/module/beegopro/schema_render.go create mode 100644 internal/app/module/beegopro/schema_text_model.go create mode 100644 internal/pkg/utils/dsn.go diff --git a/internal/app/module/beegopro/container.go b/internal/app/module/beegopro/container.go index 17456ae..2d2cea7 100644 --- a/internal/app/module/beegopro/container.go +++ b/internal/app/module/beegopro/container.go @@ -20,7 +20,7 @@ var DefaultBeegoPro = &Container{ BeegoProFile: system.CurrentDir + "/beegopro.toml", TimestampFile: system.CurrentDir + "/beegopro.timestamp", GoModFile: system.CurrentDir + "/go.mod", - Option: Option{ + UserOption: UserOption{ Debug: false, ContextDebug: false, Dsn: "", @@ -28,7 +28,7 @@ var DefaultBeegoPro = &Container{ ProType: "default", ApiPrefix: "/", EnableModule: nil, - Models: make(map[string]ModelContent, 0), + Models: make(map[string]TextModel, 0), GitRemotePath: "https://github.com/beego-dev/beego-pro.git", Branch: "master", GitLocalPath: system.BeegoHome + "/beego-pro", @@ -43,7 +43,7 @@ var DefaultBeegoPro = &Container{ }, GenerateTime: time.Now().Format(MDateFormat), GenerateTimeUnix: time.Now().Unix(), - Tmpl: Tmpl{}, + TmplOption: TmplOption{}, CurPath: system.CurrentDir, EnableModules: make(map[string]interface{}, 0), // get the user configuration, get the enable module result FunctionOnce: make(map[string]sync.Once, 0), // get the tmpl configuration, get the function once result @@ -52,13 +52,14 @@ var DefaultBeegoPro = &Container{ func (c *Container) Run() { // init git refresh cache time c.initTimestamp() - c.initBeegoPro() - c.initBeegoTmpl() + c.initUserOption() + c.initTemplateOption() + c.initParser() c.initRender() c.flushTimestamp() } -func (c *Container) initBeegoPro() { +func (c *Container) initUserOption() { if !utils.IsExist(c.BeegoProFile) { beeLogger.Log.Fatalf("beego pro config is not exist, beego json path: %s", c.BeegoProFile) return @@ -70,23 +71,23 @@ func (c *Container) initBeegoPro() { return } - err = viper.Unmarshal(&c.Option) + err = viper.Unmarshal(&c.UserOption) if err != nil { beeLogger.Log.Fatalf("beego pro config unmarshal error, err: %s", err.Error()) return } - if c.Option.Debug { + if c.UserOption.Debug { viper.Debug() } - if c.Option.EnableGomod { + if c.UserOption.EnableGomod { if !utils.IsExist(c.GoModFile) { beeLogger.Log.Fatalf("go mod not exist, please create go mod file") return } } - for _, value := range c.Option.EnableModule { + for _, value := range c.UserOption.EnableModule { c.EnableModules[value] = struct{}{} } @@ -94,14 +95,15 @@ func (c *Container) initBeegoPro() { c.EnableModules["*"] = struct{}{} } - if c.Option.Debug { + if c.UserOption.Debug { fmt.Println("c.modules", c.EnableModules) } + } -func (c *Container) initBeegoTmpl() { - if c.Option.EnableGitPull && (c.GenerateTimeUnix-c.Timestamp.GitCacheLastRefresh > c.Option.RefreshGitTime) { - err := git.CloneORPullRepo(c.Option.GitRemotePath, c.Option.GitLocalPath) +func (c *Container) initTemplateOption() { + if c.UserOption.EnableGitPull && (c.GenerateTimeUnix-c.Timestamp.GitCacheLastRefresh > c.UserOption.RefreshGitTime) { + err := git.CloneORPullRepo(c.UserOption.GitRemotePath, c.UserOption.GitLocalPath) if err != nil { beeLogger.Log.Fatalf("beego pro git clone or pull repo error, err: %s", err) return @@ -109,59 +111,49 @@ func (c *Container) initBeegoTmpl() { c.Timestamp.GitCacheLastRefresh = c.GenerateTimeUnix } - tree, err := toml.LoadFile(c.Option.GitLocalPath + "/" + c.Option.ProType + "/bee.toml") + tree, err := toml.LoadFile(c.UserOption.GitLocalPath + "/" + c.UserOption.ProType + "/bee.toml") if err != nil { beeLogger.Log.Fatalf("beego tmpl exec error, err: %s", err) return } - err = tree.Unmarshal(&c.Tmpl) + err = tree.Unmarshal(&c.TmplOption) if err != nil { beeLogger.Log.Fatalf("beego tmpl parse error, err: %s", err) return } - if c.Option.Debug { - spew.Dump("tmpl", c.Tmpl) + if c.UserOption.Debug { + spew.Dump("tmpl", c.TmplOption) } - for _, value := range c.Tmpl.Descriptor { + for _, value := range c.TmplOption.Descriptor { if value.Once == true { c.FunctionOnce[value.SrcName] = sync.Once{} } } } -type modelInfo struct { - Module string - ModelName string - Option Option - Content ModelContent - Descriptor Descriptor - TmplPath string - GenerateTime string +func (c *Container) initParser() { + driver, flag := ParserDriver[c.UserOption.SourceGen] + if !flag { + beeLogger.Log.Fatalf("parse driver not exit, source gen %s", c.UserOption.SourceGen) + } + driver.RegisterOption(c.UserOption, c.TmplOption) + c.Parser = driver } func (c *Container) initRender() { - for _, desc := range c.Tmpl.Descriptor { + for _, desc := range c.TmplOption.Descriptor { _, allFlag := c.EnableModules["*"] _, moduleFlag := c.EnableModules[desc.Module] if !allFlag && !moduleFlag { continue } + models := c.Parser.GetRenderInfos(desc) // model table name, model table schema - for modelName, content := range c.Option.Models { - m := modelInfo{ - Module: desc.Module, - ModelName: modelName, - Content: content, - Option: c.Option, - Descriptor: desc, - TmplPath: c.Tmpl.RenderPath, - GenerateTime: c.GenerateTime, - } - + for _, m := range models { // some render exec once syncOnce, flag := c.FunctionOnce[desc.SrcName] if flag { @@ -175,7 +167,9 @@ func (c *Container) initRender() { } } -func (c *Container) renderModel(m modelInfo) { +func (c *Container) renderModel(m RenderInfo) { + // todo optimize + m.GenerateTime = c.GenerateTime render := NewRender(m) render.Exec(m.Descriptor.SrcName) if render.Descriptor.IsExistScript() { diff --git a/internal/app/module/beegopro/migration.go b/internal/app/module/beegopro/migration.go index 3626efc..11caa2c 100644 --- a/internal/app/module/beegopro/migration.go +++ b/internal/app/module/beegopro/migration.go @@ -11,10 +11,10 @@ import ( var SQL utils.DocValue func (c *Container) Migration(args []string) { - c.initBeegoPro() - db, err := sql.Open(c.Option.Driver, c.Option.Dsn) + c.initUserOption() + db, err := sql.Open(c.UserOption.Driver, c.UserOption.Dsn) if err != nil { - beeLogger.Log.Fatalf("Could not connect to '%s' database using '%s': %s", c.Option.Driver, c.Option.Dsn, err) + beeLogger.Log.Fatalf("Could not connect to '%s' database using '%s': %s", c.UserOption.Driver, c.UserOption.Dsn, err) return } diff --git a/internal/app/module/beegopro/parser.go b/internal/app/module/beegopro/parser.go new file mode 100644 index 0000000..5a8afb2 --- /dev/null +++ b/internal/app/module/beegopro/parser.go @@ -0,0 +1,13 @@ +package beegopro + +type Parser interface { + RegisterOption(userOption UserOption, tmplOption TmplOption) + Parse(descriptor Descriptor) + GetRenderInfos(descriptor Descriptor) (output []RenderInfo) + Unregister() +} + +var ParserDriver = map[string]Parser{ + "text": &TextParser{}, + "mysql": &MysqlParser{}, +} diff --git a/internal/app/module/beegopro/parser_mysql.go b/internal/app/module/beegopro/parser_mysql.go new file mode 100644 index 0000000..204c704 --- /dev/null +++ b/internal/app/module/beegopro/parser_mysql.go @@ -0,0 +1,88 @@ +package beegopro + +import ( + "database/sql" + "fmt" + "github.com/beego/bee/internal/pkg/utils" + beeLogger "github.com/beego/bee/logger" +) + +type MysqlParser struct { + userOption UserOption + tmplOption TmplOption + db *sql.DB +} + +func (m *MysqlParser) RegisterOption(userOption UserOption, tmplOption TmplOption) { + m.userOption = userOption + m.tmplOption = tmplOption + +} + +func (*MysqlParser) Parse(descriptor Descriptor) { + +} + +func (m *MysqlParser) GetRenderInfos(descriptor Descriptor) (output []RenderInfo) { + tableSchemas, err := m.getTableSchemas() + if err != nil { + beeLogger.Log.Fatalf("get table schemas err %s", err) + } + models := tableSchemas.ToTableMap() + + output = make([]RenderInfo, 0) + // model table name, model table schema + for modelName, content := range models { + output = append(output, RenderInfo{ + Module: descriptor.Module, + ModelName: modelName, + Content: content, + Option: m.userOption, + Descriptor: descriptor, + TmplPath: m.tmplOption.RenderPath, + }) + } + return +} + +func (t *MysqlParser) Unregister() { + +} + +func (m *MysqlParser) getTableSchemas() (resp TableSchemas, err error) { + dsn, err := utils.ParseDSN(m.userOption.Dsn) + if err != nil { + beeLogger.Log.Fatalf("parse dsn err %s", err) + return + } + + conn, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/information_schema", dsn.User, dsn.Passwd, dsn.Addr)) + if err != nil { + beeLogger.Log.Fatalf("Could not connect to mysql database using '%s': %s", m.userOption.Dsn, err) + return + } + defer conn.Close() + + q := `SELECT TABLE_NAME, COLUMN_NAME, IS_NULLABLE, DATA_TYPE, CHARACTER_MAXIMUM_LENGTH, +NUMERIC_PRECISION, NUMERIC_SCALE,COLUMN_TYPE,COLUMN_KEY,COLUMN_COMMENT +FROM COLUMNS WHERE TABLE_SCHEMA = ? ORDER BY TABLE_NAME, ORDINAL_POSITION` + rows, err := conn.Query(q, dsn.DBName) + if err != nil { + return nil, err + } + columns := make(TableSchemas, 0) + for rows.Next() { + cs := TableSchema{} + err := rows.Scan(&cs.TableName, &cs.ColumnName, &cs.IsNullable, &cs.DataType, + &cs.CharacterMaximumLength, &cs.NumericPrecision, &cs.NumericScale, + &cs.ColumnType, &cs.ColumnKey, &cs.Comment) + if err != nil { + return nil, err + } + columns = append(columns, cs) + } + if err := rows.Err(); err != nil { + return nil, err + } + return columns, nil +} diff --git a/internal/app/module/beegopro/parser_text.go b/internal/app/module/beegopro/parser_text.go new file mode 100644 index 0000000..dc14601 --- /dev/null +++ b/internal/app/module/beegopro/parser_text.go @@ -0,0 +1,35 @@ +package beegopro + +type TextParser struct { + userOption UserOption + tmplOption TmplOption +} + +func (t *TextParser) RegisterOption(userOption UserOption, tmplOption TmplOption) { + t.userOption = userOption + t.tmplOption = tmplOption +} + +func (*TextParser) Parse(descriptor Descriptor) { + +} + +func (t *TextParser) GetRenderInfos(descriptor Descriptor) (output []RenderInfo) { + output = make([]RenderInfo, 0) + // model table name, model table schema + for modelName, content := range t.userOption.Models { + output = append(output, RenderInfo{ + Module: descriptor.Module, + ModelName: modelName, + Content: content.ToModelInfos(), + Option: t.userOption, + Descriptor: descriptor, + TmplPath: t.tmplOption.RenderPath, + }) + } + return +} + +func (t *TextParser) Unregister() { + +} diff --git a/internal/app/module/beegopro/render.go b/internal/app/module/beegopro/render.go index 4dfa025..ba5d2b7 100644 --- a/internal/app/module/beegopro/render.go +++ b/internal/app/module/beegopro/render.go @@ -15,7 +15,7 @@ type RenderFile struct { *pongo2render.Render Context pongo2.Context GenerateTime string - Option Option + Option UserOption ModelName string PackageName string FlushFile string @@ -24,7 +24,7 @@ type RenderFile struct { Descriptor Descriptor } -func NewRender(m modelInfo) *RenderFile { +func NewRender(m RenderInfo) *RenderFile { var ( pathCtx pongo2.Context newDescriptor Descriptor diff --git a/internal/app/module/beegopro/schema.go b/internal/app/module/beegopro/schema.go index ceebfe8..f29d981 100644 --- a/internal/app/module/beegopro/schema.go +++ b/internal/app/module/beegopro/schema.go @@ -18,46 +18,41 @@ type Container struct { BeegoProFile string // beego pro toml TimestampFile string // store ts file GoModFile string // go mod file - Option Option // user option - Tmpl Tmpl // remote tmpl + UserOption UserOption // user option + TmplOption TmplOption // tmpl option CurPath string // user current path EnableModules map[string]interface{} // beego pro provider a collection of module FunctionOnce map[string]sync.Once // exec function once Timestamp Timestamp GenerateTime string GenerateTimeUnix int64 + Parser Parser } // user option -type Option struct { - Debug bool `json:"debug"` - ContextDebug bool `json:"contextDebug"` - Dsn string `json:"dsn"` - Driver string `json:"driver"` - ProType string `json:"proType"` - ApiPrefix string `json:"apiPrefix"` - EnableModule []string `json:"enableModule"` - Models map[string]ModelContent `json:"models"` - GitRemotePath string `json:"gitRemotePath"` - Branch string `json:"branch"` - GitLocalPath string `json:"gitLocalPath"` - EnableFormat bool `json:"enableFormat"` - SourceGen string `json:"sourceGen"` - EnableGitPull bool `json:"enbaleGitPull"` - Path map[string]string `json:"path"` - EnableGomod bool `json:"enableGomod"` - RefreshGitTime int64 `json:"refreshGitTime"` - Extend map[string]string `json:"extend"` // extend user data -} - -type BeegoPro struct { - FilePath string - CurPath string - Option Option +type UserOption struct { + Debug bool `json:"debug"` + ContextDebug bool `json:"contextDebug"` + Dsn string `json:"dsn"` + Driver string `json:"driver"` + ProType string `json:"proType"` + ApiPrefix string `json:"apiPrefix"` + EnableModule []string `json:"enableModule"` + Models map[string]TextModel `json:"models"` + GitRemotePath string `json:"gitRemotePath"` + Branch string `json:"branch"` + GitLocalPath string `json:"gitLocalPath"` + EnableFormat bool `json:"enableFormat"` + SourceGen string `json:"sourceGen"` + EnableGitPull bool `json:"enbaleGitPull"` + Path map[string]string `json:"path"` + EnableGomod bool `json:"enableGomod"` + RefreshGitTime int64 `json:"refreshGitTime"` + Extend map[string]string `json:"extend"` // extend user data } // tmpl option -type Tmpl struct { +type TmplOption struct { RenderPath string `toml:"renderPath"` Descriptor []Descriptor } @@ -126,7 +121,6 @@ func (d Descriptor) ExecScript(path string) (err error) { if len(arr) == 0 { return } - fmt.Println("path------>", path) stdout, stderr, err := command.ExecCmdDir(path, arr[0], arr[1:]...) if err != nil { @@ -134,7 +128,6 @@ func (d Descriptor) ExecScript(path string) (err error) { } beeLogger.Log.Info(stdout) - return nil } diff --git a/internal/app/module/beegopro/schema_model.go b/internal/app/module/beegopro/schema_model.go index 1cb5858..25e2c8c 100644 --- a/internal/app/module/beegopro/schema_model.go +++ b/internal/app/module/beegopro/schema_model.go @@ -1,18 +1,11 @@ package beegopro import ( - beeLogger "github.com/beego/bee/logger" "github.com/beego/bee/utils" "strings" ) -type ModelContent struct { - Names []string - Orms []string - Comments []string - Extends []string -} - +// parse get the model info type ModelInfo struct { Name string `json:"name"` // mysql name InputType string `json:"inputType"` // user input type @@ -37,49 +30,12 @@ func (m ModelInfo) IsPrimaryKey() (flag bool) { return } -func (content ModelContent) ToModelInfoArr() (output []ModelInfo) { - namesLen := len(content.Names) - ormsLen := len(content.Orms) - commentsLen := len(content.Comments) - if namesLen != ormsLen && namesLen != commentsLen { - beeLogger.Log.Fatalf("length error, namesLen is %d, ormsLen is %d, commentsLen is %d", namesLen, ormsLen, commentsLen) - } - extendLen := len(content.Extends) - if extendLen != 0 && extendLen != namesLen { - beeLogger.Log.Fatalf("extend length error, namesLen is %d, extendsLen is %d", namesLen, extendLen) - } - - output = make([]ModelInfo, 0) - for i, name := range content.Names { - comment := content.Comments[i] - if comment == "" { - comment = name - } - inputType, goType, mysqlType, ormTag := getModelType(content.Orms[i]) - - m := ModelInfo{ - Name: name, - InputType: inputType, - GoType: goType, - Orm: ormTag, - Comment: comment, - MysqlType: mysqlType, - Extend: "", - } - // extend value - if extendLen != 0 { - m.Extend = content.Extends[i] - } - output = append(output, m) - } - return -} - -func (content ModelContent) ToModelSchemas() (output ModelSchemas) { - modelInfoArr := content.ToModelInfoArr() +type ModelInfos []ModelInfo +// to render model schemas +func (modelInfos ModelInfos) ToModelSchemas() (output ModelSchemas) { output = make(ModelSchemas, 0) - for i, value := range modelInfoArr { + for i, value := range modelInfos { if i == 0 && !value.IsPrimaryKey() { inputType, goType, mysqlType, ormTag := getModelType("auto") output = append(output, &ModelSchema{ diff --git a/internal/app/module/beegopro/schema_mysql_model.go b/internal/app/module/beegopro/schema_mysql_model.go new file mode 100644 index 0000000..177aa0e --- /dev/null +++ b/internal/app/module/beegopro/schema_mysql_model.go @@ -0,0 +1,74 @@ +package beegopro + +import ( + "database/sql" + "errors" + "github.com/beego/bee/logger" +) + +type TableSchema struct { + TableName string + ColumnName string + IsNullable string + DataType string + CharacterMaximumLength sql.NullInt64 + NumericPrecision sql.NullInt64 + NumericScale sql.NullInt64 + ColumnType string + ColumnKey string + Comment string +} + +type TableSchemas []TableSchema + +func (tableSchemas TableSchemas) ToTableMap() (resp map[string]ModelInfos) { + + resp = make(map[string]ModelInfos) + for _, value := range tableSchemas { + if _, ok := resp[value.TableName]; !ok { + resp[value.TableName] = make(ModelInfos, 0) + } + + modelInfos := resp[value.TableName] + inputType, goType, err := value.ToGoType() + if err != nil { + beeLogger.Log.Fatalf("parse go type err %s", err) + return + } + + modelInfo := ModelInfo{ + Name: value.ColumnName, + InputType: inputType, + GoType: goType, + Comment: value.Comment, + } + + if value.ColumnKey == "PRI" { + modelInfo.Orm = "pk" + } + resp[value.TableName] = append(modelInfos, modelInfo) + } + return +} + +// GetGoDataType maps an SQL data type to Golang data type +func (col TableSchema) ToGoType() (inputType string, goType string, err error) { + switch col.DataType { + case "char", "varchar", "enum", "set", "text", "longtext", "mediumtext", "tinytext": + goType = "string" + case "blob", "mediumblob", "longblob", "varbinary", "binary": + goType = "[]byte" + case "date", "time", "datetime", "timestamp": + goType, inputType = "time.Time", "dateTime" + case "tinyint", "smallint", "int", "mediumint": + goType = "int" + case "bit", "bigint": + goType = "int64" + case "float", "decimal", "double": + goType = "float64" + } + if goType == "" { + err = errors.New("No compatible datatype (" + col.DataType + ", CamelName: " + col.ColumnName + ") found") + } + return +} diff --git a/internal/app/module/beegopro/schema_render.go b/internal/app/module/beegopro/schema_render.go new file mode 100644 index 0000000..e7bb531 --- /dev/null +++ b/internal/app/module/beegopro/schema_render.go @@ -0,0 +1,11 @@ +package beegopro + +type RenderInfo struct { + Module string + ModelName string + Option UserOption + Content ModelInfos + Descriptor Descriptor + TmplPath string + GenerateTime string +} diff --git a/internal/app/module/beegopro/schema_text_model.go b/internal/app/module/beegopro/schema_text_model.go new file mode 100644 index 0000000..3483d8f --- /dev/null +++ b/internal/app/module/beegopro/schema_text_model.go @@ -0,0 +1,50 @@ +package beegopro + +import ( + beeLogger "github.com/beego/bee/logger" +) + +type TextModel struct { + Names []string + Orms []string + Comments []string + Extends []string +} + +func (content TextModel) ToModelInfos() (output []ModelInfo) { + namesLen := len(content.Names) + ormsLen := len(content.Orms) + commentsLen := len(content.Comments) + if namesLen != ormsLen && namesLen != commentsLen { + beeLogger.Log.Fatalf("length error, namesLen is %d, ormsLen is %d, commentsLen is %d", namesLen, ormsLen, commentsLen) + } + extendLen := len(content.Extends) + if extendLen != 0 && extendLen != namesLen { + beeLogger.Log.Fatalf("extend length error, namesLen is %d, extendsLen is %d", namesLen, extendLen) + } + + output = make([]ModelInfo, 0) + for i, name := range content.Names { + comment := content.Comments[i] + if comment == "" { + comment = name + } + inputType, goType, mysqlType, ormTag := getModelType(content.Orms[i]) + + m := ModelInfo{ + Name: name, + InputType: inputType, + GoType: goType, + Orm: ormTag, + Comment: comment, + MysqlType: mysqlType, + Extend: "", + } + // extend value + if extendLen != 0 { + m.Extend = content.Extends[i] + } + output = append(output, m) + } + return +} diff --git a/internal/pkg/git/repository.go b/internal/pkg/git/repository.go index c2d1154..d32d597 100644 --- a/internal/pkg/git/repository.go +++ b/internal/pkg/git/repository.go @@ -12,7 +12,7 @@ import ( "strings" ) -// 获取某项目代码库的标签列表 +// git tag func GetTags(repoPath string, limit int) ([]string, error) { repo, err := OpenRepository(repoPath) if err != nil { @@ -57,13 +57,6 @@ func CloneORPullRepo(url string, dst string) error { if !utils.IsDir(dst) { return CloneRepo(url, dst) } else { - //projectName, err := getGitProjectName(url) - //if err != nil { - // return err - //} - - //fmt.Println("dst------>", dst) - //projectDir := dst + "/" + projectName utils.Mkdir(dst) repo, err := OpenRepository(dst) diff --git a/internal/pkg/utils/dsn.go b/internal/pkg/utils/dsn.go new file mode 100644 index 0000000..9948be4 --- /dev/null +++ b/internal/pkg/utils/dsn.go @@ -0,0 +1,113 @@ +package utils + +import ( + "errors" + "net/url" + "strings" +) + +// DSN ... +type DSN struct { + User string // Username + Passwd string // Password (requires User) + Net string // Network type + Addr string // Network address (requires Net) + DBName string // Database name + Params map[string]string // Connection parameters +} + +var ( + errInvalidDSNUnescaped = errors.New("invalid DSN: did you forget to escape a param value?") + errInvalidDSNAddr = errors.New("invalid DSN: network address not terminated (missing closing brace)") + errInvalidDSNNoSlash = errors.New("invalid DSN: missing the slash separating the database name") +) + +// ParseDSN parses the DSN string to a Config +func ParseDSN(dsn string) (cfg *DSN, err error) { + // New config with some default values + cfg = new(DSN) + + // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] + // Find the last '/' (since the password or the net addr might contain a '/') + foundSlash := false + for i := len(dsn) - 1; i >= 0; i-- { + if dsn[i] == '/' { + foundSlash = true + var j, k int + + // left part is empty if i <= 0 + if i > 0 { + // [username[:password]@][protocol[(address)]] + // Find the last '@' in dsn[:i] + for j = i; j >= 0; j-- { + if dsn[j] == '@' { + // username[:password] + // Find the first ':' in dsn[:j] + for k = 0; k < j; k++ { + if dsn[k] == ':' { + cfg.Passwd = dsn[k+1 : j] + break + } + } + cfg.User = dsn[:k] + + break + } + } + + // [protocol[(address)]] + // Find the first '(' in dsn[j+1:i] + for k = j + 1; k < i; k++ { + if dsn[k] == '(' { + // dsn[i-1] must be == ')' if an address is specified + if dsn[i-1] != ')' { + if strings.ContainsRune(dsn[k+1:i], ')') { + return nil, errInvalidDSNUnescaped + } + return nil, errInvalidDSNAddr + } + cfg.Addr = dsn[k+1 : i-1] + break + } + } + cfg.Net = dsn[j+1 : k] + } + + // dbname[?param1=value1&...¶mN=valueN] + // Find the first '?' in dsn[i+1:] + for j = i + 1; j < len(dsn); j++ { + if dsn[j] == '?' { + if err = parseDSNParams(cfg, dsn[j+1:]); err != nil { + return + } + break + } + } + cfg.DBName = dsn[i+1 : j] + + break + } + } + if !foundSlash && len(dsn) > 0 { + return nil, errInvalidDSNNoSlash + } + return +} + +func parseDSNParams(cfg *DSN, params string) (err error) { + for _, v := range strings.Split(params, "&") { + param := strings.SplitN(v, "=", 2) + if len(param) != 2 { + continue + } + // lazy init + if cfg.Params == nil { + cfg.Params = make(map[string]string) + } + value := param[1] + if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil { + return + } + } + return +}