beego pro add mysql source type

This commit is contained in:
yitea 2020-07-05 14:54:26 +08:00
parent 8758f6eaa1
commit 9db1e8fb4c
13 changed files with 452 additions and 132 deletions

View File

@ -20,7 +20,7 @@ var DefaultBeegoPro = &Container{
BeegoProFile: system.CurrentDir + "/beegopro.toml", BeegoProFile: system.CurrentDir + "/beegopro.toml",
TimestampFile: system.CurrentDir + "/beegopro.timestamp", TimestampFile: system.CurrentDir + "/beegopro.timestamp",
GoModFile: system.CurrentDir + "/go.mod", GoModFile: system.CurrentDir + "/go.mod",
Option: Option{ UserOption: UserOption{
Debug: false, Debug: false,
ContextDebug: false, ContextDebug: false,
Dsn: "", Dsn: "",
@ -28,7 +28,7 @@ var DefaultBeegoPro = &Container{
ProType: "default", ProType: "default",
ApiPrefix: "/", ApiPrefix: "/",
EnableModule: nil, EnableModule: nil,
Models: make(map[string]ModelContent, 0), Models: make(map[string]TextModel, 0),
GitRemotePath: "https://github.com/beego-dev/beego-pro.git", GitRemotePath: "https://github.com/beego-dev/beego-pro.git",
Branch: "master", Branch: "master",
GitLocalPath: system.BeegoHome + "/beego-pro", GitLocalPath: system.BeegoHome + "/beego-pro",
@ -43,7 +43,7 @@ var DefaultBeegoPro = &Container{
}, },
GenerateTime: time.Now().Format(MDateFormat), GenerateTime: time.Now().Format(MDateFormat),
GenerateTimeUnix: time.Now().Unix(), GenerateTimeUnix: time.Now().Unix(),
Tmpl: Tmpl{}, TmplOption: TmplOption{},
CurPath: system.CurrentDir, CurPath: system.CurrentDir,
EnableModules: make(map[string]interface{}, 0), // get the user configuration, get the enable module result EnableModules: make(map[string]interface{}, 0), // get the user configuration, get the enable module result
FunctionOnce: make(map[string]sync.Once, 0), // get the tmpl configuration, get the function once result FunctionOnce: make(map[string]sync.Once, 0), // get the tmpl configuration, get the function once result
@ -52,13 +52,14 @@ var DefaultBeegoPro = &Container{
func (c *Container) Run() { func (c *Container) Run() {
// init git refresh cache time // init git refresh cache time
c.initTimestamp() c.initTimestamp()
c.initBeegoPro() c.initUserOption()
c.initBeegoTmpl() c.initTemplateOption()
c.initParser()
c.initRender() c.initRender()
c.flushTimestamp() c.flushTimestamp()
} }
func (c *Container) initBeegoPro() { func (c *Container) initUserOption() {
if !utils.IsExist(c.BeegoProFile) { if !utils.IsExist(c.BeegoProFile) {
beeLogger.Log.Fatalf("beego pro config is not exist, beego json path: %s", c.BeegoProFile) beeLogger.Log.Fatalf("beego pro config is not exist, beego json path: %s", c.BeegoProFile)
return return
@ -70,23 +71,23 @@ func (c *Container) initBeegoPro() {
return return
} }
err = viper.Unmarshal(&c.Option) err = viper.Unmarshal(&c.UserOption)
if err != nil { if err != nil {
beeLogger.Log.Fatalf("beego pro config unmarshal error, err: %s", err.Error()) beeLogger.Log.Fatalf("beego pro config unmarshal error, err: %s", err.Error())
return return
} }
if c.Option.Debug { if c.UserOption.Debug {
viper.Debug() viper.Debug()
} }
if c.Option.EnableGomod { if c.UserOption.EnableGomod {
if !utils.IsExist(c.GoModFile) { if !utils.IsExist(c.GoModFile) {
beeLogger.Log.Fatalf("go mod not exist, please create go mod file") beeLogger.Log.Fatalf("go mod not exist, please create go mod file")
return return
} }
} }
for _, value := range c.Option.EnableModule { for _, value := range c.UserOption.EnableModule {
c.EnableModules[value] = struct{}{} c.EnableModules[value] = struct{}{}
} }
@ -94,14 +95,15 @@ func (c *Container) initBeegoPro() {
c.EnableModules["*"] = struct{}{} c.EnableModules["*"] = struct{}{}
} }
if c.Option.Debug { if c.UserOption.Debug {
fmt.Println("c.modules", c.EnableModules) fmt.Println("c.modules", c.EnableModules)
} }
} }
func (c *Container) initBeegoTmpl() { func (c *Container) initTemplateOption() {
if c.Option.EnableGitPull && (c.GenerateTimeUnix-c.Timestamp.GitCacheLastRefresh > c.Option.RefreshGitTime) { if c.UserOption.EnableGitPull && (c.GenerateTimeUnix-c.Timestamp.GitCacheLastRefresh > c.UserOption.RefreshGitTime) {
err := git.CloneORPullRepo(c.Option.GitRemotePath, c.Option.GitLocalPath) err := git.CloneORPullRepo(c.UserOption.GitRemotePath, c.UserOption.GitLocalPath)
if err != nil { if err != nil {
beeLogger.Log.Fatalf("beego pro git clone or pull repo error, err: %s", err) beeLogger.Log.Fatalf("beego pro git clone or pull repo error, err: %s", err)
return return
@ -109,59 +111,49 @@ func (c *Container) initBeegoTmpl() {
c.Timestamp.GitCacheLastRefresh = c.GenerateTimeUnix c.Timestamp.GitCacheLastRefresh = c.GenerateTimeUnix
} }
tree, err := toml.LoadFile(c.Option.GitLocalPath + "/" + c.Option.ProType + "/bee.toml") tree, err := toml.LoadFile(c.UserOption.GitLocalPath + "/" + c.UserOption.ProType + "/bee.toml")
if err != nil { if err != nil {
beeLogger.Log.Fatalf("beego tmpl exec error, err: %s", err) beeLogger.Log.Fatalf("beego tmpl exec error, err: %s", err)
return return
} }
err = tree.Unmarshal(&c.Tmpl) err = tree.Unmarshal(&c.TmplOption)
if err != nil { if err != nil {
beeLogger.Log.Fatalf("beego tmpl parse error, err: %s", err) beeLogger.Log.Fatalf("beego tmpl parse error, err: %s", err)
return return
} }
if c.Option.Debug { if c.UserOption.Debug {
spew.Dump("tmpl", c.Tmpl) spew.Dump("tmpl", c.TmplOption)
} }
for _, value := range c.Tmpl.Descriptor { for _, value := range c.TmplOption.Descriptor {
if value.Once == true { if value.Once == true {
c.FunctionOnce[value.SrcName] = sync.Once{} c.FunctionOnce[value.SrcName] = sync.Once{}
} }
} }
} }
type modelInfo struct { func (c *Container) initParser() {
Module string driver, flag := ParserDriver[c.UserOption.SourceGen]
ModelName string if !flag {
Option Option beeLogger.Log.Fatalf("parse driver not exit, source gen %s", c.UserOption.SourceGen)
Content ModelContent }
Descriptor Descriptor driver.RegisterOption(c.UserOption, c.TmplOption)
TmplPath string c.Parser = driver
GenerateTime string
} }
func (c *Container) initRender() { func (c *Container) initRender() {
for _, desc := range c.Tmpl.Descriptor { for _, desc := range c.TmplOption.Descriptor {
_, allFlag := c.EnableModules["*"] _, allFlag := c.EnableModules["*"]
_, moduleFlag := c.EnableModules[desc.Module] _, moduleFlag := c.EnableModules[desc.Module]
if !allFlag && !moduleFlag { if !allFlag && !moduleFlag {
continue continue
} }
models := c.Parser.GetRenderInfos(desc)
// model table name, model table schema // model table name, model table schema
for modelName, content := range c.Option.Models { for _, m := range models {
m := modelInfo{
Module: desc.Module,
ModelName: modelName,
Content: content,
Option: c.Option,
Descriptor: desc,
TmplPath: c.Tmpl.RenderPath,
GenerateTime: c.GenerateTime,
}
// some render exec once // some render exec once
syncOnce, flag := c.FunctionOnce[desc.SrcName] syncOnce, flag := c.FunctionOnce[desc.SrcName]
if flag { if flag {
@ -175,7 +167,9 @@ func (c *Container) initRender() {
} }
} }
func (c *Container) renderModel(m modelInfo) { func (c *Container) renderModel(m RenderInfo) {
// todo optimize
m.GenerateTime = c.GenerateTime
render := NewRender(m) render := NewRender(m)
render.Exec(m.Descriptor.SrcName) render.Exec(m.Descriptor.SrcName)
if render.Descriptor.IsExistScript() { if render.Descriptor.IsExistScript() {

View File

@ -11,10 +11,10 @@ import (
var SQL utils.DocValue var SQL utils.DocValue
func (c *Container) Migration(args []string) { func (c *Container) Migration(args []string) {
c.initBeegoPro() c.initUserOption()
db, err := sql.Open(c.Option.Driver, c.Option.Dsn) db, err := sql.Open(c.UserOption.Driver, c.UserOption.Dsn)
if err != nil { if err != nil {
beeLogger.Log.Fatalf("Could not connect to '%s' database using '%s': %s", c.Option.Driver, c.Option.Dsn, err) beeLogger.Log.Fatalf("Could not connect to '%s' database using '%s': %s", c.UserOption.Driver, c.UserOption.Dsn, err)
return return
} }

View File

@ -0,0 +1,13 @@
package beegopro
type Parser interface {
RegisterOption(userOption UserOption, tmplOption TmplOption)
Parse(descriptor Descriptor)
GetRenderInfos(descriptor Descriptor) (output []RenderInfo)
Unregister()
}
var ParserDriver = map[string]Parser{
"text": &TextParser{},
"mysql": &MysqlParser{},
}

View File

@ -0,0 +1,88 @@
package beegopro
import (
"database/sql"
"fmt"
"github.com/beego/bee/internal/pkg/utils"
beeLogger "github.com/beego/bee/logger"
)
type MysqlParser struct {
userOption UserOption
tmplOption TmplOption
db *sql.DB
}
func (m *MysqlParser) RegisterOption(userOption UserOption, tmplOption TmplOption) {
m.userOption = userOption
m.tmplOption = tmplOption
}
func (*MysqlParser) Parse(descriptor Descriptor) {
}
func (m *MysqlParser) GetRenderInfos(descriptor Descriptor) (output []RenderInfo) {
tableSchemas, err := m.getTableSchemas()
if err != nil {
beeLogger.Log.Fatalf("get table schemas err %s", err)
}
models := tableSchemas.ToTableMap()
output = make([]RenderInfo, 0)
// model table name, model table schema
for modelName, content := range models {
output = append(output, RenderInfo{
Module: descriptor.Module,
ModelName: modelName,
Content: content,
Option: m.userOption,
Descriptor: descriptor,
TmplPath: m.tmplOption.RenderPath,
})
}
return
}
func (t *MysqlParser) Unregister() {
}
func (m *MysqlParser) getTableSchemas() (resp TableSchemas, err error) {
dsn, err := utils.ParseDSN(m.userOption.Dsn)
if err != nil {
beeLogger.Log.Fatalf("parse dsn err %s", err)
return
}
conn, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/information_schema", dsn.User, dsn.Passwd, dsn.Addr))
if err != nil {
beeLogger.Log.Fatalf("Could not connect to mysql database using '%s': %s", m.userOption.Dsn, err)
return
}
defer conn.Close()
q := `SELECT TABLE_NAME, COLUMN_NAME, IS_NULLABLE, DATA_TYPE, CHARACTER_MAXIMUM_LENGTH,
NUMERIC_PRECISION, NUMERIC_SCALE,COLUMN_TYPE,COLUMN_KEY,COLUMN_COMMENT
FROM COLUMNS WHERE TABLE_SCHEMA = ? ORDER BY TABLE_NAME, ORDINAL_POSITION`
rows, err := conn.Query(q, dsn.DBName)
if err != nil {
return nil, err
}
columns := make(TableSchemas, 0)
for rows.Next() {
cs := TableSchema{}
err := rows.Scan(&cs.TableName, &cs.ColumnName, &cs.IsNullable, &cs.DataType,
&cs.CharacterMaximumLength, &cs.NumericPrecision, &cs.NumericScale,
&cs.ColumnType, &cs.ColumnKey, &cs.Comment)
if err != nil {
return nil, err
}
columns = append(columns, cs)
}
if err := rows.Err(); err != nil {
return nil, err
}
return columns, nil
}

View File

@ -0,0 +1,35 @@
package beegopro
type TextParser struct {
userOption UserOption
tmplOption TmplOption
}
func (t *TextParser) RegisterOption(userOption UserOption, tmplOption TmplOption) {
t.userOption = userOption
t.tmplOption = tmplOption
}
func (*TextParser) Parse(descriptor Descriptor) {
}
func (t *TextParser) GetRenderInfos(descriptor Descriptor) (output []RenderInfo) {
output = make([]RenderInfo, 0)
// model table name, model table schema
for modelName, content := range t.userOption.Models {
output = append(output, RenderInfo{
Module: descriptor.Module,
ModelName: modelName,
Content: content.ToModelInfos(),
Option: t.userOption,
Descriptor: descriptor,
TmplPath: t.tmplOption.RenderPath,
})
}
return
}
func (t *TextParser) Unregister() {
}

View File

@ -15,7 +15,7 @@ type RenderFile struct {
*pongo2render.Render *pongo2render.Render
Context pongo2.Context Context pongo2.Context
GenerateTime string GenerateTime string
Option Option Option UserOption
ModelName string ModelName string
PackageName string PackageName string
FlushFile string FlushFile string
@ -24,7 +24,7 @@ type RenderFile struct {
Descriptor Descriptor Descriptor Descriptor
} }
func NewRender(m modelInfo) *RenderFile { func NewRender(m RenderInfo) *RenderFile {
var ( var (
pathCtx pongo2.Context pathCtx pongo2.Context
newDescriptor Descriptor newDescriptor Descriptor

View File

@ -18,46 +18,41 @@ type Container struct {
BeegoProFile string // beego pro toml BeegoProFile string // beego pro toml
TimestampFile string // store ts file TimestampFile string // store ts file
GoModFile string // go mod file GoModFile string // go mod file
Option Option // user option UserOption UserOption // user option
Tmpl Tmpl // remote tmpl TmplOption TmplOption // tmpl option
CurPath string // user current path CurPath string // user current path
EnableModules map[string]interface{} // beego pro provider a collection of module EnableModules map[string]interface{} // beego pro provider a collection of module
FunctionOnce map[string]sync.Once // exec function once FunctionOnce map[string]sync.Once // exec function once
Timestamp Timestamp Timestamp Timestamp
GenerateTime string GenerateTime string
GenerateTimeUnix int64 GenerateTimeUnix int64
Parser Parser
} }
// user option // user option
type Option struct { type UserOption struct {
Debug bool `json:"debug"` Debug bool `json:"debug"`
ContextDebug bool `json:"contextDebug"` ContextDebug bool `json:"contextDebug"`
Dsn string `json:"dsn"` Dsn string `json:"dsn"`
Driver string `json:"driver"` Driver string `json:"driver"`
ProType string `json:"proType"` ProType string `json:"proType"`
ApiPrefix string `json:"apiPrefix"` ApiPrefix string `json:"apiPrefix"`
EnableModule []string `json:"enableModule"` EnableModule []string `json:"enableModule"`
Models map[string]ModelContent `json:"models"` Models map[string]TextModel `json:"models"`
GitRemotePath string `json:"gitRemotePath"` GitRemotePath string `json:"gitRemotePath"`
Branch string `json:"branch"` Branch string `json:"branch"`
GitLocalPath string `json:"gitLocalPath"` GitLocalPath string `json:"gitLocalPath"`
EnableFormat bool `json:"enableFormat"` EnableFormat bool `json:"enableFormat"`
SourceGen string `json:"sourceGen"` SourceGen string `json:"sourceGen"`
EnableGitPull bool `json:"enbaleGitPull"` EnableGitPull bool `json:"enbaleGitPull"`
Path map[string]string `json:"path"` Path map[string]string `json:"path"`
EnableGomod bool `json:"enableGomod"` EnableGomod bool `json:"enableGomod"`
RefreshGitTime int64 `json:"refreshGitTime"` RefreshGitTime int64 `json:"refreshGitTime"`
Extend map[string]string `json:"extend"` // extend user data Extend map[string]string `json:"extend"` // extend user data
}
type BeegoPro struct {
FilePath string
CurPath string
Option Option
} }
// tmpl option // tmpl option
type Tmpl struct { type TmplOption struct {
RenderPath string `toml:"renderPath"` RenderPath string `toml:"renderPath"`
Descriptor []Descriptor Descriptor []Descriptor
} }
@ -126,7 +121,6 @@ func (d Descriptor) ExecScript(path string) (err error) {
if len(arr) == 0 { if len(arr) == 0 {
return return
} }
fmt.Println("path------>", path)
stdout, stderr, err := command.ExecCmdDir(path, arr[0], arr[1:]...) stdout, stderr, err := command.ExecCmdDir(path, arr[0], arr[1:]...)
if err != nil { if err != nil {
@ -134,7 +128,6 @@ func (d Descriptor) ExecScript(path string) (err error) {
} }
beeLogger.Log.Info(stdout) beeLogger.Log.Info(stdout)
return nil return nil
} }

View File

@ -1,18 +1,11 @@
package beegopro package beegopro
import ( import (
beeLogger "github.com/beego/bee/logger"
"github.com/beego/bee/utils" "github.com/beego/bee/utils"
"strings" "strings"
) )
type ModelContent struct { // parse get the model info
Names []string
Orms []string
Comments []string
Extends []string
}
type ModelInfo struct { type ModelInfo struct {
Name string `json:"name"` // mysql name Name string `json:"name"` // mysql name
InputType string `json:"inputType"` // user input type InputType string `json:"inputType"` // user input type
@ -37,49 +30,12 @@ func (m ModelInfo) IsPrimaryKey() (flag bool) {
return return
} }
func (content ModelContent) ToModelInfoArr() (output []ModelInfo) { type ModelInfos []ModelInfo
namesLen := len(content.Names)
ormsLen := len(content.Orms)
commentsLen := len(content.Comments)
if namesLen != ormsLen && namesLen != commentsLen {
beeLogger.Log.Fatalf("length error, namesLen is %d, ormsLen is %d, commentsLen is %d", namesLen, ormsLen, commentsLen)
}
extendLen := len(content.Extends)
if extendLen != 0 && extendLen != namesLen {
beeLogger.Log.Fatalf("extend length error, namesLen is %d, extendsLen is %d", namesLen, extendLen)
}
output = make([]ModelInfo, 0)
for i, name := range content.Names {
comment := content.Comments[i]
if comment == "" {
comment = name
}
inputType, goType, mysqlType, ormTag := getModelType(content.Orms[i])
m := ModelInfo{
Name: name,
InputType: inputType,
GoType: goType,
Orm: ormTag,
Comment: comment,
MysqlType: mysqlType,
Extend: "",
}
// extend value
if extendLen != 0 {
m.Extend = content.Extends[i]
}
output = append(output, m)
}
return
}
func (content ModelContent) ToModelSchemas() (output ModelSchemas) {
modelInfoArr := content.ToModelInfoArr()
// to render model schemas
func (modelInfos ModelInfos) ToModelSchemas() (output ModelSchemas) {
output = make(ModelSchemas, 0) output = make(ModelSchemas, 0)
for i, value := range modelInfoArr { for i, value := range modelInfos {
if i == 0 && !value.IsPrimaryKey() { if i == 0 && !value.IsPrimaryKey() {
inputType, goType, mysqlType, ormTag := getModelType("auto") inputType, goType, mysqlType, ormTag := getModelType("auto")
output = append(output, &ModelSchema{ output = append(output, &ModelSchema{

View File

@ -0,0 +1,74 @@
package beegopro
import (
"database/sql"
"errors"
"github.com/beego/bee/logger"
)
type TableSchema struct {
TableName string
ColumnName string
IsNullable string
DataType string
CharacterMaximumLength sql.NullInt64
NumericPrecision sql.NullInt64
NumericScale sql.NullInt64
ColumnType string
ColumnKey string
Comment string
}
type TableSchemas []TableSchema
func (tableSchemas TableSchemas) ToTableMap() (resp map[string]ModelInfos) {
resp = make(map[string]ModelInfos)
for _, value := range tableSchemas {
if _, ok := resp[value.TableName]; !ok {
resp[value.TableName] = make(ModelInfos, 0)
}
modelInfos := resp[value.TableName]
inputType, goType, err := value.ToGoType()
if err != nil {
beeLogger.Log.Fatalf("parse go type err %s", err)
return
}
modelInfo := ModelInfo{
Name: value.ColumnName,
InputType: inputType,
GoType: goType,
Comment: value.Comment,
}
if value.ColumnKey == "PRI" {
modelInfo.Orm = "pk"
}
resp[value.TableName] = append(modelInfos, modelInfo)
}
return
}
// GetGoDataType maps an SQL data type to Golang data type
func (col TableSchema) ToGoType() (inputType string, goType string, err error) {
switch col.DataType {
case "char", "varchar", "enum", "set", "text", "longtext", "mediumtext", "tinytext":
goType = "string"
case "blob", "mediumblob", "longblob", "varbinary", "binary":
goType = "[]byte"
case "date", "time", "datetime", "timestamp":
goType, inputType = "time.Time", "dateTime"
case "tinyint", "smallint", "int", "mediumint":
goType = "int"
case "bit", "bigint":
goType = "int64"
case "float", "decimal", "double":
goType = "float64"
}
if goType == "" {
err = errors.New("No compatible datatype (" + col.DataType + ", CamelName: " + col.ColumnName + ") found")
}
return
}

View File

@ -0,0 +1,11 @@
package beegopro
type RenderInfo struct {
Module string
ModelName string
Option UserOption
Content ModelInfos
Descriptor Descriptor
TmplPath string
GenerateTime string
}

View File

@ -0,0 +1,50 @@
package beegopro
import (
beeLogger "github.com/beego/bee/logger"
)
type TextModel struct {
Names []string
Orms []string
Comments []string
Extends []string
}
func (content TextModel) ToModelInfos() (output []ModelInfo) {
namesLen := len(content.Names)
ormsLen := len(content.Orms)
commentsLen := len(content.Comments)
if namesLen != ormsLen && namesLen != commentsLen {
beeLogger.Log.Fatalf("length error, namesLen is %d, ormsLen is %d, commentsLen is %d", namesLen, ormsLen, commentsLen)
}
extendLen := len(content.Extends)
if extendLen != 0 && extendLen != namesLen {
beeLogger.Log.Fatalf("extend length error, namesLen is %d, extendsLen is %d", namesLen, extendLen)
}
output = make([]ModelInfo, 0)
for i, name := range content.Names {
comment := content.Comments[i]
if comment == "" {
comment = name
}
inputType, goType, mysqlType, ormTag := getModelType(content.Orms[i])
m := ModelInfo{
Name: name,
InputType: inputType,
GoType: goType,
Orm: ormTag,
Comment: comment,
MysqlType: mysqlType,
Extend: "",
}
// extend value
if extendLen != 0 {
m.Extend = content.Extends[i]
}
output = append(output, m)
}
return
}

View File

@ -12,7 +12,7 @@ import (
"strings" "strings"
) )
// 获取某项目代码库的标签列表 // git tag
func GetTags(repoPath string, limit int) ([]string, error) { func GetTags(repoPath string, limit int) ([]string, error) {
repo, err := OpenRepository(repoPath) repo, err := OpenRepository(repoPath)
if err != nil { if err != nil {
@ -57,13 +57,6 @@ func CloneORPullRepo(url string, dst string) error {
if !utils.IsDir(dst) { if !utils.IsDir(dst) {
return CloneRepo(url, dst) return CloneRepo(url, dst)
} else { } else {
//projectName, err := getGitProjectName(url)
//if err != nil {
// return err
//}
//fmt.Println("dst------>", dst)
//projectDir := dst + "/" + projectName
utils.Mkdir(dst) utils.Mkdir(dst)
repo, err := OpenRepository(dst) repo, err := OpenRepository(dst)

113
internal/pkg/utils/dsn.go Normal file
View File

@ -0,0 +1,113 @@
package utils
import (
"errors"
"net/url"
"strings"
)
// DSN ...
type DSN struct {
User string // Username
Passwd string // Password (requires User)
Net string // Network type
Addr string // Network address (requires Net)
DBName string // Database name
Params map[string]string // Connection parameters
}
var (
errInvalidDSNUnescaped = errors.New("invalid DSN: did you forget to escape a param value?")
errInvalidDSNAddr = errors.New("invalid DSN: network address not terminated (missing closing brace)")
errInvalidDSNNoSlash = errors.New("invalid DSN: missing the slash separating the database name")
)
// ParseDSN parses the DSN string to a Config
func ParseDSN(dsn string) (cfg *DSN, err error) {
// New config with some default values
cfg = new(DSN)
// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
// Find the last '/' (since the password or the net addr might contain a '/')
foundSlash := false
for i := len(dsn) - 1; i >= 0; i-- {
if dsn[i] == '/' {
foundSlash = true
var j, k int
// left part is empty if i <= 0
if i > 0 {
// [username[:password]@][protocol[(address)]]
// Find the last '@' in dsn[:i]
for j = i; j >= 0; j-- {
if dsn[j] == '@' {
// username[:password]
// Find the first ':' in dsn[:j]
for k = 0; k < j; k++ {
if dsn[k] == ':' {
cfg.Passwd = dsn[k+1 : j]
break
}
}
cfg.User = dsn[:k]
break
}
}
// [protocol[(address)]]
// Find the first '(' in dsn[j+1:i]
for k = j + 1; k < i; k++ {
if dsn[k] == '(' {
// dsn[i-1] must be == ')' if an address is specified
if dsn[i-1] != ')' {
if strings.ContainsRune(dsn[k+1:i], ')') {
return nil, errInvalidDSNUnescaped
}
return nil, errInvalidDSNAddr
}
cfg.Addr = dsn[k+1 : i-1]
break
}
}
cfg.Net = dsn[j+1 : k]
}
// dbname[?param1=value1&...&paramN=valueN]
// Find the first '?' in dsn[i+1:]
for j = i + 1; j < len(dsn); j++ {
if dsn[j] == '?' {
if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
return
}
break
}
}
cfg.DBName = dsn[i+1 : j]
break
}
}
if !foundSlash && len(dsn) > 0 {
return nil, errInvalidDSNNoSlash
}
return
}
func parseDSNParams(cfg *DSN, params string) (err error) {
for _, v := range strings.Split(params, "&") {
param := strings.SplitN(v, "=", 2)
if len(param) != 2 {
continue
}
// lazy init
if cfg.Params == nil {
cfg.Params = make(map[string]string)
}
value := param[1]
if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil {
return
}
}
return
}