diff --git a/.travis.yml b/.travis.yml index 3c821dcd..93536488 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,7 @@ language: go go: - - tip - - 1.6.0 + - 1.6 - 1.5.3 - 1.4.3 services: @@ -31,21 +30,20 @@ install: - go get github.com/belogik/goes - go get github.com/siddontang/ledisdb/config - go get github.com/siddontang/ledisdb/ledis - - go get golang.org/x/tools/cmd/vet - - go get github.com/golang/lint/golint - go get github.com/ssdb/gossdb/ssdb before_script: + - psql --version - sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi" - sh -c "if [ '$ORM_DRIVER' = 'mysql' ]; then mysql -u root -e 'create database orm_test;'; fi" - sh -c "if [ '$ORM_DRIVER' = 'sqlite' ]; then touch $TRAVIS_BUILD_DIR/orm_test.db; fi" + - sh -c "if [ $(go version) == *1.[5-9]* ]; then go get github.com/golang/lint/golint; golint ./...; fi" + - sh -c "if [ $(go version) == *1.[5-9]* ]; then go tool vet .; fi" - mkdir -p res/var - ./ssdb/ssdb-server ./ssdb/ssdb.conf -d after_script: -killall -w ssdb-server - rm -rf ./res/var/* script: - - go vet -x ./... - - $HOME/gopath/bin/golint ./... - go test -v ./... -notifications: - webhooks: https://hooks.pubu.im/services/z7m9bvybl3rgtg9 +addons: + postgresql: "9.4" diff --git a/README.md b/README.md index 6c589584..fbd7ccb7 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ func main(){ ``` ######Congratulations! You just built your first beego app. -Open your browser and visit `http://localhost:8000`. +Open your browser and visit `http://localhost:8080`. Please see [Documentation](http://beego.me/docs) for more. ## Features diff --git a/admin.go b/admin.go index 031e6421..a2b2f53a 100644 --- a/admin.go +++ b/admin.go @@ -23,7 +23,10 @@ import ( "text/template" "time" + "reflect" + "github.com/astaxie/beego/grace" + "github.com/astaxie/beego/logs" "github.com/astaxie/beego/toolbox" "github.com/astaxie/beego/utils" ) @@ -90,57 +93,9 @@ func listConf(rw http.ResponseWriter, r *http.Request) { switch command { case "conf": m := make(map[string]interface{}) + list("BConfig", BConfig, m) m["AppConfigPath"] = appConfigPath m["AppConfigProvider"] = appConfigProvider - m["BConfig.AppName"] = BConfig.AppName - m["BConfig.RunMode"] = BConfig.RunMode - m["BConfig.RouterCaseSensitive"] = BConfig.RouterCaseSensitive - m["BConfig.ServerName"] = BConfig.ServerName - m["BConfig.RecoverPanic"] = BConfig.RecoverPanic - m["BConfig.CopyRequestBody"] = BConfig.CopyRequestBody - m["BConfig.EnableGzip"] = BConfig.EnableGzip - m["BConfig.MaxMemory"] = BConfig.MaxMemory - m["BConfig.EnableErrorsShow"] = BConfig.EnableErrorsShow - m["BConfig.Listen.Graceful"] = BConfig.Listen.Graceful - m["BConfig.Listen.ServerTimeOut"] = BConfig.Listen.ServerTimeOut - m["BConfig.Listen.ListenTCP4"] = BConfig.Listen.ListenTCP4 - m["BConfig.Listen.EnableHTTP"] = BConfig.Listen.EnableHTTP - m["BConfig.Listen.HTTPAddr"] = BConfig.Listen.HTTPAddr - m["BConfig.Listen.HTTPPort"] = BConfig.Listen.HTTPPort - m["BConfig.Listen.EnableHTTPS"] = BConfig.Listen.EnableHTTPS - m["BConfig.Listen.HTTPSAddr"] = BConfig.Listen.HTTPSAddr - m["BConfig.Listen.HTTPSPort"] = BConfig.Listen.HTTPSPort - m["BConfig.Listen.HTTPSCertFile"] = BConfig.Listen.HTTPSCertFile - m["BConfig.Listen.HTTPSKeyFile"] = BConfig.Listen.HTTPSKeyFile - m["BConfig.Listen.EnableAdmin"] = BConfig.Listen.EnableAdmin - m["BConfig.Listen.AdminAddr"] = BConfig.Listen.AdminAddr - m["BConfig.Listen.AdminPort"] = BConfig.Listen.AdminPort - m["BConfig.Listen.EnableFcgi"] = BConfig.Listen.EnableFcgi - m["BConfig.Listen.EnableStdIo"] = BConfig.Listen.EnableStdIo - m["BConfig.WebConfig.AutoRender"] = BConfig.WebConfig.AutoRender - m["BConfig.WebConfig.EnableDocs"] = BConfig.WebConfig.EnableDocs - m["BConfig.WebConfig.FlashName"] = BConfig.WebConfig.FlashName - m["BConfig.WebConfig.FlashSeparator"] = BConfig.WebConfig.FlashSeparator - m["BConfig.WebConfig.DirectoryIndex"] = BConfig.WebConfig.DirectoryIndex - m["BConfig.WebConfig.StaticDir"] = BConfig.WebConfig.StaticDir - m["BConfig.WebConfig.StaticExtensionsToGzip"] = BConfig.WebConfig.StaticExtensionsToGzip - m["BConfig.WebConfig.TemplateLeft"] = BConfig.WebConfig.TemplateLeft - m["BConfig.WebConfig.TemplateRight"] = BConfig.WebConfig.TemplateRight - m["BConfig.WebConfig.ViewsPath"] = BConfig.WebConfig.ViewsPath - m["BConfig.WebConfig.EnableXSRF"] = BConfig.WebConfig.EnableXSRF - m["BConfig.WebConfig.XSRFKEY"] = BConfig.WebConfig.XSRFKey - m["BConfig.WebConfig.XSRFExpire"] = BConfig.WebConfig.XSRFExpire - m["BConfig.WebConfig.Session.SessionOn"] = BConfig.WebConfig.Session.SessionOn - m["BConfig.WebConfig.Session.SessionProvider"] = BConfig.WebConfig.Session.SessionProvider - m["BConfig.WebConfig.Session.SessionName"] = BConfig.WebConfig.Session.SessionName - m["BConfig.WebConfig.Session.SessionGCMaxLifetime"] = BConfig.WebConfig.Session.SessionGCMaxLifetime - m["BConfig.WebConfig.Session.SessionProviderConfig"] = BConfig.WebConfig.Session.SessionProviderConfig - m["BConfig.WebConfig.Session.SessionCookieLifeTime"] = BConfig.WebConfig.Session.SessionCookieLifeTime - m["BConfig.WebConfig.Session.SessionAutoSetCookie"] = BConfig.WebConfig.Session.SessionAutoSetCookie - m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain - m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs - m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum - m["BConfig.Log.Outputs"] = BConfig.Log.Outputs tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) tmpl = template.Must(tmpl.Parse(configTpl)) tmpl = template.Must(tmpl.Parse(defaultScriptsTpl)) @@ -196,7 +151,7 @@ func listConf(rw http.ResponseWriter, r *http.Request) { BeforeExec: "Before Exec", AfterExec: "After Exec", FinishRouter: "Finish Router"} { - if bf, ok := BeeApp.Handlers.filters[k]; ok { + if bf := BeeApp.Handlers.filters[k]; len(bf) > 0 { filterType = fr filterTypes = append(filterTypes, filterType) resultList := new([][]string) @@ -223,6 +178,28 @@ func listConf(rw http.ResponseWriter, r *http.Request) { } } +func list(root string, p interface{}, m map[string]interface{}) { + pt := reflect.TypeOf(p) + pv := reflect.ValueOf(p) + if pt.Kind() == reflect.Ptr { + pt = pt.Elem() + pv = pv.Elem() + } + for i := 0; i < pv.NumField(); i++ { + var key string + if root == "" { + key = pt.Field(i).Name + } else { + key = root + "." + pt.Field(i).Name + } + if pv.Field(i).Kind() == reflect.Struct { + list(key, pv.Field(i).Interface(), m) + } else { + m[key] = pv.Field(i).Interface() + } + } +} + func printTree(resultList *[][]string, t *Tree) { for _, tr := range t.fixrouters { printTree(resultList, tr) @@ -410,7 +387,7 @@ func (admin *adminApp) Run() { for p, f := range admin.routers { http.Handle(p, f) } - BeeLogger.Info("Admin server Running on %s", addr) + logs.Info("Admin server Running on %s", addr) var err error if BConfig.Listen.Graceful { @@ -419,6 +396,6 @@ func (admin *adminApp) Run() { err = http.ListenAndServe(addr, nil) } if err != nil { - BeeLogger.Critical("Admin ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) + logs.Critical("Admin ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) } } diff --git a/admin_test.go b/admin_test.go new file mode 100644 index 00000000..04744f8c --- /dev/null +++ b/admin_test.go @@ -0,0 +1,72 @@ +package beego + +import ( + "testing" + "fmt" +) + +func TestList_01(t *testing.T) { + m := make(map[string]interface{}) + list("BConfig", BConfig, m) + t.Log(m) + om := oldMap() + for k, v := range om { + if fmt.Sprint(m[k])!= fmt.Sprint(v) { + t.Log(k, "old-key",v,"new-key", m[k]) + t.FailNow() + } + } +} + +func oldMap() map[string]interface{} { + m := make(map[string]interface{}) + m["BConfig.AppName"] = BConfig.AppName + m["BConfig.RunMode"] = BConfig.RunMode + m["BConfig.RouterCaseSensitive"] = BConfig.RouterCaseSensitive + m["BConfig.ServerName"] = BConfig.ServerName + m["BConfig.RecoverPanic"] = BConfig.RecoverPanic + m["BConfig.CopyRequestBody"] = BConfig.CopyRequestBody + m["BConfig.EnableGzip"] = BConfig.EnableGzip + m["BConfig.MaxMemory"] = BConfig.MaxMemory + m["BConfig.EnableErrorsShow"] = BConfig.EnableErrorsShow + m["BConfig.Listen.Graceful"] = BConfig.Listen.Graceful + m["BConfig.Listen.ServerTimeOut"] = BConfig.Listen.ServerTimeOut + m["BConfig.Listen.ListenTCP4"] = BConfig.Listen.ListenTCP4 + m["BConfig.Listen.EnableHTTP"] = BConfig.Listen.EnableHTTP + m["BConfig.Listen.HTTPAddr"] = BConfig.Listen.HTTPAddr + m["BConfig.Listen.HTTPPort"] = BConfig.Listen.HTTPPort + m["BConfig.Listen.EnableHTTPS"] = BConfig.Listen.EnableHTTPS + m["BConfig.Listen.HTTPSAddr"] = BConfig.Listen.HTTPSAddr + m["BConfig.Listen.HTTPSPort"] = BConfig.Listen.HTTPSPort + m["BConfig.Listen.HTTPSCertFile"] = BConfig.Listen.HTTPSCertFile + m["BConfig.Listen.HTTPSKeyFile"] = BConfig.Listen.HTTPSKeyFile + m["BConfig.Listen.EnableAdmin"] = BConfig.Listen.EnableAdmin + m["BConfig.Listen.AdminAddr"] = BConfig.Listen.AdminAddr + m["BConfig.Listen.AdminPort"] = BConfig.Listen.AdminPort + m["BConfig.Listen.EnableFcgi"] = BConfig.Listen.EnableFcgi + m["BConfig.Listen.EnableStdIo"] = BConfig.Listen.EnableStdIo + m["BConfig.WebConfig.AutoRender"] = BConfig.WebConfig.AutoRender + m["BConfig.WebConfig.EnableDocs"] = BConfig.WebConfig.EnableDocs + m["BConfig.WebConfig.FlashName"] = BConfig.WebConfig.FlashName + m["BConfig.WebConfig.FlashSeparator"] = BConfig.WebConfig.FlashSeparator + m["BConfig.WebConfig.DirectoryIndex"] = BConfig.WebConfig.DirectoryIndex + m["BConfig.WebConfig.StaticDir"] = BConfig.WebConfig.StaticDir + m["BConfig.WebConfig.StaticExtensionsToGzip"] = BConfig.WebConfig.StaticExtensionsToGzip + m["BConfig.WebConfig.TemplateLeft"] = BConfig.WebConfig.TemplateLeft + m["BConfig.WebConfig.TemplateRight"] = BConfig.WebConfig.TemplateRight + m["BConfig.WebConfig.ViewsPath"] = BConfig.WebConfig.ViewsPath + m["BConfig.WebConfig.EnableXSRF"] = BConfig.WebConfig.EnableXSRF + m["BConfig.WebConfig.XSRFExpire"] = BConfig.WebConfig.XSRFExpire + m["BConfig.WebConfig.Session.SessionOn"] = BConfig.WebConfig.Session.SessionOn + m["BConfig.WebConfig.Session.SessionProvider"] = BConfig.WebConfig.Session.SessionProvider + m["BConfig.WebConfig.Session.SessionName"] = BConfig.WebConfig.Session.SessionName + m["BConfig.WebConfig.Session.SessionGCMaxLifetime"] = BConfig.WebConfig.Session.SessionGCMaxLifetime + m["BConfig.WebConfig.Session.SessionProviderConfig"] = BConfig.WebConfig.Session.SessionProviderConfig + m["BConfig.WebConfig.Session.SessionCookieLifeTime"] = BConfig.WebConfig.Session.SessionCookieLifeTime + m["BConfig.WebConfig.Session.SessionAutoSetCookie"] = BConfig.WebConfig.Session.SessionAutoSetCookie + m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain + m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs + m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum + m["BConfig.Log.Outputs"] = BConfig.Log.Outputs + return m +} diff --git a/app.go b/app.go index af54ea4b..423a0a6b 100644 --- a/app.go +++ b/app.go @@ -24,6 +24,7 @@ import ( "time" "github.com/astaxie/beego/grace" + "github.com/astaxie/beego/logs" "github.com/astaxie/beego/utils" ) @@ -68,9 +69,9 @@ func (app *App) Run() { if BConfig.Listen.EnableFcgi { if BConfig.Listen.EnableStdIo { if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O - BeeLogger.Info("Use FCGI via standard I/O") + logs.Info("Use FCGI via standard I/O") } else { - BeeLogger.Critical("Cannot use FCGI via standard I/O", err) + logs.Critical("Cannot use FCGI via standard I/O", err) } return } @@ -84,10 +85,10 @@ func (app *App) Run() { l, err = net.Listen("tcp", addr) } if err != nil { - BeeLogger.Critical("Listen: ", err) + logs.Critical("Listen: ", err) } if err = fcgi.Serve(l, app.Handlers); err != nil { - BeeLogger.Critical("fcgi.Serve: ", err) + logs.Critical("fcgi.Serve: ", err) } return } @@ -95,6 +96,7 @@ func (app *App) Run() { app.Server.Handler = app.Handlers app.Server.ReadTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second app.Server.WriteTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second + app.Server.ErrorLog = logs.GetLogger("HTTP") // run graceful mode if BConfig.Listen.Graceful { @@ -111,7 +113,7 @@ func (app *App) Run() { server.Server.ReadTimeout = app.Server.ReadTimeout server.Server.WriteTimeout = app.Server.WriteTimeout if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { - BeeLogger.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) + logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) time.Sleep(100 * time.Microsecond) endRunning <- true } @@ -126,7 +128,7 @@ func (app *App) Run() { server.Network = "tcp4" } if err := server.ListenAndServe(); err != nil { - BeeLogger.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) + logs.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) time.Sleep(100 * time.Microsecond) endRunning <- true } @@ -137,16 +139,18 @@ func (app *App) Run() { } // run normal mode - app.Server.Addr = addr if BConfig.Listen.EnableHTTPS { go func() { time.Sleep(20 * time.Microsecond) if BConfig.Listen.HTTPSPort != 0 { app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) + } else if BConfig.Listen.EnableHTTP { + BeeLogger.Info("Start https server error, confict with http.Please reset https port") + return } - BeeLogger.Info("https server Running on %s", app.Server.Addr) + logs.Info("https server Running on %s", app.Server.Addr) if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { - BeeLogger.Critical("ListenAndServeTLS: ", err) + logs.Critical("ListenAndServeTLS: ", err) time.Sleep(100 * time.Microsecond) endRunning <- true } @@ -155,24 +159,24 @@ func (app *App) Run() { if BConfig.Listen.EnableHTTP { go func() { app.Server.Addr = addr - BeeLogger.Info("http server Running on %s", app.Server.Addr) + logs.Info("http server Running on %s", app.Server.Addr) if BConfig.Listen.ListenTCP4 { ln, err := net.Listen("tcp4", app.Server.Addr) if err != nil { - BeeLogger.Critical("ListenAndServe: ", err) + logs.Critical("ListenAndServe: ", err) time.Sleep(100 * time.Microsecond) endRunning <- true return } if err = app.Server.Serve(ln); err != nil { - BeeLogger.Critical("ListenAndServe: ", err) + logs.Critical("ListenAndServe: ", err) time.Sleep(100 * time.Microsecond) endRunning <- true return } } else { if err := app.Server.ListenAndServe(); err != nil { - BeeLogger.Critical("ListenAndServe: ", err) + logs.Critical("ListenAndServe: ", err) time.Sleep(100 * time.Microsecond) endRunning <- true } diff --git a/beego.go b/beego.go index 8f82cdcf..32b64f75 100644 --- a/beego.go +++ b/beego.go @@ -51,6 +51,7 @@ func AddAPPStartHook(hf hookfunc) { // beego.Run(":8089") // beego.Run("127.0.0.1:8089") func Run(params ...string) { + initBeforeHTTPRun() if len(params) > 0 && params[0] != "" { @@ -71,9 +72,9 @@ func initBeforeHTTPRun() { AddAPPStartHook(registerMime) AddAPPStartHook(registerDefaultErrorHandler) AddAPPStartHook(registerSession) - AddAPPStartHook(registerDocs) AddAPPStartHook(registerTemplate) AddAPPStartHook(registerAdmin) + AddAPPStartHook(registerGzip) for _, hk := range hooks { if err := hk(); err != nil { @@ -84,8 +85,11 @@ func initBeforeHTTPRun() { // TestBeegoInit is for test package init func TestBeegoInit(ap string) { - os.Setenv("BEEGO_RUNMODE", "test") appConfigPath = filepath.Join(ap, "conf", "app.conf") os.Chdir(ap) + if err := LoadAppConfig(appConfigProvider, appConfigPath); err != nil { + panic(err) + } + BConfig.RunMode = "test" initBeforeHTTPRun() } diff --git a/cache/redis/redis_test.go b/cache/redis/redis_test.go index 47c5acc6..6b81da4d 100644 --- a/cache/redis/redis_test.go +++ b/cache/redis/redis_test.go @@ -18,9 +18,8 @@ import ( "testing" "time" - "github.com/garyburd/redigo/redis" - "github.com/astaxie/beego/cache" + "github.com/garyburd/redigo/redis" ) func TestRedisCache(t *testing.T) { diff --git a/cache/ssdb/ssdb_test.go b/cache/ssdb/ssdb_test.go index e03ba343..dd474960 100644 --- a/cache/ssdb/ssdb_test.go +++ b/cache/ssdb/ssdb_test.go @@ -1,10 +1,11 @@ package ssdb import ( - "github.com/astaxie/beego/cache" "strconv" "testing" "time" + + "github.com/astaxie/beego/cache" ) func TestSsdbcacheCache(t *testing.T) { diff --git a/config.go b/config.go index 067f0d1b..effc5e12 100644 --- a/config.go +++ b/config.go @@ -18,9 +18,11 @@ import ( "fmt" "os" "path/filepath" + "reflect" "strings" "github.com/astaxie/beego/config" + "github.com/astaxie/beego/logs" "github.com/astaxie/beego/session" "github.com/astaxie/beego/utils" ) @@ -81,14 +83,17 @@ type WebConfig struct { // SessionConfig holds session related config type SessionConfig struct { - SessionOn bool - SessionProvider string - SessionName string - SessionGCMaxLifetime int64 - SessionProviderConfig string - SessionCookieLifeTime int - SessionAutoSetCookie bool - SessionDomain string + SessionOn bool + SessionProvider string + SessionName string + SessionGCMaxLifetime int64 + SessionProviderConfig string + SessionCookieLifeTime int + SessionAutoSetCookie bool + SessionDomain string + EnableSidInHttpHeader bool // enable store/get the sessionId into/from http headers + SessionNameInHttpHeader string + EnableSidInUrlQuery bool // enable get the sessionId from Url Query params } // LogConfig holds Log related config @@ -115,11 +120,30 @@ var ( ) func init() { - AppPath, _ = filepath.Abs(filepath.Dir(os.Args[0])) + BConfig = newBConfig() + var err error + if AppPath, err = filepath.Abs(filepath.Dir(os.Args[0])); err != nil { + panic(err) + } + workPath, err := os.Getwd() + if err != nil { + panic(err) + } + appConfigPath = filepath.Join(workPath, "conf", "app.conf") + if !utils.FileExists(appConfigPath) { + appConfigPath = filepath.Join(AppPath, "conf", "app.conf") + if !utils.FileExists(appConfigPath) { + AppConfig = &beegoAppConfig{innerConfig: config.NewFakeConfig()} + return + } + } + if err = parseConfig(appConfigPath); err != nil { + panic(err) + } +} - os.Chdir(AppPath) - - BConfig = &Config{ +func newBConfig() *Config { + return &Config{ AppName: "beego", RunMode: DEV, RouterCaseSensitive: true, @@ -162,14 +186,17 @@ func init() { XSRFKey: "beegoxsrf", XSRFExpire: 0, Session: SessionConfig{ - SessionOn: false, - SessionProvider: "memory", - SessionName: "beegosessionID", - SessionGCMaxLifetime: 3600, - SessionProviderConfig: "", - SessionCookieLifeTime: 0, //set cookie default is the browser life - SessionAutoSetCookie: true, - SessionDomain: "", + SessionOn: false, + SessionProvider: "memory", + SessionName: "beegosessionID", + SessionGCMaxLifetime: 3600, + SessionProviderConfig: "", + SessionCookieLifeTime: 0, //set cookie default is the browser life + SessionAutoSetCookie: true, + SessionDomain: "", + EnableSidInHttpHeader: false, // enable store/get the sessionId into/from http headers + SessionNameInHttpHeader: "Beegosessionid", + EnableSidInUrlQuery: false, // enable get the sessionId from Url Query params }, }, Log: LogConfig{ @@ -178,16 +205,6 @@ func init() { Outputs: map[string]string{"console": ""}, }, } - - appConfigPath = filepath.Join(AppPath, "conf", "app.conf") - if !utils.FileExists(appConfigPath) { - AppConfig = &beegoAppConfig{innerConfig: config.NewFakeConfig()} - return - } - - if err := parseConfig(appConfigPath); err != nil { - panic(err) - } } // now only support ini, next will support json. @@ -196,63 +213,23 @@ func parseConfig(appConfigPath string) (err error) { if err != nil { return err } + return assignConfig(AppConfig) +} + +func assignConfig(ac config.Configer) error { // set the run mode first if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" { BConfig.RunMode = envRunMode - } else if runMode := AppConfig.String("RunMode"); runMode != "" { + } else if runMode := ac.String("RunMode"); runMode != "" { BConfig.RunMode = runMode } - BConfig.AppName = AppConfig.DefaultString("AppName", BConfig.AppName) - BConfig.RecoverPanic = AppConfig.DefaultBool("RecoverPanic", BConfig.RecoverPanic) - BConfig.RouterCaseSensitive = AppConfig.DefaultBool("RouterCaseSensitive", BConfig.RouterCaseSensitive) - BConfig.ServerName = AppConfig.DefaultString("ServerName", BConfig.ServerName) - BConfig.EnableGzip = AppConfig.DefaultBool("EnableGzip", BConfig.EnableGzip) - BConfig.EnableErrorsShow = AppConfig.DefaultBool("EnableErrorsShow", BConfig.EnableErrorsShow) - BConfig.CopyRequestBody = AppConfig.DefaultBool("CopyRequestBody", BConfig.CopyRequestBody) - BConfig.MaxMemory = AppConfig.DefaultInt64("MaxMemory", BConfig.MaxMemory) - BConfig.Listen.Graceful = AppConfig.DefaultBool("Graceful", BConfig.Listen.Graceful) - BConfig.Listen.HTTPAddr = AppConfig.String("HTTPAddr") - BConfig.Listen.HTTPPort = AppConfig.DefaultInt("HTTPPort", BConfig.Listen.HTTPPort) - BConfig.Listen.ListenTCP4 = AppConfig.DefaultBool("ListenTCP4", BConfig.Listen.ListenTCP4) - BConfig.Listen.EnableHTTP = AppConfig.DefaultBool("EnableHTTP", BConfig.Listen.EnableHTTP) - BConfig.Listen.EnableHTTPS = AppConfig.DefaultBool("EnableHTTPS", BConfig.Listen.EnableHTTPS) - BConfig.Listen.HTTPSAddr = AppConfig.DefaultString("HTTPSAddr", BConfig.Listen.HTTPSAddr) - BConfig.Listen.HTTPSPort = AppConfig.DefaultInt("HTTPSPort", BConfig.Listen.HTTPSPort) - BConfig.Listen.HTTPSCertFile = AppConfig.DefaultString("HTTPSCertFile", BConfig.Listen.HTTPSCertFile) - BConfig.Listen.HTTPSKeyFile = AppConfig.DefaultString("HTTPSKeyFile", BConfig.Listen.HTTPSKeyFile) - BConfig.Listen.EnableAdmin = AppConfig.DefaultBool("EnableAdmin", BConfig.Listen.EnableAdmin) - BConfig.Listen.AdminAddr = AppConfig.DefaultString("AdminAddr", BConfig.Listen.AdminAddr) - BConfig.Listen.AdminPort = AppConfig.DefaultInt("AdminPort", BConfig.Listen.AdminPort) - BConfig.Listen.EnableFcgi = AppConfig.DefaultBool("EnableFcgi", BConfig.Listen.EnableFcgi) - BConfig.Listen.EnableStdIo = AppConfig.DefaultBool("EnableStdIo", BConfig.Listen.EnableStdIo) - BConfig.Listen.ServerTimeOut = AppConfig.DefaultInt64("ServerTimeOut", BConfig.Listen.ServerTimeOut) - BConfig.WebConfig.AutoRender = AppConfig.DefaultBool("AutoRender", BConfig.WebConfig.AutoRender) - BConfig.WebConfig.ViewsPath = AppConfig.DefaultString("ViewsPath", BConfig.WebConfig.ViewsPath) - BConfig.WebConfig.DirectoryIndex = AppConfig.DefaultBool("DirectoryIndex", BConfig.WebConfig.DirectoryIndex) - BConfig.WebConfig.FlashName = AppConfig.DefaultString("FlashName", BConfig.WebConfig.FlashName) - BConfig.WebConfig.FlashSeparator = AppConfig.DefaultString("FlashSeparator", BConfig.WebConfig.FlashSeparator) - BConfig.WebConfig.EnableDocs = AppConfig.DefaultBool("EnableDocs", BConfig.WebConfig.EnableDocs) - BConfig.WebConfig.XSRFKey = AppConfig.DefaultString("XSRFKEY", BConfig.WebConfig.XSRFKey) - BConfig.WebConfig.EnableXSRF = AppConfig.DefaultBool("EnableXSRF", BConfig.WebConfig.EnableXSRF) - BConfig.WebConfig.XSRFExpire = AppConfig.DefaultInt("XSRFExpire", BConfig.WebConfig.XSRFExpire) - BConfig.WebConfig.TemplateLeft = AppConfig.DefaultString("TemplateLeft", BConfig.WebConfig.TemplateLeft) - BConfig.WebConfig.TemplateRight = AppConfig.DefaultString("TemplateRight", BConfig.WebConfig.TemplateRight) - BConfig.WebConfig.Session.SessionOn = AppConfig.DefaultBool("SessionOn", BConfig.WebConfig.Session.SessionOn) - BConfig.WebConfig.Session.SessionProvider = AppConfig.DefaultString("SessionProvider", BConfig.WebConfig.Session.SessionProvider) - BConfig.WebConfig.Session.SessionName = AppConfig.DefaultString("SessionName", BConfig.WebConfig.Session.SessionName) - BConfig.WebConfig.Session.SessionProviderConfig = AppConfig.DefaultString("SessionProviderConfig", BConfig.WebConfig.Session.SessionProviderConfig) - BConfig.WebConfig.Session.SessionGCMaxLifetime = AppConfig.DefaultInt64("SessionGCMaxLifetime", BConfig.WebConfig.Session.SessionGCMaxLifetime) - BConfig.WebConfig.Session.SessionCookieLifeTime = AppConfig.DefaultInt("SessionCookieLifeTime", BConfig.WebConfig.Session.SessionCookieLifeTime) - BConfig.WebConfig.Session.SessionAutoSetCookie = AppConfig.DefaultBool("SessionAutoSetCookie", BConfig.WebConfig.Session.SessionAutoSetCookie) - BConfig.WebConfig.Session.SessionDomain = AppConfig.DefaultString("SessionDomain", BConfig.WebConfig.Session.SessionDomain) - BConfig.Log.AccessLogs = AppConfig.DefaultBool("LogAccessLogs", BConfig.Log.AccessLogs) - BConfig.Log.FileLineNum = AppConfig.DefaultBool("LogFileLineNum", BConfig.Log.FileLineNum) + for _, i := range []interface{}{BConfig, &BConfig.Listen, &BConfig.WebConfig, &BConfig.Log, &BConfig.WebConfig.Session} { + assignSingleConfig(i, ac) + } - if sd := AppConfig.String("StaticDir"); sd != "" { - for k := range BConfig.WebConfig.StaticDir { - delete(BConfig.WebConfig.StaticDir, k) - } + if sd := ac.String("StaticDir"); sd != "" { + BConfig.WebConfig.StaticDir = map[string]string{} sds := strings.Fields(sd) for _, v := range sds { if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 { @@ -262,7 +239,8 @@ func parseConfig(appConfigPath string) (err error) { } } } - if sgz := AppConfig.String("StaticExtensionsToGzip"); sgz != "" { + + if sgz := ac.String("StaticExtensionsToGzip"); sgz != "" { extensions := strings.Split(sgz, ",") fileExts := []string{} for _, ext := range extensions { @@ -280,7 +258,7 @@ func parseConfig(appConfigPath string) (err error) { } } - if lo := AppConfig.String("LogOutputs"); lo != "" { + if lo := ac.String("LogOutputs"); lo != "" { los := strings.Split(lo, ";") for _, v := range los { if logType2Config := strings.SplitN(v, ",", 2); len(logType2Config) == 2 { @@ -292,18 +270,50 @@ func parseConfig(appConfigPath string) (err error) { } //init log - BeeLogger.Reset() + logs.Reset() for adaptor, config := range BConfig.Log.Outputs { - err = BeeLogger.SetLogger(adaptor, config) + err := logs.SetLogger(adaptor, config) if err != nil { - fmt.Printf("%s with the config `%s` got err:%s\n", adaptor, config, err) + fmt.Fprintln(os.Stderr, fmt.Sprintf("%s with the config %q got err:%s", adaptor, config, err.Error())) } } - SetLogFuncCall(BConfig.Log.FileLineNum) + logs.SetLogFuncCall(BConfig.Log.FileLineNum) return nil } +func assignSingleConfig(p interface{}, ac config.Configer) { + pt := reflect.TypeOf(p) + if pt.Kind() != reflect.Ptr { + return + } + pt = pt.Elem() + if pt.Kind() != reflect.Struct { + return + } + pv := reflect.ValueOf(p).Elem() + + for i := 0; i < pt.NumField(); i++ { + pf := pv.Field(i) + if !pf.CanSet() { + continue + } + name := pt.Field(i).Name + switch pf.Kind() { + case reflect.String: + pf.SetString(ac.DefaultString(name, pf.String())) + case reflect.Int, reflect.Int64: + pf.SetInt(int64(ac.DefaultInt64(name, pf.Int()))) + case reflect.Bool: + pf.SetBool(ac.DefaultBool(name, pf.Bool())) + case reflect.Struct: + default: + //do nothing here + } + } + +} + // LoadAppConfig allow developer to apply a config file func LoadAppConfig(adapterName, configPath string) error { absConfigPath, err := filepath.Abs(configPath) @@ -315,10 +325,6 @@ func LoadAppConfig(adapterName, configPath string) error { return fmt.Errorf("the target config file: %s don't exist", configPath) } - if absConfigPath == appConfigPath { - return nil - } - appConfigPath = absConfigPath appConfigProvider = adapterName @@ -352,7 +358,7 @@ func (b *beegoAppConfig) String(key string) string { } func (b *beegoAppConfig) Strings(key string) []string { - if v := b.innerConfig.Strings(BConfig.RunMode + "::" + key); v[0] != "" { + if v := b.innerConfig.Strings(BConfig.RunMode + "::" + key); len(v) > 0 { return v } return b.innerConfig.Strings(key) diff --git a/config/config.go b/config/config.go index c0afec05..9f41fb79 100644 --- a/config/config.go +++ b/config/config.go @@ -12,11 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package config is used to parse config +// Package config is used to parse config. // Usage: -// import( -// "github.com/astaxie/beego/config" -// ) +// import "github.com/astaxie/beego/config" +//Examples. // // cnf, err := config.NewConfig("ini", "config.conf") // @@ -38,12 +37,12 @@ // cnf.DIY(key string) (interface{}, error) // cnf.GetSection(section string) (map[string]string, error) // cnf.SaveConfigFile(filename string) error -// -// more docs http://beego.me/docs/module/config.md +//More docs http://beego.me/docs/module/config.md package config import ( "fmt" + "os" ) // Configer defines how to get and set value from configuration raw data. @@ -107,6 +106,69 @@ func NewConfigData(adapterName string, data []byte) (Configer, error) { return adapter.ParseData(data) } +// ExpandValueEnvForMap convert all string value with environment variable. +func ExpandValueEnvForMap(m map[string]interface{}) map[string]interface{} { + for k, v := range m { + switch value := v.(type) { + case string: + m[k] = ExpandValueEnv(value) + case map[string]interface{}: + m[k] = ExpandValueEnvForMap(value) + case map[string]string: + for k2, v2 := range value { + value[k2] = ExpandValueEnv(v2) + } + m[k] = value + } + } + return m +} + +// ExpandValueEnv returns value of convert with environment variable. +// +// Return environment variable if value start with "${" and end with "}". +// Return default value if environment variable is empty or not exist. +// +// It accept value formats "${env}" , "${env||}}" , "${env||defaultValue}" , "defaultvalue". +// Examples: +// v1 := config.ExpandValueEnv("${GOPATH}") // return the GOPATH environment variable. +// v2 := config.ExpandValueEnv("${GOAsta||/usr/local/go}") // return the default value "/usr/local/go/". +// v3 := config.ExpandValueEnv("Astaxie") // return the value "Astaxie". +func ExpandValueEnv(value string) (realValue string) { + realValue = value + + vLen := len(value) + // 3 = ${} + if vLen < 3 { + return + } + // Need start with "${" and end with "}", then return. + if value[0] != '$' || value[1] != '{' || value[vLen-1] != '}' { + return + } + + key := "" + defalutV := "" + // value start with "${" + for i := 2; i < vLen; i++ { + if value[i] == '|' && (i+1 < vLen && value[i+1] == '|') { + key = value[2:i] + defalutV = value[i+2 : vLen-1] // other string is default value. + break + } else if value[i] == '}' { + key = value[2:i] + break + } + } + + realValue = os.Getenv(key) + if realValue == "" { + realValue = defalutV + } + + return +} + // ParseBool returns the boolean value represented by the string. // // It accepts 1, 1.0, t, T, TRUE, true, True, YES, yes, Yes,Y, y, ON, on, On, diff --git a/config/config_test.go b/config/config_test.go new file mode 100644 index 00000000..15d6ffa6 --- /dev/null +++ b/config/config_test.go @@ -0,0 +1,55 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "os" + "testing" +) + +func TestExpandValueEnv(t *testing.T) { + + testCases := []struct { + item string + want string + }{ + {"", ""}, + {"$", "$"}, + {"{", "{"}, + {"{}", "{}"}, + {"${}", ""}, + {"${|}", ""}, + {"${}", ""}, + {"${{}}", ""}, + {"${{||}}", "}"}, + {"${pwd||}", ""}, + {"${pwd||}", ""}, + {"${pwd||}", ""}, + {"${pwd||}}", "}"}, + {"${pwd||{{||}}}", "{{||}}"}, + {"${GOPATH}", os.Getenv("GOPATH")}, + {"${GOPATH||}", os.Getenv("GOPATH")}, + {"${GOPATH||root}", os.Getenv("GOPATH")}, + {"${GOPATH_NOT||root}", "root"}, + {"${GOPATH_NOT||||root}", "||root"}, + } + + for _, c := range testCases { + if got := ExpandValueEnv(c.item); got != c.want { + t.Errorf("expand value error, item %q want %q, got %q", c.item, c.want, got) + } + } + +} diff --git a/config/fake.go b/config/fake.go index 7e362608..f5144598 100644 --- a/config/fake.go +++ b/config/fake.go @@ -38,7 +38,7 @@ func (c *fakeConfigContainer) String(key string) string { } func (c *fakeConfigContainer) DefaultString(key string, defaultval string) string { - v := c.getData(key) + v := c.String(key) if v == "" { return defaultval } @@ -46,7 +46,7 @@ func (c *fakeConfigContainer) DefaultString(key string, defaultval string) strin } func (c *fakeConfigContainer) Strings(key string) []string { - v := c.getData(key) + v := c.String(key) if v == "" { return nil } diff --git a/config/ini.go b/config/ini.go index 9c19b9b1..53bd992d 100644 --- a/config/ini.go +++ b/config/ini.go @@ -82,6 +82,10 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) { if err == io.EOF { break } + //It might be a good idea to throw a error on all unknonw errors? + if _, ok := err.(*os.PathError); ok { + return nil, err + } if bytes.Equal(line, bEmpty) { continue } @@ -162,7 +166,7 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) { val = bytes.Trim(val, `"`) } - cfg.data[section][key] = string(val) + cfg.data[section][key] = ExpandValueEnv(string(val)) if comment.Len() > 0 { cfg.keyComment[section+"."+key] = comment.String() comment.Reset() @@ -296,7 +300,9 @@ func (c *IniConfigContainer) GetSection(section string) (map[string]string, erro return nil, errors.New("not exist setction") } -// SaveConfigFile save the config into file +// SaveConfigFile save the config into file. +// +// BUG(env): The environment variable config item will be saved with real value in SaveConfigFile Funcation. func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { // Write configuration file by filename. f, err := os.Create(filename) diff --git a/config/ini_test.go b/config/ini_test.go index 93fce61f..83ff3668 100644 --- a/config/ini_test.go +++ b/config/ini_test.go @@ -42,11 +42,14 @@ needlogin = ON enableSession = Y enableCookie = N flag = 1 +path1 = ${GOPATH} +path2 = ${GOPATH||/home/go} [demo] key1="asta" key2 = "xie" CaseInsensitive = true peers = one;two;three +password = ${GOPATH} ` keyValue = map[string]interface{}{ @@ -64,10 +67,13 @@ peers = one;two;three "enableSession": true, "enableCookie": false, "flag": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), "demo::key1": "asta", "demo::key2": "xie", "demo::CaseInsensitive": true, "demo::peers": []string{"one", "two", "three"}, + "demo::password": os.Getenv("GOPATH"), "null": "", "demo2::key1": "", "error": "", diff --git a/config/json.go b/config/json.go index fce517eb..a0d93210 100644 --- a/config/json.go +++ b/config/json.go @@ -57,6 +57,9 @@ func (js *JSONConfig) ParseData(data []byte) (Configer, error) { } x.data["rootArray"] = wrappingArray } + + x.data = ExpandValueEnvForMap(x.data) + return x, nil } diff --git a/config/json_test.go b/config/json_test.go index df663461..24ff9644 100644 --- a/config/json_test.go +++ b/config/json_test.go @@ -86,16 +86,19 @@ func TestJson(t *testing.T) { "enableSession": "Y", "enableCookie": "N", "flag": 1, +"path1": "${GOPATH}", +"path2": "${GOPATH||/home/go}", "database": { "host": "host", "port": "port", "database": "database", "username": "username", - "password": "password", + "password": "${GOPATH}", "conns":{ "maxconnection":12, "autoconnect":true, - "connectioninfo":"info" + "connectioninfo":"info", + "root": "${GOPATH}" } } }` @@ -115,13 +118,16 @@ func TestJson(t *testing.T) { "enableSession": true, "enableCookie": false, "flag": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), "database::host": "host", "database::port": "port", "database::database": "database", - "database::password": "password", + "database::password": os.Getenv("GOPATH"), "database::conns::maxconnection": 12, "database::conns::autoconnect": true, "database::conns::connectioninfo": "info", + "database::conns::root": os.Getenv("GOPATH"), "unknown": "", } ) diff --git a/config/xml/xml.go b/config/xml/xml.go index b5291bf4..0c4e4d27 100644 --- a/config/xml/xml.go +++ b/config/xml/xml.go @@ -12,21 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package xml for config provider +// Package xml for config provider. // -// depend on github.com/beego/x2j +// depend on github.com/beego/x2j. // -// go install github.com/beego/x2j +// go install github.com/beego/x2j. // // Usage: -// import( -// _ "github.com/astaxie/beego/config/xml" -// "github.com/astaxie/beego/config" -// ) +// import( +// _ "github.com/astaxie/beego/config/xml" +// "github.com/astaxie/beego/config" +// ) // // cnf, err := config.NewConfig("xml", "config.xml") // -// more docs http://beego.me/docs/module/config.md +//More docs http://beego.me/docs/module/config.md package xml import ( @@ -69,7 +69,7 @@ func (xc *Config) Parse(filename string) (config.Configer, error) { return nil, err } - x.data = d["config"].(map[string]interface{}) + x.data = config.ExpandValueEnvForMap(d["config"].(map[string]interface{})) return x, nil } @@ -92,7 +92,7 @@ type ConfigContainer struct { // Bool returns the boolean value for a given key. func (c *ConfigContainer) Bool(key string) (bool, error) { - if v, ok := c.data[key]; ok { + if v := c.data[key]; v != nil { return config.ParseBool(v) } return false, fmt.Errorf("not exist key: %q", key) diff --git a/config/xml/xml_test.go b/config/xml/xml_test.go index 60dcba54..d8a09a59 100644 --- a/config/xml/xml_test.go +++ b/config/xml/xml_test.go @@ -15,14 +15,18 @@ package xml import ( + "fmt" "os" "testing" "github.com/astaxie/beego/config" ) -//xml parse should incluce in tags -var xmlcontext = ` +func TestXML(t *testing.T) { + + var ( + //xml parse should incluce in tags + xmlcontext = ` beeapi 8080 @@ -31,10 +35,25 @@ var xmlcontext = ` dev false true +${GOPATH} +${GOPATH||/home/go} ` + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "error": "", + "emptystrings": []string{}, + } + ) -func TestXML(t *testing.T) { f, err := os.Create("testxml.conf") if err != nil { t.Fatal(err) @@ -50,39 +69,42 @@ func TestXML(t *testing.T) { if err != nil { t.Fatal(err) } - if xmlconf.String("appname") != "beeapi" { - t.Fatal("appname not equal to beeapi") - } - if port, err := xmlconf.Int("httpport"); err != nil || port != 8080 { - t.Error(port) - t.Fatal(err) - } - if port, err := xmlconf.Int64("mysqlport"); err != nil || port != 3600 { - t.Error(port) - t.Fatal(err) - } - if pi, err := xmlconf.Float("PI"); err != nil || pi != 3.1415976 { - t.Error(pi) - t.Fatal(err) - } - if xmlconf.String("runmode") != "dev" { - t.Fatal("runmode not equal to dev") - } - if v, err := xmlconf.Bool("autorender"); err != nil || v != false { - t.Error(v) - t.Fatal(err) - } - if v, err := xmlconf.Bool("copyrequestbody"); err != nil || v != true { - t.Error(v) - t.Fatal(err) + + for k, v := range keyValue { + + var ( + value interface{} + err error + ) + + switch v.(type) { + case int: + value, err = xmlconf.Int(k) + case int64: + value, err = xmlconf.Int64(k) + case float64: + value, err = xmlconf.Float(k) + case bool: + value, err = xmlconf.Bool(k) + case []string: + value = xmlconf.Strings(k) + case string: + value = xmlconf.String(k) + default: + value, err = xmlconf.DIY(k) + } + if err != nil { + t.Errorf("get key %q value fatal,%v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Errorf("get key %q value, want %v got %v .", k, v, value) + } + } + if err = xmlconf.Set("name", "astaxie"); err != nil { t.Fatal(err) } if xmlconf.String("name") != "astaxie" { t.Fatal("get name error") } - if xmlconf.Strings("emptystrings") != nil { - t.Fatal("get emtpy strings error") - } } diff --git a/config/yaml/yaml.go b/config/yaml/yaml.go index 7e1d0426..64e25cb3 100644 --- a/config/yaml/yaml.go +++ b/config/yaml/yaml.go @@ -19,14 +19,14 @@ // go install github.com/beego/goyaml2 // // Usage: -// import( +// import( // _ "github.com/astaxie/beego/config/yaml" -// "github.com/astaxie/beego/config" -// ) +// "github.com/astaxie/beego/config" +// ) // // cnf, err := config.NewConfig("yaml", "config.yaml") // -// more docs http://beego.me/docs/module/config.md +//More docs http://beego.me/docs/module/config.md package yaml import ( @@ -110,6 +110,7 @@ func ReadYmlReader(path string) (cnf map[string]interface{}, err error) { log.Println("Not a Map? >> ", string(buf), data) cnf = nil } + cnf = config.ExpandValueEnvForMap(cnf) return } @@ -121,10 +122,11 @@ type ConfigContainer struct { // Bool returns the boolean value for a given key. func (c *ConfigContainer) Bool(key string) (bool, error) { - if v, ok := c.data[key]; ok { - return config.ParseBool(v) + v, err := c.getData(key) + if err != nil { + return false, err } - return false, fmt.Errorf("not exist key: %q", key) + return config.ParseBool(v) } // DefaultBool return the bool value if has no error @@ -139,8 +141,12 @@ func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool { // Int returns the integer value for a given key. func (c *ConfigContainer) Int(key string) (int, error) { - if v, ok := c.data[key].(int64); ok { - return int(v), nil + if v, err := c.getData(key); err != nil { + return 0, err + } else if vv, ok := v.(int); ok { + return vv, nil + } else if vv, ok := v.(int64); ok { + return int(vv), nil } return 0, errors.New("not int value") } @@ -157,8 +163,10 @@ func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { // Int64 returns the int64 value for a given key. func (c *ConfigContainer) Int64(key string) (int64, error) { - if v, ok := c.data[key].(int64); ok { - return v, nil + if v, err := c.getData(key); err != nil { + return 0, err + } else if vv, ok := v.(int64); ok { + return vv, nil } return 0, errors.New("not bool value") } @@ -175,8 +183,14 @@ func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { // Float returns the float value for a given key. func (c *ConfigContainer) Float(key string) (float64, error) { - if v, ok := c.data[key].(float64); ok { - return v, nil + if v, err := c.getData(key); err != nil { + return 0.0, err + } else if vv, ok := v.(float64); ok { + return vv, nil + } else if vv, ok := v.(int); ok { + return float64(vv), nil + } else if vv, ok := v.(int64); ok { + return float64(vv), nil } return 0.0, errors.New("not float64 value") } @@ -193,8 +207,10 @@ func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { // String returns the string value for a given key. func (c *ConfigContainer) String(key string) string { - if v, ok := c.data[key].(string); ok { - return v + if v, err := c.getData(key); err == nil { + if vv, ok := v.(string); ok { + return vv + } } return "" } @@ -230,8 +246,8 @@ func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []stri // GetSection returns map for the given section func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { - v, ok := c.data[section] - if ok { + + if v, ok := c.data[section]; ok { return v.(map[string]string), nil } return nil, errors.New("not exist setction") @@ -259,10 +275,19 @@ func (c *ConfigContainer) Set(key, val string) error { // DIY returns the raw value by a given key. func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { + return c.getData(key) +} + +func (c *ConfigContainer) getData(key string) (interface{}, error) { + + if len(key) == 0 { + return nil, errors.New("key is emtpy") + } + if v, ok := c.data[key]; ok { return v, nil } - return nil, errors.New("not exist key") + return nil, fmt.Errorf("not exist key %q", key) } func init() { diff --git a/config/yaml/yaml_test.go b/config/yaml/yaml_test.go index 80cbb8fe..49cc1d1e 100644 --- a/config/yaml/yaml_test.go +++ b/config/yaml/yaml_test.go @@ -15,13 +15,17 @@ package yaml import ( + "fmt" "os" "testing" "github.com/astaxie/beego/config" ) -var yamlcontext = ` +func TestYaml(t *testing.T) { + + var ( + yamlcontext = ` "appname": beeapi "httpport": 8080 "mysqlport": 3600 @@ -29,9 +33,27 @@ var yamlcontext = ` "runmode": dev "autorender": false "copyrequestbody": true +"PATH": GOPATH +"path1": ${GOPATH} +"path2": ${GOPATH||/home/go} +"empty": "" ` -func TestYaml(t *testing.T) { + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "PATH": "GOPATH", + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "error": "", + "emptystrings": []string{}, + } + ) f, err := os.Create("testyaml.conf") if err != nil { t.Fatal(err) @@ -47,32 +69,42 @@ func TestYaml(t *testing.T) { if err != nil { t.Fatal(err) } + if yamlconf.String("appname") != "beeapi" { t.Fatal("appname not equal to beeapi") } - if port, err := yamlconf.Int("httpport"); err != nil || port != 8080 { - t.Error(port) - t.Fatal(err) - } - if port, err := yamlconf.Int64("mysqlport"); err != nil || port != 3600 { - t.Error(port) - t.Fatal(err) - } - if pi, err := yamlconf.Float("PI"); err != nil || pi != 3.1415976 { - t.Error(pi) - t.Fatal(err) - } - if yamlconf.String("runmode") != "dev" { - t.Fatal("runmode not equal to dev") - } - if v, err := yamlconf.Bool("autorender"); err != nil || v != false { - t.Error(v) - t.Fatal(err) - } - if v, err := yamlconf.Bool("copyrequestbody"); err != nil || v != true { - t.Error(v) - t.Fatal(err) + + for k, v := range keyValue { + + var ( + value interface{} + err error + ) + + switch v.(type) { + case int: + value, err = yamlconf.Int(k) + case int64: + value, err = yamlconf.Int64(k) + case float64: + value, err = yamlconf.Float(k) + case bool: + value, err = yamlconf.Bool(k) + case []string: + value = yamlconf.Strings(k) + case string: + value = yamlconf.String(k) + default: + value, err = yamlconf.DIY(k) + } + if err != nil { + t.Errorf("get key %q value fatal,%v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Errorf("get key %q value, want %v got %v .", k, v, value) + } + } + if err = yamlconf.Set("name", "astaxie"); err != nil { t.Fatal(err) } @@ -80,7 +112,4 @@ func TestYaml(t *testing.T) { t.Fatal("get name error") } - if yamlconf.Strings("emptystrings") != nil { - t.Fatal("get emtpy strings error") - } } diff --git a/config_test.go b/config_test.go index cf4a781d..c9576afd 100644 --- a/config_test.go +++ b/config_test.go @@ -15,7 +15,11 @@ package beego import ( + "encoding/json" + "reflect" "testing" + + "github.com/astaxie/beego/config" ) func TestDefaults(t *testing.T) { @@ -27,3 +31,109 @@ func TestDefaults(t *testing.T) { t.Errorf("FlashName was not set to default.") } } + +func TestAssignConfig_01(t *testing.T) { + _BConfig := &Config{} + _BConfig.AppName = "beego_test" + jcf := &config.JSONConfig{} + ac, _ := jcf.ParseData([]byte(`{"AppName":"beego_json"}`)) + assignSingleConfig(_BConfig, ac) + if _BConfig.AppName != "beego_json" { + t.Log(_BConfig) + t.FailNow() + } +} + +func TestAssignConfig_02(t *testing.T) { + _BConfig := &Config{} + bs, _ := json.Marshal(newBConfig()) + + jsonMap := map[string]interface{}{} + json.Unmarshal(bs, &jsonMap) + + configMap := map[string]interface{}{} + for k, v := range jsonMap { + if reflect.TypeOf(v).Kind() == reflect.Map { + for k1, v1 := range v.(map[string]interface{}) { + if reflect.TypeOf(v1).Kind() == reflect.Map { + for k2, v2 := range v1.(map[string]interface{}) { + configMap[k2] = v2 + } + } else { + configMap[k1] = v1 + } + } + } else { + configMap[k] = v + } + } + configMap["MaxMemory"] = 1024 + configMap["Graceful"] = true + configMap["XSRFExpire"] = 32 + configMap["SessionProviderConfig"] = "file" + configMap["FileLineNum"] = true + + jcf := &config.JSONConfig{} + bs, _ = json.Marshal(configMap) + ac, _ := jcf.ParseData([]byte(bs)) + + for _, i := range []interface{}{_BConfig, &_BConfig.Listen, &_BConfig.WebConfig, &_BConfig.Log, &_BConfig.WebConfig.Session} { + assignSingleConfig(i, ac) + } + + if _BConfig.MaxMemory != 1024 { + t.Log(_BConfig.MaxMemory) + t.FailNow() + } + + if !_BConfig.Listen.Graceful { + t.Log(_BConfig.Listen.Graceful) + t.FailNow() + } + + if _BConfig.WebConfig.XSRFExpire != 32 { + t.Log(_BConfig.WebConfig.XSRFExpire) + t.FailNow() + } + + if _BConfig.WebConfig.Session.SessionProviderConfig != "file" { + t.Log(_BConfig.WebConfig.Session.SessionProviderConfig) + t.FailNow() + } + + if !_BConfig.Log.FileLineNum { + t.Log(_BConfig.Log.FileLineNum) + t.FailNow() + } + +} + +func TestAssignConfig_03(t *testing.T) { + jcf := &config.JSONConfig{} + ac, _ := jcf.ParseData([]byte(`{"AppName":"beego"}`)) + ac.Set("AppName", "test_app") + ac.Set("RunMode", "online") + ac.Set("StaticDir", "download:down download2:down2") + ac.Set("StaticExtensionsToGzip", ".css,.js,.html,.jpg,.png") + assignConfig(ac) + + + t.Logf("%#v",BConfig) + + if BConfig.AppName != "test_app" { + t.FailNow() + } + + if BConfig.RunMode != "online" { + t.FailNow() + } + if BConfig.WebConfig.StaticDir["/download"] != "down" { + t.FailNow() + } + if BConfig.WebConfig.StaticDir["/download2"] != "down2" { + t.FailNow() + } + if len(BConfig.WebConfig.StaticExtensionsToGzip) != 5 { + t.FailNow() + } +} diff --git a/context/acceptencoder.go b/context/acceptencoder.go index 033d9ca8..cb735445 100644 --- a/context/acceptencoder.go +++ b/context/acceptencoder.go @@ -27,6 +27,33 @@ import ( "sync" ) +var ( + //Default size==20B same as nginx + defaultGzipMinLength = 20 + //Content will only be compressed if content length is either unknown or greater than gzipMinLength. + gzipMinLength = defaultGzipMinLength + //The compression level used for deflate compression. (0-9). + gzipCompressLevel int + //List of HTTP methods to compress. If not set, only GET requests are compressed. + includedMethods map[string]bool + getMethodOnly bool +) + +func InitGzip(minLength, compressLevel int, methods []string) { + if minLength >= 0 { + gzipMinLength = minLength + } + gzipCompressLevel = compressLevel + if gzipCompressLevel < flate.NoCompression || gzipCompressLevel > flate.BestCompression { + gzipCompressLevel = flate.BestSpeed + } + getMethodOnly = (len(methods) == 0) || (len(methods) == 1 && strings.ToUpper(methods[0]) == "GET") + includedMethods = make(map[string]bool, len(methods)) + for _, v := range methods { + includedMethods[strings.ToUpper(v)] = true + } +} + type resetWriter interface { io.Writer Reset(w io.Writer) @@ -41,20 +68,20 @@ func (n nopResetWriter) Reset(w io.Writer) { } type acceptEncoder struct { - name string - levelEncode func(int) resetWriter - bestSpeedPool *sync.Pool - bestCompressionPool *sync.Pool + name string + levelEncode func(int) resetWriter + customCompressLevelPool *sync.Pool + bestCompressionPool *sync.Pool } func (ac acceptEncoder) encode(wr io.Writer, level int) resetWriter { - if ac.bestSpeedPool == nil || ac.bestCompressionPool == nil { + if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil { return nopResetWriter{wr} } var rwr resetWriter switch level { case flate.BestSpeed: - rwr = ac.bestSpeedPool.Get().(resetWriter) + rwr = ac.customCompressLevelPool.Get().(resetWriter) case flate.BestCompression: rwr = ac.bestCompressionPool.Get().(resetWriter) default: @@ -65,13 +92,18 @@ func (ac acceptEncoder) encode(wr io.Writer, level int) resetWriter { } func (ac acceptEncoder) put(wr resetWriter, level int) { - if ac.bestSpeedPool == nil || ac.bestCompressionPool == nil { + if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil { return } wr.Reset(nil) + + //notice + //compressionLevel==BestCompression DOES NOT MATTER + //sync.Pool will not memory leak + switch level { - case flate.BestSpeed: - ac.bestSpeedPool.Put(wr) + case gzipCompressLevel: + ac.customCompressLevelPool.Put(wr) case flate.BestCompression: ac.bestCompressionPool.Put(wr) } @@ -79,28 +111,22 @@ func (ac acceptEncoder) put(wr resetWriter, level int) { var ( noneCompressEncoder = acceptEncoder{"", nil, nil, nil} - gzipCompressEncoder = acceptEncoder{"gzip", - func(level int) resetWriter { wr, _ := gzip.NewWriterLevel(nil, level); return wr }, - &sync.Pool{ - New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestSpeed); return wr }, - }, - &sync.Pool{ - New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestCompression); return wr }, - }, + gzipCompressEncoder = acceptEncoder{ + name: "gzip", + levelEncode: func(level int) resetWriter { wr, _ := gzip.NewWriterLevel(nil, level); return wr }, + customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, gzipCompressLevel); return wr }}, + bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestCompression); return wr }}, } //according to the sec :http://tools.ietf.org/html/rfc2616#section-3.5 ,the deflate compress in http is zlib indeed //deflate //The "zlib" format defined in RFC 1950 [31] in combination with //the "deflate" compression mechanism described in RFC 1951 [29]. - deflateCompressEncoder = acceptEncoder{"deflate", - func(level int) resetWriter { wr, _ := zlib.NewWriterLevel(nil, level); return wr }, - &sync.Pool{ - New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestSpeed); return wr }, - }, - &sync.Pool{ - New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestCompression); return wr }, - }, + deflateCompressEncoder = acceptEncoder{ + name: "deflate", + levelEncode: func(level int) resetWriter { wr, _ := zlib.NewWriterLevel(nil, level); return wr }, + customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, gzipCompressLevel); return wr }}, + bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestCompression); return wr }}, } ) @@ -120,7 +146,11 @@ func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string, // WriteBody reads writes content to writer by the specific encoding(gzip/deflate) func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) { - return writeLevel(encoding, writer, bytes.NewReader(content), flate.BestSpeed) + if encoding == "" || len(content) < gzipMinLength { + _, err := writer.Write(content) + return false, "", err + } + return writeLevel(encoding, writer, bytes.NewReader(content), gzipCompressLevel) } // writeLevel reads from reader,writes to writer by specific encoding and compress level @@ -156,7 +186,10 @@ func ParseEncoding(r *http.Request) string { if r == nil { return "" } - return parseEncoding(r) + if (getMethodOnly && r.Method == "GET") || includedMethods[r.Method] { + return parseEncoding(r) + } + return "" } type q struct { diff --git a/context/context.go b/context/context.go index 63a1313d..03286097 100644 --- a/context/context.go +++ b/context/context.go @@ -24,13 +24,11 @@ package context import ( "bufio" - "bytes" "crypto/hmac" "crypto/sha1" "encoding/base64" "errors" "fmt" - "io" "net" "net/http" "strconv" @@ -67,18 +65,18 @@ func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) { ctx.ResponseWriter.reset(rw) ctx.Input.Reset(ctx) ctx.Output.Reset(ctx) + ctx._xsrfToken = "" } // Redirect does redirection to localurl with http header status code. -// It sends http response header directly. func (ctx *Context) Redirect(status int, localurl string) { - ctx.Output.Header("Location", localurl) - ctx.ResponseWriter.WriteHeader(status) + http.Redirect(ctx.ResponseWriter, ctx.Request, localurl, status) } // Abort stops this request. // if beego.ErrorMaps exists, panic body. func (ctx *Context) Abort(status int, body string) { + ctx.Output.SetStatus(status) panic(body) } @@ -195,14 +193,6 @@ func (r *Response) Write(p []byte) (int, error) { return r.ResponseWriter.Write(p) } -// Copy writes the data to the connection as part of an HTTP reply, -// and sets `started` to true. -// started means the response has sent out. -func (r *Response) Copy(buf *bytes.Buffer) (int64, error) { - r.Started = true - return io.Copy(r.ResponseWriter, buf) -} - // WriteHeader sends an HTTP response header with status code, // and sets `started` to true. func (r *Response) WriteHeader(code int) { diff --git a/context/context_test.go b/context/context_test.go new file mode 100644 index 00000000..7c0535e0 --- /dev/null +++ b/context/context_test.go @@ -0,0 +1,47 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestXsrfReset_01(t *testing.T) { + r := &http.Request{} + c := NewContext() + c.Request = r + c.ResponseWriter = &Response{} + c.ResponseWriter.reset(httptest.NewRecorder()) + c.Output.Reset(c) + c.Input.Reset(c) + c.XSRFToken("key", 16) + if c._xsrfToken == "" { + t.FailNow() + } + token := c._xsrfToken + c.Reset(&Response{ResponseWriter: httptest.NewRecorder()}, r) + if c._xsrfToken != "" { + t.FailNow() + } + c.XSRFToken("key", 16) + if c._xsrfToken == "" { + t.FailNow() + } + if token == c._xsrfToken { + t.FailNow() + } +} diff --git a/context/input.go b/context/input.go index edfdf530..c47996c9 100644 --- a/context/input.go +++ b/context/input.go @@ -89,6 +89,9 @@ func (input *BeegoInput) Site() string { // Scheme returns request scheme as "http" or "https". func (input *BeegoInput) Scheme() string { + if scheme := input.Header("X-Forwarded-Proto"); scheme != "" { + return scheme + } if input.Context.Request.URL.Scheme != "" { return input.Context.Request.URL.Scheme } diff --git a/context/input_test.go b/context/input_test.go index 24f6fd99..8887aec4 100644 --- a/context/input_test.go +++ b/context/input_test.go @@ -100,7 +100,7 @@ func TestSubDomain(t *testing.T) { /* TODO Fix this r, _ = http.NewRequest("GET", "http://127.0.0.1/", nil) - beegoInput.Request = r + beegoInput.Context.Request = r if beegoInput.SubDomains() != "" { t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) } diff --git a/context/output.go b/context/output.go index 17404702..e1ad23e0 100644 --- a/context/output.go +++ b/context/output.go @@ -21,8 +21,11 @@ import ( "errors" "fmt" "html/template" + "io" "mime" "net/http" + "net/url" + "os" "path/filepath" "strconv" "strings" @@ -72,10 +75,11 @@ func (output *BeegoOutput) Body(content []byte) error { if output.Status != 0 { output.Context.ResponseWriter.WriteHeader(output.Status) output.Status = 0 + } else { + output.Context.ResponseWriter.Started = true } - - _, err := output.Context.ResponseWriter.Copy(buf) - return err + io.Copy(output.Context.ResponseWriter, buf) + return nil } // Cookie sets cookie value via given key. @@ -235,13 +239,21 @@ func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error { // Download forces response for download file. // it prepares the download response header automatically. func (output *BeegoOutput) Download(file string, filename ...string) { + // check get file error, file not found or other error. + if _, err := os.Stat(file); err != nil { + http.ServeFile(output.Context.ResponseWriter, output.Context.Request, file) + return + } + + var fName string + if len(filename) > 0 && filename[0] != "" { + fName = filename[0] + } else { + fName = filepath.Base(file) + } + output.Header("Content-Disposition", "attachment; filename="+url.QueryEscape(fName)) output.Header("Content-Description", "File Transfer") output.Header("Content-Type", "application/octet-stream") - if len(filename) > 0 && filename[0] != "" { - output.Header("Content-Disposition", "attachment; filename="+filename[0]) - } else { - output.Header("Content-Disposition", "attachment; filename="+filepath.Base(file)) - } output.Header("Content-Transfer-Encoding", "binary") output.Header("Expires", "0") output.Header("Cache-Control", "must-revalidate") @@ -269,55 +281,55 @@ func (output *BeegoOutput) SetStatus(status int) { // IsCachable returns boolean of this request is cached. // HTTP 304 means cached. -func (output *BeegoOutput) IsCachable(status int) bool { +func (output *BeegoOutput) IsCachable() bool { return output.Status >= 200 && output.Status < 300 || output.Status == 304 } // IsEmpty returns boolean of this request is empty. // HTTP 201,204 and 304 means empty. -func (output *BeegoOutput) IsEmpty(status int) bool { +func (output *BeegoOutput) IsEmpty() bool { return output.Status == 201 || output.Status == 204 || output.Status == 304 } // IsOk returns boolean of this request runs well. // HTTP 200 means ok. -func (output *BeegoOutput) IsOk(status int) bool { +func (output *BeegoOutput) IsOk() bool { return output.Status == 200 } // IsSuccessful returns boolean of this request runs successfully. // HTTP 2xx means ok. -func (output *BeegoOutput) IsSuccessful(status int) bool { +func (output *BeegoOutput) IsSuccessful() bool { return output.Status >= 200 && output.Status < 300 } // IsRedirect returns boolean of this request is redirection header. // HTTP 301,302,307 means redirection. -func (output *BeegoOutput) IsRedirect(status int) bool { +func (output *BeegoOutput) IsRedirect() bool { return output.Status == 301 || output.Status == 302 || output.Status == 303 || output.Status == 307 } // IsForbidden returns boolean of this request is forbidden. // HTTP 403 means forbidden. -func (output *BeegoOutput) IsForbidden(status int) bool { +func (output *BeegoOutput) IsForbidden() bool { return output.Status == 403 } // IsNotFound returns boolean of this request is not found. // HTTP 404 means forbidden. -func (output *BeegoOutput) IsNotFound(status int) bool { +func (output *BeegoOutput) IsNotFound() bool { return output.Status == 404 } // IsClientError returns boolean of this request client sends error data. // HTTP 4xx means forbidden. -func (output *BeegoOutput) IsClientError(status int) bool { +func (output *BeegoOutput) IsClientError() bool { return output.Status >= 400 && output.Status < 500 } // IsServerError returns boolean of this server handler errors. // HTTP 5xx means server internal error. -func (output *BeegoOutput) IsServerError(status int) bool { +func (output *BeegoOutput) IsServerError() bool { return output.Status >= 500 && output.Status < 600 } diff --git a/controller.go b/controller.go index 85894275..3a9d1618 100644 --- a/controller.go +++ b/controller.go @@ -208,7 +208,7 @@ func (c *Controller) RenderBytes() ([]byte, error) { continue } buf.Reset() - err = executeTemplate(&buf, sectionTpl, c.Data) + err = ExecuteTemplate(&buf, sectionTpl, c.Data) if err != nil { return nil, err } @@ -217,7 +217,7 @@ func (c *Controller) RenderBytes() ([]byte, error) { } buf.Reset() - executeTemplate(&buf, c.Layout, c.Data) + ExecuteTemplate(&buf, c.Layout, c.Data) } return buf.Bytes(), err } @@ -242,7 +242,7 @@ func (c *Controller) renderTemplate() (bytes.Buffer, error) { } BuildTemplate(BConfig.WebConfig.ViewsPath, buildFiles...) } - return buf, executeTemplate(&buf, c.TplName, c.Data) + return buf, ExecuteTemplate(&buf, c.TplName, c.Data) } // Redirect sends the redirection response to url with status code. @@ -261,12 +261,13 @@ func (c *Controller) Abort(code string) { // CustomAbort stops controller handler and show the error data, it's similar Aborts, but support status code and body. func (c *Controller) CustomAbort(status int, body string) { - c.Ctx.Output.Status = status - // first panic from ErrorMaps, is is user defined error functions. + // first panic from ErrorMaps, it is user defined error functions. if _, ok := ErrorMaps[body]; ok { + c.Ctx.Output.Status = status panic(body) } // last panic user string + c.Ctx.ResponseWriter.WriteHeader(status) c.Ctx.ResponseWriter.Write([]byte(body)) panic(ErrAbort) } diff --git a/error.go b/error.go index 4f48fab2..bad08d86 100644 --- a/error.go +++ b/error.go @@ -210,159 +210,139 @@ var ErrorMaps = make(map[string]*errorInfo, 10) // show 401 unauthorized error. func unauthorized(rw http.ResponseWriter, r *http.Request) { - t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := map[string]interface{}{ - "Title": http.StatusText(401), - "BeegoVersion": VERSION, - } - data["Content"] = template.HTML("
The page you have requested can't be authorized." + - "
Perhaps you are here because:" + - "

") - t.Execute(rw, data) + responseError(rw, r, + 401, + "
The page you have requested can't be authorized."+ + "
Perhaps you are here because:"+ + "

", + ) } // show 402 Payment Required func paymentRequired(rw http.ResponseWriter, r *http.Request) { - t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := map[string]interface{}{ - "Title": http.StatusText(402), - "BeegoVersion": VERSION, - } - data["Content"] = template.HTML("
The page you have requested Payment Required." + - "
Perhaps you are here because:" + - "

") - t.Execute(rw, data) + responseError(rw, r, + 402, + "
The page you have requested Payment Required."+ + "
Perhaps you are here because:"+ + "

", + ) } // show 403 forbidden error. func forbidden(rw http.ResponseWriter, r *http.Request) { - t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := map[string]interface{}{ - "Title": http.StatusText(403), - "BeegoVersion": VERSION, - } - data["Content"] = template.HTML("
The page you have requested is forbidden." + - "
Perhaps you are here because:" + - "

") - t.Execute(rw, data) + responseError(rw, r, + 403, + "
The page you have requested is forbidden."+ + "
Perhaps you are here because:"+ + "

", + ) } -// show 404 notfound error. +// show 404 not found error. func notFound(rw http.ResponseWriter, r *http.Request) { - t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := map[string]interface{}{ - "Title": http.StatusText(404), - "BeegoVersion": VERSION, - } - data["Content"] = template.HTML("
The page you have requested has flown the coop." + - "
Perhaps you are here because:" + - "

") - t.Execute(rw, data) + responseError(rw, r, + 404, + "
The page you have requested has flown the coop."+ + "
Perhaps you are here because:"+ + "

", + ) } // show 405 Method Not Allowed func methodNotAllowed(rw http.ResponseWriter, r *http.Request) { - t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := map[string]interface{}{ - "Title": http.StatusText(405), - "BeegoVersion": VERSION, - } - data["Content"] = template.HTML("
The method you have requested Not Allowed." + - "
Perhaps you are here because:" + - "

") - t.Execute(rw, data) + responseError(rw, r, + 405, + "
The method you have requested Not Allowed."+ + "
Perhaps you are here because:"+ + "

", + ) } // show 500 internal server error. func internalServerError(rw http.ResponseWriter, r *http.Request) { - t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := map[string]interface{}{ - "Title": http.StatusText(500), - "BeegoVersion": VERSION, - } - data["Content"] = template.HTML("
The page you have requested is down right now." + - "

") - t.Execute(rw, data) + responseError(rw, r, + 500, + "
The page you have requested is down right now."+ + "

", + ) } // show 501 Not Implemented. func notImplemented(rw http.ResponseWriter, r *http.Request) { - t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := map[string]interface{}{ - "Title": http.StatusText(504), - "BeegoVersion": VERSION, - } - data["Content"] = template.HTML("
The page you have requested is Not Implemented." + - "

") - t.Execute(rw, data) + responseError(rw, r, + 501, + "
The page you have requested is Not Implemented."+ + "

", + ) } // show 502 Bad Gateway. func badGateway(rw http.ResponseWriter, r *http.Request) { - t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := map[string]interface{}{ - "Title": http.StatusText(502), - "BeegoVersion": VERSION, - } - data["Content"] = template.HTML("
The page you have requested is down right now." + - "

") - t.Execute(rw, data) + responseError(rw, r, + 502, + "
The page you have requested is down right now."+ + "

", + ) } // show 503 service unavailable error. func serviceUnavailable(rw http.ResponseWriter, r *http.Request) { - t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := map[string]interface{}{ - "Title": http.StatusText(503), - "BeegoVersion": VERSION, - } - data["Content"] = template.HTML("
The page you have requested is unavailable." + - "
Perhaps you are here because:" + - "

") - t.Execute(rw, data) + responseError(rw, r, + 503, + "
The page you have requested is unavailable."+ + "
Perhaps you are here because:"+ + "

", + ) } // show 504 Gateway Timeout. func gatewayTimeout(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 504, + "
The page you have requested is unavailable"+ + "
Perhaps you are here because:"+ + "

", + ) +} + +func responseError(rw http.ResponseWriter, r *http.Request, errCode int, errContent string) { t, _ := template.New("beegoerrortemp").Parse(errtpl) data := map[string]interface{}{ - "Title": http.StatusText(504), + "Title": http.StatusText(errCode), "BeegoVersion": VERSION, + "Content": template.HTML(errContent), } - data["Content"] = template.HTML("
The page you have requested is unavailable." + - "
Perhaps you are here because:" + - "

") t.Execute(rw, data) } @@ -408,7 +388,10 @@ func exception(errCode string, ctx *context.Context) { if err == nil { return v } - return 503 + if ctx.Output.Status == 0 { + return 503 + } + return ctx.Output.Status } for _, ec := range []string{errCode, "503", "500"} { diff --git a/error_test.go b/error_test.go new file mode 100644 index 00000000..2fb8f962 --- /dev/null +++ b/error_test.go @@ -0,0 +1,88 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" +) + +type errorTestController struct { + Controller +} + +const parseCodeError = "parse code error" + +func (ec *errorTestController) Get() { + errorCode, err := ec.GetInt("code") + if err != nil { + ec.Abort(parseCodeError) + } + if errorCode != 0 { + ec.CustomAbort(errorCode, ec.GetString("code")) + } + ec.Abort("404") +} + +func TestErrorCode_01(t *testing.T) { + registerDefaultErrorHandler() + for k := range ErrorMaps { + r, _ := http.NewRequest("GET", "/error?code="+k, nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/error", &errorTestController{}) + handler.ServeHTTP(w, r) + code, _ := strconv.Atoi(k) + if w.Code != code { + t.Fail() + } + if !strings.Contains(string(w.Body.Bytes()), http.StatusText(code)) { + t.Fail() + } + } +} + +func TestErrorCode_02(t *testing.T) { + registerDefaultErrorHandler() + r, _ := http.NewRequest("GET", "/error?code=0", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/error", &errorTestController{}) + handler.ServeHTTP(w, r) + if w.Code != 404 { + t.Fail() + } +} + +func TestErrorCode_03(t *testing.T) { + registerDefaultErrorHandler() + r, _ := http.NewRequest("GET", "/error?code=panic", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/error", &errorTestController{}) + handler.ServeHTTP(w, r) + if w.Code != 200 { + t.Fail() + } + if string(w.Body.Bytes()) != parseCodeError { + t.Fail() + } +} diff --git a/filter_test.go b/filter_test.go index d9928d8d..4ca4d2b8 100644 --- a/filter_test.go +++ b/filter_test.go @@ -20,14 +20,8 @@ import ( "testing" "github.com/astaxie/beego/context" - "github.com/astaxie/beego/logs" ) -func init() { - BeeLogger = logs.NewLogger(10000) - BeeLogger.SetLogger("console", "") -} - var FilterUser = func(ctx *context.Context) { ctx.Output.Body([]byte("i am " + ctx.Input.Param(":last") + ctx.Input.Param(":first"))) } diff --git a/hooks.go b/hooks.go index 59b10b32..3dca1b8d 100644 --- a/hooks.go +++ b/hooks.go @@ -6,6 +6,8 @@ import ( "net/http" "path/filepath" + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" "github.com/astaxie/beego/session" ) @@ -45,13 +47,16 @@ func registerSession() error { sessionConfig := AppConfig.String("sessionConfig") if sessionConfig == "" { conf := map[string]interface{}{ - "cookieName": BConfig.WebConfig.Session.SessionName, - "gclifetime": BConfig.WebConfig.Session.SessionGCMaxLifetime, - "providerConfig": filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig), - "secure": BConfig.Listen.EnableHTTPS, - "enableSetCookie": BConfig.WebConfig.Session.SessionAutoSetCookie, - "domain": BConfig.WebConfig.Session.SessionDomain, - "cookieLifeTime": BConfig.WebConfig.Session.SessionCookieLifeTime, + "cookieName": BConfig.WebConfig.Session.SessionName, + "gclifetime": BConfig.WebConfig.Session.SessionGCMaxLifetime, + "providerConfig": filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig), + "secure": BConfig.Listen.EnableHTTPS, + "enableSetCookie": BConfig.WebConfig.Session.SessionAutoSetCookie, + "domain": BConfig.WebConfig.Session.SessionDomain, + "cookieLifeTime": BConfig.WebConfig.Session.SessionCookieLifeTime, + "enableSidInHttpHeader": BConfig.WebConfig.Session.EnableSidInHttpHeader, + "sessionNameInHttpHeader": BConfig.WebConfig.Session.SessionNameInHttpHeader, + "enableSidInUrlQuery": BConfig.WebConfig.Session.EnableSidInUrlQuery, } confBytes, err := json.Marshal(conf) if err != nil { @@ -70,24 +75,27 @@ func registerSession() error { func registerTemplate() error { if err := BuildTemplate(BConfig.WebConfig.ViewsPath); err != nil { if BConfig.RunMode == DEV { - Warn(err) + logs.Warn(err) } return err } return nil } -func registerDocs() error { - if BConfig.WebConfig.EnableDocs { - Get("/docs", serverDocs) - Get("/docs/*", serverDocs) - } - return nil -} - func registerAdmin() error { if BConfig.Listen.EnableAdmin { go beeAdminApp.Run() } return nil } + +func registerGzip() error { + if BConfig.EnableGzip { + context.InitGzip( + AppConfig.DefaultInt("gzipMinLength", -1), + AppConfig.DefaultInt("gzipCompressLevel", -1), + AppConfig.DefaultStrings("includedMethods", []string{"GET"}), + ) + } + return nil +} diff --git a/log.go b/log.go index 46ec57dd..e9412f92 100644 --- a/log.go +++ b/log.go @@ -33,82 +33,77 @@ const ( ) // BeeLogger references the used application logger. -var BeeLogger = logs.NewLogger(100) +var BeeLogger = logs.GetBeeLogger() // SetLevel sets the global log level used by the simple logger. func SetLevel(l int) { - BeeLogger.SetLevel(l) + logs.SetLevel(l) } // SetLogFuncCall set the CallDepth, default is 3 func SetLogFuncCall(b bool) { - BeeLogger.EnableFuncCallDepth(b) - BeeLogger.SetLogFuncCallDepth(3) + logs.SetLogFuncCall(b) } // SetLogger sets a new logger. func SetLogger(adaptername string, config string) error { - err := BeeLogger.SetLogger(adaptername, config) - if err != nil { - return err - } - return nil + return logs.SetLogger(adaptername, config) } // Emergency logs a message at emergency level. func Emergency(v ...interface{}) { - BeeLogger.Emergency(generateFmtStr(len(v)), v...) + logs.Emergency(generateFmtStr(len(v)), v...) } // Alert logs a message at alert level. func Alert(v ...interface{}) { - BeeLogger.Alert(generateFmtStr(len(v)), v...) + logs.Alert(generateFmtStr(len(v)), v...) } // Critical logs a message at critical level. func Critical(v ...interface{}) { - BeeLogger.Critical(generateFmtStr(len(v)), v...) + logs.Critical(generateFmtStr(len(v)), v...) } // Error logs a message at error level. func Error(v ...interface{}) { - BeeLogger.Error(generateFmtStr(len(v)), v...) + logs.Error(generateFmtStr(len(v)), v...) } // Warning logs a message at warning level. func Warning(v ...interface{}) { - BeeLogger.Warning(generateFmtStr(len(v)), v...) + logs.Warning(generateFmtStr(len(v)), v...) } // Warn compatibility alias for Warning() func Warn(v ...interface{}) { - BeeLogger.Warn(generateFmtStr(len(v)), v...) + logs.Warn(generateFmtStr(len(v)), v...) } // Notice logs a message at notice level. func Notice(v ...interface{}) { - BeeLogger.Notice(generateFmtStr(len(v)), v...) + logs.Notice(generateFmtStr(len(v)), v...) } // Informational logs a message at info level. func Informational(v ...interface{}) { - BeeLogger.Informational(generateFmtStr(len(v)), v...) + logs.Informational(generateFmtStr(len(v)), v...) } // Info compatibility alias for Warning() func Info(v ...interface{}) { - BeeLogger.Info(generateFmtStr(len(v)), v...) + logs.Info(generateFmtStr(len(v)), v...) } // Debug logs a message at debug level. func Debug(v ...interface{}) { - BeeLogger.Debug(generateFmtStr(len(v)), v...) + logs.Debug(generateFmtStr(len(v)), v...) } // Trace logs a message at trace level. // compatibility alias for Warning() func Trace(v ...interface{}) { - BeeLogger.Trace(generateFmtStr(len(v)), v...) + logs.Trace(generateFmtStr(len(v)), v...) } func generateFmtStr(n int) string { diff --git a/logs/conn.go b/logs/conn.go index 1db1a427..6d5bf6bf 100644 --- a/logs/conn.go +++ b/logs/conn.go @@ -113,5 +113,5 @@ func (c *connWriter) needToConnectOnMsg() bool { } func init() { - Register("conn", NewConn) + Register(AdapterConn, NewConn) } diff --git a/logs/console.go b/logs/console.go index dc41dd7d..e6bf6c29 100644 --- a/logs/console.go +++ b/logs/console.go @@ -97,5 +97,5 @@ func (c *consoleWriter) Flush() { } func init() { - Register("console", NewConsole) + Register(AdapterConsole, NewConsole) } diff --git a/logs/es/es.go b/logs/es/es.go index 397ca2ef..22f4f650 100644 --- a/logs/es/es.go +++ b/logs/es/es.go @@ -76,5 +76,5 @@ func (el *esLogger) Flush() { } func init() { - logs.Register("es", NewES) + logs.Register(logs.AdapterEs, NewES) } diff --git a/logs/file.go b/logs/file.go index 9d3f78a0..7798a221 100644 --- a/logs/file.go +++ b/logs/file.go @@ -22,6 +22,7 @@ import ( "io" "os" "path/filepath" + "strconv" "strings" "sync" "time" @@ -30,7 +31,7 @@ import ( // fileLogWriter implements LoggerInterface. // It writes messages by lines limit, file size limit, or time frequency. type fileLogWriter struct { - sync.Mutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize + sync.RWMutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize // The opened file Filename string `json:"filename"` fileWriter *os.File @@ -47,12 +48,13 @@ type fileLogWriter struct { Daily bool `json:"daily"` MaxDays int64 `json:"maxdays"` dailyOpenDate int + dailyOpenTime time.Time Rotate bool `json:"rotate"` Level int `json:"level"` - Perm os.FileMode `json:"perm"` + Perm string `json:"perm"` fileNameOnly, suffix string // like "project.log", project is fileNameOnly and .log is suffix } @@ -60,14 +62,11 @@ type fileLogWriter struct { // newFileWriter create a FileLogWriter returning as LoggerInterface. func newFileWriter() Logger { w := &fileLogWriter{ - Filename: "", - MaxLines: 1000000, - MaxSize: 1 << 28, //256 MB - Daily: true, - MaxDays: 7, - Rotate: true, - Level: LevelTrace, - Perm: 0660, + Daily: true, + MaxDays: 7, + Rotate: true, + Level: LevelTrace, + Perm: "0660", } return w } @@ -77,11 +76,11 @@ func newFileWriter() Logger { // { // "filename":"logs/beego.log", // "maxLines":10000, -// "maxsize":1<<30, +// "maxsize":1024, // "daily":true, // "maxDays":15, // "rotate":true, -// "perm":0600 +// "perm":"0600" // } func (w *fileLogWriter) Init(jsonConfig string) error { err := json.Unmarshal([]byte(jsonConfig), w) @@ -128,7 +127,9 @@ func (w *fileLogWriter) WriteMsg(when time.Time, msg string, level int) error { h, d := formatTimeHeader(when) msg = string(h) + msg + "\n" if w.Rotate { + w.RLock() if w.needRotate(len(msg), d) { + w.RUnlock() w.Lock() if w.needRotate(len(msg), d) { if err := w.doRotate(when); err != nil { @@ -136,6 +137,8 @@ func (w *fileLogWriter) WriteMsg(when time.Time, msg string, level int) error { } } w.Unlock() + } else { + w.RUnlock() } } @@ -151,7 +154,11 @@ func (w *fileLogWriter) WriteMsg(when time.Time, msg string, level int) error { func (w *fileLogWriter) createLogFile() (*os.File, error) { // Open the log file - fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, w.Perm) + perm, err := strconv.ParseInt(w.Perm, 8, 64) + if err != nil { + return nil, err + } + fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, os.FileMode(perm)) return fd, err } @@ -162,8 +169,12 @@ func (w *fileLogWriter) initFd() error { return fmt.Errorf("get stat err: %s\n", err) } w.maxSizeCurSize = int(fInfo.Size()) - w.dailyOpenDate = time.Now().Day() + w.dailyOpenTime = time.Now() + w.dailyOpenDate = w.dailyOpenTime.Day() w.maxLinesCurLines = 0 + if w.Daily { + go w.dailyRotate(w.dailyOpenTime) + } if fInfo.Size() > 0 { count, err := w.lines() if err != nil { @@ -174,6 +185,22 @@ func (w *fileLogWriter) initFd() error { return nil } +func (w *fileLogWriter) dailyRotate(openTime time.Time) { + y, m, d := openTime.Add(24 * time.Hour).Date() + nextDay := time.Date(y, m, d, 0, 0, 0, 0, openTime.Location()) + tm := time.NewTimer(time.Duration(nextDay.UnixNano() - openTime.UnixNano() + 100)) + select { + case <-tm.C: + w.Lock() + if w.needRotate(0, time.Now().Day()) { + if err := w.doRotate(time.Now()); err != nil { + fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) + } + } + w.Unlock() + } +} + func (w *fileLogWriter) lines() (int, error) { fd, err := os.Open(w.Filename) if err != nil { @@ -204,22 +231,29 @@ func (w *fileLogWriter) lines() (int, error) { // DoRotate means it need to write file in new file. // new file name like xx.2013-01-01.log (daily) or xx.001.log (by line or size) func (w *fileLogWriter) doRotate(logTime time.Time) error { - _, err := os.Lstat(w.Filename) - if err != nil { - return err - } // file exists // Find the next available number num := 1 fName := "" + + _, err := os.Lstat(w.Filename) + if err != nil { + //even if the file is not exist or other ,we should RESTART the logger + goto RESTART_LOGGER + } + if w.MaxLines > 0 || w.MaxSize > 0 { for ; err == nil && num <= 999; num++ { fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", logTime.Format("2006-01-02"), num, w.suffix) _, err = os.Lstat(fName) } } else { - fName = fmt.Sprintf("%s.%s%s", w.fileNameOnly, logTime.Format("2006-01-02"), w.suffix) + fName = fmt.Sprintf("%s.%s%s", w.fileNameOnly, w.dailyOpenTime.Format("2006-01-02"), w.suffix) _, err = os.Lstat(fName) + for ; err == nil && num <= 999; num++ { + fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", w.dailyOpenTime.Format("2006-01-02"), num, w.suffix) + _, err = os.Lstat(fName) + } } // return error if the last file checked still existed if err == nil { @@ -231,16 +265,18 @@ func (w *fileLogWriter) doRotate(logTime time.Time) error { // Rename the file to its new found name // even if occurs error,we MUST guarantee to restart new logger - renameErr := os.Rename(w.Filename, fName) + err = os.Rename(w.Filename, fName) // re-start logger +RESTART_LOGGER: + startLoggerErr := w.startLogger() go w.deleteOldLog() if startLoggerErr != nil { return fmt.Errorf("Rotate StartLogger: %s\n", startLoggerErr) } - if renameErr != nil { - return fmt.Errorf("Rotate: %s\n", renameErr) + if err != nil { + return fmt.Errorf("Rotate: %s\n", err) } return nil @@ -255,8 +291,12 @@ func (w *fileLogWriter) deleteOldLog() { } }() - if !info.IsDir() && info.ModTime().Unix() < (time.Now().Unix()-60*60*24*w.MaxDays) { - if strings.HasPrefix(filepath.Base(path), w.fileNameOnly) && + if info == nil { + return + } + + if !info.IsDir() && info.ModTime().Add(24*time.Hour*time.Duration(w.MaxDays)).Before(time.Now()) { + if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && strings.HasSuffix(filepath.Base(path), w.suffix) { os.Remove(path) } @@ -278,5 +318,5 @@ func (w *fileLogWriter) Flush() { } func init() { - Register("file", newFileWriter) + Register(AdapterFile, newFileWriter) } diff --git a/logs/file_test.go b/logs/file_test.go index 1fa6cdaa..23370947 100644 --- a/logs/file_test.go +++ b/logs/file_test.go @@ -17,12 +17,34 @@ package logs import ( "bufio" "fmt" + "io/ioutil" "os" "strconv" "testing" "time" ) +func TestFilePerm(t *testing.T) { + log := NewLogger(10000) + log.SetLogger("file", `{"filename":"test.log", "perm": "0600"}`) + log.Debug("debug") + log.Informational("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + file, err := os.Stat("test.log") + if err != nil { + t.Fatal(err) + } + if file.Mode() != 0600 { + t.Fatal("unexpected log file permission") + } + os.Remove("test.log") +} + func TestFile1(t *testing.T) { log := NewLogger(10000) log.SetLogger("file", `{"filename":"test.log"}`) @@ -89,7 +111,7 @@ func TestFile2(t *testing.T) { os.Remove("test2.log") } -func TestFileRotate(t *testing.T) { +func TestFileRotate_01(t *testing.T) { log := NewLogger(10000) log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`) log.Debug("debug") @@ -110,6 +132,90 @@ func TestFileRotate(t *testing.T) { os.Remove("test3.log") } +func TestFileRotate_02(t *testing.T) { + fn1 := "rotate_day.log" + fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log" + testFileRotate(t, fn1, fn2) +} + +func TestFileRotate_03(t *testing.T) { + fn1 := "rotate_day.log" + fn := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log" + os.Create(fn) + fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log" + testFileRotate(t, fn1, fn2) + os.Remove(fn) +} + +func TestFileRotate_04(t *testing.T) { + fn1 := "rotate_day.log" + fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log" + testFileDailyRotate(t, fn1, fn2) +} + +func TestFileRotate_05(t *testing.T) { + fn1 := "rotate_day.log" + fn := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log" + os.Create(fn) + fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log" + testFileDailyRotate(t, fn1, fn2) + os.Remove(fn) +} + +func testFileRotate(t *testing.T, fn1, fn2 string) { + fw := &fileLogWriter{ + Daily: true, + MaxDays: 7, + Rotate: true, + Level: LevelTrace, + Perm: "0660", + } + fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) + fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) + fw.dailyOpenDate = fw.dailyOpenTime.Day() + fw.WriteMsg(time.Now(), "this is a msg for test", LevelDebug) + + for _, file := range []string{fn1, fn2} { + _, err := os.Stat(file) + if err != nil { + t.FailNow() + } + os.Remove(file) + } + fw.Destroy() +} + +func testFileDailyRotate(t *testing.T, fn1, fn2 string) { + fw := &fileLogWriter{ + Daily: true, + MaxDays: 7, + Rotate: true, + Level: LevelTrace, + Perm: "0660", + } + fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) + fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) + fw.dailyOpenDate = fw.dailyOpenTime.Day() + today, _ := time.ParseInLocation("2006-01-02", time.Now().Format("2006-01-02"), fw.dailyOpenTime.Location()) + today = today.Add(-1 * time.Second) + fw.dailyRotate(today) + for _, file := range []string{fn1, fn2} { + _, err := os.Stat(file) + if err != nil { + t.FailNow() + } + content, err := ioutil.ReadFile(file) + if err != nil { + t.FailNow() + } + if len(content) > 0 { + t.FailNow() + } + os.Remove(file) + } + fw.Destroy() +} + func exists(path string) (bool, error) { _, err := os.Stat(path) if err == nil { diff --git a/logs/log.go b/logs/log.go index a5982e45..c43782f3 100644 --- a/logs/log.go +++ b/logs/log.go @@ -35,10 +35,12 @@ package logs import ( "fmt" + "log" "os" "path" "runtime" "strconv" + "strings" "sync" "time" ) @@ -55,16 +57,28 @@ const ( LevelDebug ) -// Legacy loglevel constants to ensure backwards compatibility. -// -// Deprecated: will be removed in 1.5.0. +// levelLogLogger is defined to implement log.Logger +// the real log level will be LevelEmergency +const levelLoggerImpl = -1 + +// Name for adapter with beego official support +const ( + AdapterConsole = "console" + AdapterFile = "file" + AdapterMultiFile = "multifile" + AdapterMail = "stmp" + AdapterConn = "conn" + AdapterEs = "es" +) + +// Legacy log level constants to ensure backwards compatibility. const ( LevelInfo = LevelInformational LevelTrace = LevelDebug LevelWarn = LevelWarning ) -type loggerType func() Logger +type newLoggerFunc func() Logger // Logger defines the behavior of a log provider. type Logger interface { @@ -74,12 +88,13 @@ type Logger interface { Flush() } -var adapters = make(map[string]loggerType) +var adapters = make(map[string]newLoggerFunc) +var levelPrefix = [LevelDebug + 1]string{"[M] ", "[A] ", "[C] ", "[E] ", "[W] ", "[N] ", "[I] ", "[D] "} // Register makes a log provide available by the provided name. // If Register is called twice with the same name or if driver is nil, // it panics. -func Register(name string, log loggerType) { +func Register(name string, log newLoggerFunc) { if log == nil { panic("logs: Register provide is nil") } @@ -94,15 +109,19 @@ func Register(name string, log loggerType) { type BeeLogger struct { lock sync.Mutex level int + init bool enableFuncCallDepth bool loggerFuncCallDepth int asynchronous bool + msgChanLen int64 msgChan chan *logMsg signalChan chan string wg sync.WaitGroup outputs []*nameLogger } +const defaultAsyncMsgLen = 1e3 + type nameLogger struct { Logger name string @@ -119,18 +138,31 @@ var logMsgPool *sync.Pool // NewLogger returns a new BeeLogger. // channelLen means the number of messages in chan(used where asynchronous is true). // if the buffering chan is full, logger adapters write to file or other way. -func NewLogger(channelLen int64) *BeeLogger { +func NewLogger(channelLens ...int64) *BeeLogger { bl := new(BeeLogger) bl.level = LevelDebug bl.loggerFuncCallDepth = 2 - bl.msgChan = make(chan *logMsg, channelLen) + bl.msgChanLen = append(channelLens, 0)[0] + if bl.msgChanLen <= 0 { + bl.msgChanLen = defaultAsyncMsgLen + } bl.signalChan = make(chan string, 1) + bl.setLogger(AdapterConsole) return bl } // Async set the log to asynchronous and start the goroutine -func (bl *BeeLogger) Async() *BeeLogger { +func (bl *BeeLogger) Async(msgLen ...int64) *BeeLogger { + bl.lock.Lock() + defer bl.lock.Unlock() + if bl.asynchronous { + return bl + } bl.asynchronous = true + if len(msgLen) > 0 && msgLen[0] > 0 { + bl.msgChanLen = msgLen[0] + } + bl.msgChan = make(chan *logMsg, bl.msgChanLen) logMsgPool = &sync.Pool{ New: func() interface{} { return &logMsg{} @@ -143,10 +175,8 @@ func (bl *BeeLogger) Async() *BeeLogger { // SetLogger provides a given logger adapter into BeeLogger with config string. // config need to be correct JSON as string: {"interval":360}. -func (bl *BeeLogger) SetLogger(adapterName string, config string) error { - bl.lock.Lock() - defer bl.lock.Unlock() - +func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { + config := append(configs, "{}")[0] for _, l := range bl.outputs { if l.name == adapterName { return fmt.Errorf("logs: duplicate adaptername %q (you have set this logger before)", adapterName) @@ -168,6 +198,18 @@ func (bl *BeeLogger) SetLogger(adapterName string, config string) error { return nil } +// SetLogger provides a given logger adapter into BeeLogger with config string. +// config need to be correct JSON as string: {"interval":360}. +func (bl *BeeLogger) SetLogger(adapterName string, configs ...string) error { + bl.lock.Lock() + defer bl.lock.Unlock() + if !bl.init { + bl.outputs = []*nameLogger{} + bl.init = true + } + return bl.setLogger(adapterName, configs...) +} + // DelLogger remove a logger adapter in BeeLogger. func (bl *BeeLogger) DelLogger(adapterName string) error { bl.lock.Lock() @@ -196,7 +238,37 @@ func (bl *BeeLogger) writeToLoggers(when time.Time, msg string, level int) { } } -func (bl *BeeLogger) writeMsg(logLevel int, msg string) error { +func (bl *BeeLogger) Write(p []byte) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + // writeMsg will always add a '\n' character + if p[len(p)-1] == '\n' { + p = p[0 : len(p)-1] + } + // set levelLoggerImpl to ensure all log message will be write out + err = bl.writeMsg(levelLoggerImpl, string(p)) + if err == nil { + return len(p), err + } + return 0, err +} + +func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error { + if !bl.init { + bl.lock.Lock() + bl.setLogger(AdapterConsole) + bl.lock.Unlock() + } + if logLevel == levelLoggerImpl { + // set to emergency to ensure all log will be print out correctly + logLevel = LevelEmergency + } else { + msg = levelPrefix[logLevel] + msg + } + if len(v) > 0 { + msg = fmt.Sprintf(msg, v...) + } when := time.Now() if bl.enableFuncCallDepth { _, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth) @@ -205,7 +277,7 @@ func (bl *BeeLogger) writeMsg(logLevel int, msg string) error { line = 0 } _, filename := path.Split(file) - msg = "[" + filename + ":" + strconv.FormatInt(int64(line), 10) + "]" + msg + msg = "[" + filename + ":" + strconv.FormatInt(int64(line), 10) + "] " + msg } if bl.asynchronous { lm := logMsgPool.Get().(*logMsg) @@ -273,8 +345,7 @@ func (bl *BeeLogger) Emergency(format string, v ...interface{}) { if LevelEmergency > bl.level { return } - msg := fmt.Sprintf("[M] "+format, v...) - bl.writeMsg(LevelEmergency, msg) + bl.writeMsg(LevelEmergency, format, v...) } // Alert Log ALERT level message. @@ -282,8 +353,7 @@ func (bl *BeeLogger) Alert(format string, v ...interface{}) { if LevelAlert > bl.level { return } - msg := fmt.Sprintf("[A] "+format, v...) - bl.writeMsg(LevelAlert, msg) + bl.writeMsg(LevelAlert, format, v...) } // Critical Log CRITICAL level message. @@ -291,8 +361,7 @@ func (bl *BeeLogger) Critical(format string, v ...interface{}) { if LevelCritical > bl.level { return } - msg := fmt.Sprintf("[C] "+format, v...) - bl.writeMsg(LevelCritical, msg) + bl.writeMsg(LevelCritical, format, v...) } // Error Log ERROR level message. @@ -300,17 +369,12 @@ func (bl *BeeLogger) Error(format string, v ...interface{}) { if LevelError > bl.level { return } - msg := fmt.Sprintf("[E] "+format, v...) - bl.writeMsg(LevelError, msg) + bl.writeMsg(LevelError, format, v...) } // Warning Log WARNING level message. func (bl *BeeLogger) Warning(format string, v ...interface{}) { - if LevelWarning > bl.level { - return - } - msg := fmt.Sprintf("[W] "+format, v...) - bl.writeMsg(LevelWarning, msg) + bl.Warn(format, v...) } // Notice Log NOTICE level message. @@ -318,17 +382,12 @@ func (bl *BeeLogger) Notice(format string, v ...interface{}) { if LevelNotice > bl.level { return } - msg := fmt.Sprintf("[N] "+format, v...) - bl.writeMsg(LevelNotice, msg) + bl.writeMsg(LevelNotice, format, v...) } // Informational Log INFORMATIONAL level message. func (bl *BeeLogger) Informational(format string, v ...interface{}) { - if LevelInformational > bl.level { - return - } - msg := fmt.Sprintf("[I] "+format, v...) - bl.writeMsg(LevelInformational, msg) + bl.Info(format, v...) } // Debug Log DEBUG level message. @@ -336,38 +395,31 @@ func (bl *BeeLogger) Debug(format string, v ...interface{}) { if LevelDebug > bl.level { return } - msg := fmt.Sprintf("[D] "+format, v...) - bl.writeMsg(LevelDebug, msg) + bl.writeMsg(LevelDebug, format, v...) } // Warn Log WARN level message. // compatibility alias for Warning() func (bl *BeeLogger) Warn(format string, v ...interface{}) { - if LevelWarning > bl.level { + if LevelWarn > bl.level { return } - msg := fmt.Sprintf("[W] "+format, v...) - bl.writeMsg(LevelWarning, msg) + bl.writeMsg(LevelWarn, format, v...) } // Info Log INFO level message. // compatibility alias for Informational() func (bl *BeeLogger) Info(format string, v ...interface{}) { - if LevelInformational > bl.level { + if LevelInfo > bl.level { return } - msg := fmt.Sprintf("[I] "+format, v...) - bl.writeMsg(LevelInformational, msg) + bl.writeMsg(LevelInfo, format, v...) } // Trace Log TRACE level message. // compatibility alias for Debug() func (bl *BeeLogger) Trace(format string, v ...interface{}) { - if LevelDebug > bl.level { - return - } - msg := fmt.Sprintf("[D] "+format, v...) - bl.writeMsg(LevelDebug, msg) + bl.Debug(format, v...) } // Flush flush all chan data. @@ -386,6 +438,7 @@ func (bl *BeeLogger) Close() { if bl.asynchronous { bl.signalChan <- "close" bl.wg.Wait() + close(bl.msgChan) } else { bl.flush() for _, l := range bl.outputs { @@ -393,7 +446,6 @@ func (bl *BeeLogger) Close() { } bl.outputs = nil } - close(bl.msgChan) close(bl.signalChan) } @@ -407,16 +459,175 @@ func (bl *BeeLogger) Reset() { } func (bl *BeeLogger) flush() { - for { - if len(bl.msgChan) > 0 { - bm := <-bl.msgChan - bl.writeToLoggers(bm.when, bm.msg, bm.level) - logMsgPool.Put(bm) - continue + if bl.asynchronous { + for { + if len(bl.msgChan) > 0 { + bm := <-bl.msgChan + bl.writeToLoggers(bm.when, bm.msg, bm.level) + logMsgPool.Put(bm) + continue + } + break } - break } for _, l := range bl.outputs { l.Flush() } } + +// beeLogger references the used application logger. +var beeLogger *BeeLogger = NewLogger() + +// GetLogger returns the default BeeLogger +func GetBeeLogger() *BeeLogger { + return beeLogger +} + +var beeLoggerMap = struct { + sync.RWMutex + logs map[string]*log.Logger +}{ + logs: map[string]*log.Logger{}, +} + +// GetLogger returns the default BeeLogger +func GetLogger(prefixes ...string) *log.Logger { + prefix := append(prefixes, "")[0] + if prefix != "" { + prefix = fmt.Sprintf(`[%s] `, strings.ToUpper(prefix)) + } + beeLoggerMap.RLock() + l, ok := beeLoggerMap.logs[prefix] + if ok { + beeLoggerMap.RUnlock() + return l + } + beeLoggerMap.RUnlock() + beeLoggerMap.Lock() + defer beeLoggerMap.Unlock() + l, ok = beeLoggerMap.logs[prefix] + if !ok { + l = log.New(beeLogger, prefix, 0) + beeLoggerMap.logs[prefix] = l + } + return l +} + +// Reset will remove all the adapter +func Reset() { + beeLogger.Reset() +} + +func Async(msgLen ...int64) *BeeLogger { + return beeLogger.Async(msgLen...) +} + +// SetLevel sets the global log level used by the simple logger. +func SetLevel(l int) { + beeLogger.SetLevel(l) +} + +// EnableFuncCallDepth enable log funcCallDepth +func EnableFuncCallDepth(b bool) { + beeLogger.enableFuncCallDepth = b +} + +// SetLogFuncCall set the CallDepth, default is 3 +func SetLogFuncCall(b bool) { + beeLogger.EnableFuncCallDepth(b) + beeLogger.SetLogFuncCallDepth(3) +} + +// SetLogFuncCallDepth set log funcCallDepth +func SetLogFuncCallDepth(d int) { + beeLogger.loggerFuncCallDepth = d +} + +// SetLogger sets a new logger. +func SetLogger(adapter string, config ...string) error { + err := beeLogger.SetLogger(adapter, config...) + if err != nil { + return err + } + return nil +} + +// Emergency logs a message at emergency level. +func Emergency(f interface{}, v ...interface{}) { + beeLogger.Emergency(formatLog(f, v...)) +} + +// Alert logs a message at alert level. +func Alert(f interface{}, v ...interface{}) { + beeLogger.Alert(formatLog(f, v...)) +} + +// Critical logs a message at critical level. +func Critical(f interface{}, v ...interface{}) { + beeLogger.Critical(formatLog(f, v...)) +} + +// Error logs a message at error level. +func Error(f interface{}, v ...interface{}) { + beeLogger.Error(formatLog(f, v...)) +} + +// Warning logs a message at warning level. +func Warning(f interface{}, v ...interface{}) { + beeLogger.Warn(formatLog(f, v...)) +} + +// Warn compatibility alias for Warning() +func Warn(f interface{}, v ...interface{}) { + beeLogger.Warn(formatLog(f, v...)) +} + +// Notice logs a message at notice level. +func Notice(f interface{}, v ...interface{}) { + beeLogger.Notice(formatLog(f, v...)) +} + +// Informational logs a message at info level. +func Informational(f interface{}, v ...interface{}) { + beeLogger.Info(formatLog(f, v...)) +} + +// Info compatibility alias for Warning() +func Info(f interface{}, v ...interface{}) { + beeLogger.Info(formatLog(f, v...)) +} + +// Debug logs a message at debug level. +func Debug(f interface{}, v ...interface{}) { + beeLogger.Debug(formatLog(f, v...)) +} + +// Trace logs a message at trace level. +// compatibility alias for Warning() +func Trace(f interface{}, v ...interface{}) { + beeLogger.Trace(formatLog(f, v...)) +} + +func formatLog(f interface{}, v ...interface{}) string { + var msg string + switch f.(type) { + case string: + msg = f.(string) + if len(v) == 0 { + return msg + } + if strings.Contains(msg, "%") && !strings.Contains(msg, "%%") { + //format string + } else { + //do not contain format char + msg += strings.Repeat(" %v", len(v)) + } + default: + msg = fmt.Sprint(f) + if len(v) == 0 { + return msg + } + msg += strings.Repeat(" %v", len(v)) + } + return fmt.Sprintf(msg, v...) +} diff --git a/logs/logger.go b/logs/logger.go index b25bfaef..2f47e569 100644 --- a/logs/logger.go +++ b/logs/logger.go @@ -36,43 +36,46 @@ func (lg *logWriter) println(when time.Time, msg string) { lg.Unlock() } +const y1 = `0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999` +const y2 = `0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789` +const mo1 = `000000000111` +const mo2 = `123456789012` +const d1 = `0000000001111111111222222222233` +const d2 = `1234567890123456789012345678901` +const h1 = `000000000011111111112222` +const h2 = `012345678901234567890123` +const mi1 = `000000000011111111112222222222333333333344444444445555555555` +const mi2 = `012345678901234567890123456789012345678901234567890123456789` +const s1 = `000000000011111111112222222222333333333344444444445555555555` +const s2 = `012345678901234567890123456789012345678901234567890123456789` + func formatTimeHeader(when time.Time) ([]byte, int) { y, mo, d := when.Date() h, mi, s := when.Clock() - //len(2006/01/02 15:03:04)==19 + //len("2006/01/02 15:04:05 ")==20 var buf [20]byte - t := 3 - for y >= 10 { - p := y / 10 - buf[t] = byte('0' + y - p*10) - y = p - t-- - } - buf[0] = byte('0' + y) + + //change to '3' after 984 years, LOL + buf[0] = '2' + //change to '1' after 84 years, LOL + buf[1] = '0' + buf[2] = y1[y-2000] + buf[3] = y2[y-2000] buf[4] = '/' - if mo > 9 { - buf[5] = '1' - buf[6] = byte('0' + mo - 9) - } else { - buf[5] = '0' - buf[6] = byte('0' + mo) - } + buf[5] = mo1[mo-1] + buf[6] = mo2[mo-1] buf[7] = '/' - t = d / 10 - buf[8] = byte('0' + t) - buf[9] = byte('0' + d - t*10) + buf[8] = d1[d-1] + buf[9] = d2[d-1] buf[10] = ' ' - t = h / 10 - buf[11] = byte('0' + t) - buf[12] = byte('0' + h - t*10) + buf[11] = h1[h] + buf[12] = h2[h] buf[13] = ':' - t = mi / 10 - buf[14] = byte('0' + t) - buf[15] = byte('0' + mi - t*10) + buf[14] = mi1[mi] + buf[15] = mi2[mi] buf[16] = ':' - t = s / 10 - buf[17] = byte('0' + t) - buf[18] = byte('0' + s - t*10) + buf[17] = s1[s] + buf[18] = s2[s] buf[19] = ' ' return buf[0:], d diff --git a/logs/logger_test.go b/logs/logger_test.go new file mode 100644 index 00000000..4627853a --- /dev/null +++ b/logs/logger_test.go @@ -0,0 +1,57 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "testing" + "time" +) + +func TestFormatHeader_0(t *testing.T) { + tm := time.Now() + if tm.Year() >= 2100 { + t.FailNow() + } + dur := time.Second + for { + if tm.Year() >= 2100 { + break + } + h, _ := formatTimeHeader(tm) + if tm.Format("2006/01/02 15:04:05 ") != string(h) { + t.Log(tm) + t.FailNow() + } + tm = tm.Add(dur) + dur *= 2 + } +} + +func TestFormatHeader_1(t *testing.T) { + tm := time.Now() + year := tm.Year() + dur := time.Second + for { + if tm.Year() >= year+1 { + break + } + h, _ := formatTimeHeader(tm) + if tm.Format("2006/01/02 15:04:05 ") != string(h) { + t.Log(tm) + t.FailNow() + } + tm = tm.Add(dur) + } +} diff --git a/logs/multifile.go b/logs/multifile.go index b82ba274..63204e17 100644 --- a/logs/multifile.go +++ b/logs/multifile.go @@ -112,5 +112,5 @@ func newFilesWriter() Logger { } func init() { - Register("multifile", newFilesWriter) + Register(AdapterMultiFile, newFilesWriter) } diff --git a/logs/smtp.go b/logs/smtp.go index 47f5a0c6..834130ef 100644 --- a/logs/smtp.go +++ b/logs/smtp.go @@ -156,5 +156,5 @@ func (s *SMTPWriter) Destroy() { } func init() { - Register("smtp", newSMTPWriter) + Register(AdapterMail, newSMTPWriter) } diff --git a/migration/migration.go b/migration/migration.go index 1591bc50..c9ca1bc6 100644 --- a/migration/migration.go +++ b/migration/migration.go @@ -33,7 +33,7 @@ import ( "strings" "time" - "github.com/astaxie/beego" + "github.com/astaxie/beego/logs" "github.com/astaxie/beego/orm" ) @@ -90,7 +90,7 @@ func (m *Migration) Reset() { func (m *Migration) Exec(name, status string) error { o := orm.NewOrm() for _, s := range m.sqls { - beego.Info("exec sql:", s) + logs.Info("exec sql:", s) r := o.Raw(s) _, err := r.Exec() if err != nil { @@ -144,20 +144,20 @@ func Upgrade(lasttime int64) error { i := 0 for _, v := range sm { if v.created > lasttime { - beego.Info("start upgrade", v.name) + logs.Info("start upgrade", v.name) v.m.Reset() v.m.Up() err := v.m.Exec(v.name, "up") if err != nil { - beego.Error("execute error:", err) + logs.Error("execute error:", err) time.Sleep(2 * time.Second) return err } - beego.Info("end upgrade:", v.name) + logs.Info("end upgrade:", v.name) i++ } } - beego.Info("total success upgrade:", i, " migration") + logs.Info("total success upgrade:", i, " migration") time.Sleep(2 * time.Second) return nil } @@ -165,20 +165,20 @@ func Upgrade(lasttime int64) error { // Rollback rollback the migration by the name func Rollback(name string) error { if v, ok := migrationMap[name]; ok { - beego.Info("start rollback") + logs.Info("start rollback") v.Reset() v.Down() err := v.Exec(name, "down") if err != nil { - beego.Error("execute error:", err) + logs.Error("execute error:", err) time.Sleep(2 * time.Second) return err } - beego.Info("end rollback") + logs.Info("end rollback") time.Sleep(2 * time.Second) return nil } - beego.Error("not exist the migrationMap name:" + name) + logs.Error("not exist the migrationMap name:" + name) time.Sleep(2 * time.Second) return errors.New("not exist the migrationMap name:" + name) } @@ -191,23 +191,23 @@ func Reset() error { for j := len(sm) - 1; j >= 0; j-- { v := sm[j] if isRollBack(v.name) { - beego.Info("skip the", v.name) + logs.Info("skip the", v.name) time.Sleep(1 * time.Second) continue } - beego.Info("start reset:", v.name) + logs.Info("start reset:", v.name) v.m.Reset() v.m.Down() err := v.m.Exec(v.name, "down") if err != nil { - beego.Error("execute error:", err) + logs.Error("execute error:", err) time.Sleep(2 * time.Second) return err } i++ - beego.Info("end reset:", v.name) + logs.Info("end reset:", v.name) } - beego.Info("total success reset:", i, " migration") + logs.Info("total success reset:", i, " migration") time.Sleep(2 * time.Second) return nil } @@ -216,7 +216,7 @@ func Reset() error { func Refresh() error { err := Reset() if err != nil { - beego.Error("execute error:", err) + logs.Error("execute error:", err) time.Sleep(2 * time.Second) return err } @@ -265,7 +265,7 @@ func isRollBack(name string) bool { var maps []orm.Params num, err := o.Raw("select * from migrations where `name` = ? order by id_migration desc", name).Values(&maps) if err != nil { - beego.Info("get name has error", err) + logs.Info("get name has error", err) return false } if num <= 0 { diff --git a/namespace.go b/namespace.go index 4007d44c..cfde0111 100644 --- a/namespace.go +++ b/namespace.go @@ -44,7 +44,7 @@ func NewNamespace(prefix string, params ...LinkNamespace) *Namespace { return ns } -// Cond set condtion function +// Cond set condition function // if cond return true can run this namespace, else can't // usage: // ns.Cond(func (ctx *context.Context) bool{ @@ -60,7 +60,7 @@ func (n *Namespace) Cond(cond namespaceCond) *Namespace { exception("405", ctx) } } - if v, ok := n.handlers.filters[BeforeRouter]; ok { + if v := n.handlers.filters[BeforeRouter]; len(v) > 0 { mr := new(FilterRouter) mr.tree = NewTree() mr.pattern = "*" diff --git a/namespace_test.go b/namespace_test.go index a92ae3ef..fc02b5fb 100644 --- a/namespace_test.go +++ b/namespace_test.go @@ -61,8 +61,8 @@ func TestNamespaceNest(t *testing.T) { ns.Namespace( NewNamespace("/admin"). Get("/order", func(ctx *context.Context) { - ctx.Output.Body([]byte("order")) - }), + ctx.Output.Body([]byte("order")) + }), ) AddNamespace(ns) BeeApp.Handlers.ServeHTTP(w, r) @@ -79,8 +79,8 @@ func TestNamespaceNestParam(t *testing.T) { ns.Namespace( NewNamespace("/admin"). Get("/order/:id", func(ctx *context.Context) { - ctx.Output.Body([]byte(ctx.Input.Param(":id"))) - }), + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }), ) AddNamespace(ns) BeeApp.Handlers.ServeHTTP(w, r) @@ -124,8 +124,8 @@ func TestNamespaceFilter(t *testing.T) { ctx.Output.Body([]byte("this is Filter")) }). Get("/user/:id", func(ctx *context.Context) { - ctx.Output.Body([]byte(ctx.Input.Param(":id"))) - }) + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }) AddNamespace(ns) BeeApp.Handlers.ServeHTTP(w, r) if w.Body.String() != "this is Filter" { diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go index da0ee8ab..8119b70b 100644 --- a/orm/cmd_utils.go +++ b/orm/cmd_utils.go @@ -52,9 +52,15 @@ checkColumn: case TypeBooleanField: col = T["bool"] case TypeCharField: - col = fmt.Sprintf(T["string"], fieldSize) + if al.Driver == DRPostgres && fi.toText { + col = T["string-text"] + } else { + col = fmt.Sprintf(T["string"], fieldSize) + } case TypeTextField: col = T["string-text"] + case TypeTimeField: + col = T["time.Time-clock"] case TypeDateField: col = T["time.Time-date"] case TypeDateTimeField: @@ -88,6 +94,18 @@ checkColumn: } else { col = fmt.Sprintf(s, fi.digits, fi.decimals) } + case TypeJSONField: + if al.Driver != DRPostgres { + fieldType = TypeCharField + goto checkColumn + } + col = T["json"] + case TypeJsonbField: + if al.Driver != DRPostgres { + fieldType = TypeCharField + goto checkColumn + } + col = T["jsonb"] case RelForeignKey, RelOneToOne: fieldType = fi.relModelInfo.fields.pk.fieldType fieldSize = fi.relModelInfo.fields.pk.size @@ -264,7 +282,7 @@ func getColumnDefault(fi *fieldInfo) string { // These defaults will be useful if there no config value orm:"default" and NOT NULL is on switch fi.fieldType { - case TypeDateField, TypeDateTimeField, TypeTextField: + case TypeTimeField, TypeDateField, TypeDateTimeField, TypeTextField: return v case TypeBitField, TypeSmallIntegerField, TypeIntegerField, @@ -276,6 +294,8 @@ func getColumnDefault(fi *fieldInfo) string { case TypeBooleanField: t = " DEFAULT %s " d = "FALSE" + case TypeJSONField, TypeJsonbField: + d = "{}" } if fi.colDefault { diff --git a/orm/db.go b/orm/db.go index 314c3535..9964e263 100644 --- a/orm/db.go +++ b/orm/db.go @@ -24,6 +24,7 @@ import ( ) const ( + formatTime = "15:04:05" formatDate = "2006-01-02" formatDateTime = "2006-01-02 15:04:05" ) @@ -71,12 +72,12 @@ type dbBase struct { var _ dbBaser = new(dbBase) // get struct columns values as interface slice. -func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, err error) { - var columns []string - - if names != nil { - columns = *names +func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, autoFields []string, err error) { + if names == nil { + ns := make([]string, 0, len(cols)) + names = &ns } + values = make([]interface{}, 0, len(cols)) for _, column := range cols { var fi *fieldInfo @@ -90,18 +91,24 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, } value, err := d.collectFieldValue(mi, fi, ind, insert, tz) if err != nil { - return nil, err + return nil, nil, err } - if names != nil { - columns = append(columns, column) + // ignore empty value auto field + if insert && fi.auto { + if fi.fieldType&IsPositiveIntegerField > 0 { + if vu, ok := value.(uint64); !ok || vu == 0 { + continue + } + } else { + if vu, ok := value.(int64); !ok || vu == 0 { + continue + } + } + autoFields = append(autoFields, fi.column) } - values = append(values, value) - } - - if names != nil { - *names = columns + *names, values = append(*names, column), append(values, value) } return @@ -134,7 +141,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val } else { value = field.Bool() } - case TypeCharField, TypeTextField: + case TypeCharField, TypeTextField, TypeJSONField, TypeJsonbField: if ns, ok := field.Interface().(sql.NullString); ok { value = nil if ns.Valid { @@ -169,7 +176,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val value = field.Float() } } - case TypeDateField, TypeDateTimeField: + case TypeTimeField, TypeDateField, TypeDateTimeField: value = field.Interface() if t, ok := value.(time.Time); ok { d.ins.TimeToDB(&t, tz) @@ -181,7 +188,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val } default: switch { - case fi.fieldType&IsPostiveIntegerField > 0: + case fi.fieldType&IsPositiveIntegerField > 0: if field.Kind() == reflect.Ptr { if field.IsNil() { value = nil @@ -223,7 +230,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val } } switch fi.fieldType { - case TypeDateField, TypeDateTimeField: + case TypeTimeField, TypeDateField, TypeDateTimeField: if fi.autoNow || fi.autoNowAdd && insert { if insert { if t, ok := value.(time.Time); ok && !t.IsZero() { @@ -236,10 +243,21 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val if fi.isFielder { f := field.Addr().Interface().(Fielder) f.SetRaw(tnow.In(DefaultTimeLoc)) + } else if field.Kind() == reflect.Ptr { + v := tnow.In(DefaultTimeLoc) + field.Set(reflect.ValueOf(&v)) } else { field.Set(reflect.ValueOf(tnow.In(DefaultTimeLoc))) } } + case TypeJSONField, TypeJsonbField: + if s, ok := value.(string); (ok && len(s) == 0) || value == nil { + if fi.colDefault && fi.initial.Exist() { + value = fi.initial.String() + } else { + value = nil + } + } } } return value, nil @@ -273,7 +291,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, // insert struct with prepared statement and given struct reflect value. func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { - values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) + values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) if err != nil { return 0, err } @@ -300,7 +318,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo if len(cols) > 0 { var err error whereCols = make([]string, 0, len(cols)) - args, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz) + args, _, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz) if err != nil { return err } @@ -349,13 +367,21 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo // execute insert sql dbQuerier with given struct reflect.Value. func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { - names := make([]string, 0, len(mi.fields.dbcols)-1) - values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz) + names := make([]string, 0, len(mi.fields.dbcols)) + values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) if err != nil { return 0, err } - return d.InsertValue(q, mi, false, names, values) + id, err := d.InsertValue(q, mi, false, names, values) + if err != nil { + return 0, err + } + + if len(autoFields) > 0 { + err = d.ins.setval(q, mi, autoFields) + } + return id, err } // multi-insert sql with given slice struct reflect.Value. @@ -369,7 +395,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul // typ := reflect.Indirect(mi.addrField).Type() - length := sind.Len() + length, autoFields := sind.Len(), make([]string, 0, 1) for i := 1; i <= length; i++ { @@ -381,16 +407,18 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul // } if i == 1 { - vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz) + var ( + vus []interface{} + err error + ) + vus, autoFields, err = d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) if err != nil { return cnt, err } values = make([]interface{}, bulk*len(vus)) nums += copy(values, vus) - } else { - - vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) + vus, _, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, nil, tz) if err != nil { return cnt, err } @@ -412,7 +440,12 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul } } - return cnt, nil + var err error + if len(autoFields) > 0 { + err = d.ins.setval(q, mi, autoFields) + } + + return cnt, err } // execute insert sql with given struct and given values. @@ -472,7 +505,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. setNames = make([]string, 0, len(cols)) } - setValues, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz) + setValues, _, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz) if err != nil { return 0, err } @@ -516,7 +549,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. } if num > 0 { if mi.fields.pk.auto { - if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { + if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(0) } else { ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0) @@ -1071,13 +1104,13 @@ setValue: } value = b } - case fieldType == TypeCharField || fieldType == TypeTextField: + case fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField: if str == nil { value = ToStr(val) } else { value = str.String() } - case fieldType == TypeDateField || fieldType == TypeDateTimeField: + case fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField: if str == nil { switch t := val.(type) { case time.Time: @@ -1097,15 +1130,20 @@ setValue: if len(s) >= 19 { s = s[:19] t, err = time.ParseInLocation(formatDateTime, s, tz) - } else { + } else if len(s) >= 10 { if len(s) > 10 { s = s[:10] } t, err = time.ParseInLocation(formatDate, s, tz) + } else if len(s) >= 8 { + if len(s) > 8 { + s = s[:8] + } + t, err = time.ParseInLocation(formatTime, s, tz) } t = t.In(DefaultTimeLoc) - if err != nil && s != "0000-00-00" && s != "0000-00-00 00:00:00" { + if err != nil && s != "00:00:00" && s != "0000-00-00" && s != "0000-00-00 00:00:00" { tErr = err goto end } @@ -1140,7 +1178,7 @@ setValue: tErr = err goto end } - if fieldType&IsPostiveIntegerField > 0 { + if fieldType&IsPositiveIntegerField > 0 { v, _ := str.Uint64() value = v } else { @@ -1212,7 +1250,7 @@ setValue: field.SetBool(value.(bool)) } } - case fieldType == TypeCharField || fieldType == TypeTextField: + case fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField: if isNative { if ns, ok := field.Interface().(sql.NullString); ok { if value == nil { @@ -1234,12 +1272,18 @@ setValue: field.SetString(value.(string)) } } - case fieldType == TypeDateField || fieldType == TypeDateTimeField: + case fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField: if isNative { if value == nil { value = time.Time{} + } else if field.Kind() == reflect.Ptr { + if value != nil { + v := value.(time.Time) + field.Set(reflect.ValueOf(&v)) + } + } else { + field.Set(reflect.ValueOf(value)) } - field.Set(reflect.ValueOf(value)) } case fieldType == TypePositiveBitField && field.Kind() == reflect.Ptr: if value != nil { @@ -1292,7 +1336,7 @@ setValue: field.Set(reflect.ValueOf(&v)) } case fieldType&IsIntegerField > 0: - if fieldType&IsPostiveIntegerField > 0 { + if fieldType&IsPositiveIntegerField > 0 { if isNative { if value == nil { value = uint64(0) @@ -1440,7 +1484,11 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond sels := strings.Join(cols, ", ") - query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s%s", sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) + sqlSelect := "SELECT" + if qs.distinct { + sqlSelect += " DISTINCT" + } + query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) d.ins.ReplaceMarks(&query) @@ -1562,6 +1610,11 @@ func (d *dbBase) HasReturningID(*modelInfo, *string) bool { return false } +// sync auto key +func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { + return nil +} + // convert time from db. func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) { *t = t.In(tz) diff --git a/orm/db_postgres.go b/orm/db_postgres.go index 7dbef95a..e972c4a2 100644 --- a/orm/db_postgres.go +++ b/orm/db_postgres.go @@ -56,6 +56,8 @@ var postgresTypes = map[string]string{ "uint64": `bigint CHECK("%COL%" >= 0)`, "float64": "double precision", "float64-decimal": "numeric(%d, %d)", + "json": "json", + "jsonb": "jsonb", } // postgresql dbBaser. @@ -123,14 +125,35 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) { } // make returning sql support for postgresql. -func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) { - if mi.fields.pk.auto { - if query != nil { - *query = fmt.Sprintf(`%s RETURNING "%s"`, *query, mi.fields.pk.column) - } - has = true +func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool { + fi := mi.fields.pk + if fi.fieldType&IsPositiveIntegerField == 0 && fi.fieldType&IsIntegerField == 0 { + return false } - return + + if query != nil { + *query = fmt.Sprintf(`%s RETURNING "%s"`, *query, fi.column) + } + return true +} + +// sync auto key +func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { + if len(autoFields) == 0 { + return nil + } + + Q := d.ins.TableQuote() + for _, name := range autoFields { + query := fmt.Sprintf("SELECT setval(pg_get_serial_sequence('%s', '%s'), (SELECT MAX(%s%s%s) FROM %s%s%s));", + mi.table, name, + Q, name, Q, + Q, mi.table, Q) + if _, err := db.Exec(query); err != nil { + return err + } + } + return nil } // show table sql for postgresql. diff --git a/orm/db_utils.go b/orm/db_utils.go index c97caf36..cf465d02 100644 --- a/orm/db_utils.go +++ b/orm/db_utils.go @@ -33,13 +33,13 @@ func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interfac fi := mi.fields.pk v := ind.FieldByIndex(fi.fieldIndex) - if fi.fieldType&IsPostiveIntegerField > 0 { + if fi.fieldType&IsPositiveIntegerField > 0 { vu := v.Uint() exist = vu > 0 value = vu } else if fi.fieldType&IsIntegerField > 0 { vu := v.Int() - exist = vu > 0 + exist = true value = vu } else { vu := v.String() @@ -74,24 +74,32 @@ outFor: case reflect.String: v := val.String() if fi != nil { - if fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField { + if fi.fieldType == TypeTimeField || fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField { var t time.Time var err error if len(v) >= 19 { s := v[:19] t, err = time.ParseInLocation(formatDateTime, s, DefaultTimeLoc) - } else { + } else if len(v) >= 10 { s := v if len(v) > 10 { s = v[:10] } t, err = time.ParseInLocation(formatDate, s, tz) + } else { + s := v + if len(s) > 8 { + s = v[:8] + } + t, err = time.ParseInLocation(formatTime, s, tz) } if err == nil { if fi.fieldType == TypeDateField { v = t.In(tz).Format(formatDate) - } else { + } else if fi.fieldType == TypeDateTimeField { v = t.In(tz).Format(formatDateTime) + } else { + v = t.In(tz).Format(formatTime) } } } @@ -137,8 +145,10 @@ outFor: if v, ok := arg.(time.Time); ok { if fi != nil && fi.fieldType == TypeDateField { arg = v.In(tz).Format(formatDate) - } else { + } else if fi.fieldType == TypeDateTimeField { arg = v.In(tz).Format(formatDateTime) + } else { + arg = v.In(tz).Format(formatTime) } } else { typ := val.Type() diff --git a/orm/models_boot.go b/orm/models_boot.go index 3690557b..c9905330 100644 --- a/orm/models_boot.go +++ b/orm/models_boot.go @@ -66,7 +66,7 @@ func registerModel(prefix string, model interface{}) { } if info.fields.pk == nil { - fmt.Printf(" `%s` need a primary key field\n", name) + fmt.Printf(" `%s` need a primary key field, default use 'id' if not set\n", name) os.Exit(2) } diff --git a/orm/models_fields.go b/orm/models_fields.go index a8cf8e4f..57820600 100644 --- a/orm/models_fields.go +++ b/orm/models_fields.go @@ -25,6 +25,7 @@ const ( TypeBooleanField = 1 << iota TypeCharField TypeTextField + TypeTimeField TypeDateField TypeDateTimeField TypeBitField @@ -37,6 +38,8 @@ const ( TypePositiveBigIntegerField TypeFloatField TypeDecimalField + TypeJSONField + TypeJsonbField RelForeignKey RelOneToOne RelManyToMany @@ -46,10 +49,10 @@ const ( // Define some logic enum const ( - IsIntegerField = ^-TypePositiveBigIntegerField >> 4 << 5 - IsPostiveIntegerField = ^-TypePositiveBigIntegerField >> 8 << 9 - IsRelField = ^-RelReverseMany >> 14 << 15 - IsFieldType = ^-RelReverseMany<<1 + 1 + IsIntegerField = ^-TypePositiveBigIntegerField >> 5 << 6 + IsPositiveIntegerField = ^-TypePositiveBigIntegerField >> 9 << 10 + IsRelField = ^-RelReverseMany >> 17 << 18 + IsFieldType = ^-RelReverseMany<<1 + 1 ) // BooleanField A true/false field. @@ -145,6 +148,65 @@ func (e *CharField) RawValue() interface{} { // verify CharField implement Fielder var _ Fielder = new(CharField) +// TimeField A time, represented in go by a time.Time instance. +// only time values like 10:00:00 +// Has a few extra, optional attr tag: +// +// auto_now: +// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// auto_now_add: +// Automatically set the field to now when the object is first created. Useful for creation of timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// eg: `orm:"auto_now"` or `orm:"auto_now_add"` +type TimeField time.Time + +// Value return the time.Time +func (e TimeField) Value() time.Time { + return time.Time(e) +} + +// Set set the TimeField's value +func (e *TimeField) Set(d time.Time) { + *e = TimeField(d) +} + +// String convert time to string +func (e *TimeField) String() string { + return e.Value().String() +} + +// FieldType return enum type Date +func (e *TimeField) FieldType() int { + return TypeDateField +} + +// SetRaw convert the interface to time.Time. Allow string and time.Time +func (e *TimeField) SetRaw(value interface{}) error { + switch d := value.(type) { + case time.Time: + e.Set(d) + case string: + v, err := timeParse(d, formatTime) + if err != nil { + e.Set(v) + } + return err + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return time value +func (e *TimeField) RawValue() interface{} { + return e.Value() +} + +var _ Fielder = new(TimeField) + // DateField A date, represented in go by a time.Time instance. // only date values like 2006-01-02 // Has a few extra, optional attr tag: @@ -627,3 +689,87 @@ func (e *TextField) RawValue() interface{} { // verify TextField implement Fielder var _ Fielder = new(TextField) + +// JSONField postgres json field. +type JSONField string + +// Value return JSONField value +func (j JSONField) Value() string { + return string(j) +} + +// Set the JSONField value +func (j *JSONField) Set(d string) { + *j = JSONField(d) +} + +// String convert JSONField to string +func (j *JSONField) String() string { + return j.Value() +} + +// FieldType return enum type +func (j *JSONField) FieldType() int { + return TypeJSONField +} + +// SetRaw convert interface string to string +func (j *JSONField) SetRaw(value interface{}) error { + switch d := value.(type) { + case string: + j.Set(d) + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return JSONField value +func (j *JSONField) RawValue() interface{} { + return j.Value() +} + +// verify JSONField implement Fielder +var _ Fielder = new(JSONField) + +// JsonbField postgres json field. +type JsonbField string + +// Value return JsonbField value +func (j JsonbField) Value() string { + return string(j) +} + +// Set the JsonbField value +func (j *JsonbField) Set(d string) { + *j = JsonbField(d) +} + +// String convert JsonbField to string +func (j *JsonbField) String() string { + return j.Value() +} + +// FieldType return enum type +func (j *JsonbField) FieldType() int { + return TypeJsonbField +} + +// SetRaw convert interface string to string +func (j *JsonbField) SetRaw(value interface{}) error { + switch d := value.(type) { + case string: + j.Set(d) + default: + return fmt.Errorf(" unknown value `%s`", value) + } + return nil +} + +// RawValue return JsonbField value +func (j *JsonbField) RawValue() interface{} { + return j.Value() +} + +// verify JsonbField implement Fielder +var _ Fielder = new(JsonbField) diff --git a/orm/models_info_f.go b/orm/models_info_f.go index 996a2f40..be6c9aa4 100644 --- a/orm/models_info_f.go +++ b/orm/models_info_f.go @@ -119,6 +119,7 @@ type fieldInfo struct { colDefault bool initial StrTo size int + toText bool autoNow bool autoNowAdd bool rel bool @@ -239,8 +240,15 @@ checkType: if err != nil { goto end } - if fieldType == TypeCharField && tags["type"] == "text" { - fieldType = TypeTextField + if fieldType == TypeCharField { + switch tags["type"] { + case "text": + fieldType = TypeTextField + case "json": + fieldType = TypeJSONField + case "jsonb": + fieldType = TypeJsonbField + } } if fieldType == TypeFloatField && (digits != "" || decimals != "") { fieldType = TypeDecimalField @@ -248,6 +256,9 @@ checkType: if fieldType == TypeDateTimeField && tags["type"] == "date" { fieldType = TypeDateField } + if fieldType == TypeTimeField && tags["type"] == "time" { + fieldType = TypeTimeField + } } switch fieldType { @@ -339,7 +350,7 @@ checkType: switch fieldType { case TypeBooleanField: - case TypeCharField: + case TypeCharField, TypeJSONField, TypeJsonbField: if size != "" { v, e := StrTo(size).Int32() if e != nil { @@ -349,11 +360,12 @@ checkType: } } else { fi.size = 255 + fi.toText = true } case TypeTextField: fi.index = false fi.unique = false - case TypeDateField, TypeDateTimeField: + case TypeTimeField, TypeDateField, TypeDateTimeField: if attrs["auto_now"] { fi.autoNow = true } else if attrs["auto_now_add"] { @@ -406,7 +418,7 @@ checkType: fi.index = false } - if fi.auto || fi.pk || fi.unique || fieldType == TypeDateField || fieldType == TypeDateTimeField { + if fi.auto || fi.pk || fi.unique || fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField { // can not set default initial.Clear() } diff --git a/orm/models_test.go b/orm/models_test.go index ffb16ea0..462370b2 100644 --- a/orm/models_test.go +++ b/orm/models_test.go @@ -78,40 +78,43 @@ func (e *SliceStringField) RawValue() interface{} { var _ Fielder = new(SliceStringField) // A json field. -type JSONField struct { +type JSONFieldTest struct { Name string Data string } -func (e *JSONField) String() string { +func (e *JSONFieldTest) String() string { data, _ := json.Marshal(e) return string(data) } -func (e *JSONField) FieldType() int { +func (e *JSONFieldTest) FieldType() int { return TypeTextField } -func (e *JSONField) SetRaw(value interface{}) error { +func (e *JSONFieldTest) SetRaw(value interface{}) error { switch d := value.(type) { case string: return json.Unmarshal([]byte(d), e) default: - return fmt.Errorf(" unknown value `%v`", value) + return fmt.Errorf(" unknown value `%v`", value) } } -func (e *JSONField) RawValue() interface{} { +func (e *JSONFieldTest) RawValue() interface{} { return e.String() } -var _ Fielder = new(JSONField) +var _ Fielder = new(JSONFieldTest) type Data struct { ID int `orm:"column(id)"` Boolean bool Char string `orm:"size(50)"` Text string `orm:"type(text)"` + JSON string `orm:"type(json);default({\"name\":\"json\"})"` + Jsonb string `orm:"type(jsonb)"` + Time time.Time `orm:"type(time)"` Date time.Time `orm:"type(date)"` DateTime time.Time `orm:"column(datetime)"` Byte byte @@ -136,6 +139,9 @@ type DataNull struct { Boolean bool `orm:"null"` Char string `orm:"null;size(50)"` Text string `orm:"null;type(text)"` + JSON string `orm:"type(json);null"` + Jsonb string `orm:"type(jsonb);null"` + Time time.Time `orm:"null;type(time)"` Date time.Time `orm:"null;type(date)"` DateTime time.Time `orm:"null;column(datetime)"` Byte byte `orm:"null"` @@ -175,6 +181,9 @@ type DataNull struct { Float32Ptr *float32 `orm:"null"` Float64Ptr *float64 `orm:"null"` DecimalPtr *float64 `orm:"digits(8);decimals(4);null"` + TimePtr *time.Time `orm:"null;type(time)"` + DatePtr *time.Time `orm:"null;type(date)"` + DateTimePtr *time.Time `orm:"null"` } type String string @@ -237,7 +246,7 @@ type User struct { ShouldSkip string `orm:"-"` Nums int Langs SliceStringField `orm:"size(100)"` - Extra JSONField `orm:"type(text)"` + Extra JSONFieldTest `orm:"type(text)"` unexport bool `orm:"-"` unexportBool bool } @@ -375,6 +384,28 @@ func NewInLine() *InLine { return new(InLine) } +type InLineOneToOne struct { + // Common Fields + ModelBase + + Note string + InLine *InLine `orm:"rel(fk);column(inline)"` +} + +func NewInLineOneToOne() *InLineOneToOne { + return new(InLineOneToOne) +} + +type IntegerPk struct { + ID int64 `orm:"pk"` + Value string +} + +type UintPk struct { + ID uint32 `orm:"pk"` + Name string +} + var DBARGS = struct { Driver string Source string diff --git a/orm/models_utils.go b/orm/models_utils.go index ec11d516..4c4b0f24 100644 --- a/orm/models_utils.go +++ b/orm/models_utils.go @@ -137,6 +137,8 @@ func getFieldType(val reflect.Value) (ft int, err error) { ft = TypeBooleanField case reflect.TypeOf(new(string)): ft = TypeCharField + case reflect.TypeOf(new(time.Time)): + ft = TypeDateTimeField default: elm := reflect.Indirect(val) switch elm.Kind() { @@ -192,10 +194,10 @@ func parseStructTag(data string, attrs *map[string]bool, tags *map[string]string tag := make(map[string]string) for _, v := range strings.Split(data, defaultStructTagDelim) { v = strings.TrimSpace(v) - if supportTag[v] == 1 { - attr[v] = true + if t := strings.ToLower(v); supportTag[t] == 1 { + attr[t] = true } else if i := strings.Index(v, "("); i > 0 && strings.Index(v, ")") == len(v)-1 { - name := v[:i] + name := t[:i] if supportTag[name] == 2 { v = v[i+1 : len(v)-1] tag[name] = v diff --git a/orm/orm.go b/orm/orm.go index 0ffb6b86..5e43ae59 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -140,7 +140,14 @@ func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, i return (err == nil), id, err } - return false, ind.FieldByIndex(mi.fields.pk.fieldIndex).Int(), err + id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex) + if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { + id = int64(vid.Uint()) + } else { + id = vid.Int() + } + + return false, id, err } // insert model data to database @@ -159,7 +166,7 @@ func (o *orm) Insert(md interface{}) (int64, error) { // set auto pk field func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) { if mi.fields.pk.auto { - if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { + if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id)) } else { ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(id) @@ -184,7 +191,7 @@ func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) { if bulk <= 1 { for i := 0; i < sind.Len(); i++ { - ind := sind.Index(i) + ind := reflect.Indirect(sind.Index(i)) mi, _ := o.getMiInd(ind.Interface(), false) id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) if err != nil { diff --git a/orm/orm_log.go b/orm/orm_log.go index 712eb219..54723273 100644 --- a/orm/orm_log.go +++ b/orm/orm_log.go @@ -31,7 +31,7 @@ type Log struct { // NewLog set io.Writer to create a Logger. func NewLog(out io.Writer) *Log { d := new(Log) - d.Logger = log.New(out, "[ORM]", 1e9) + d.Logger = log.New(out, "[ORM]", log.LstdFlags) return d } diff --git a/orm/orm_object.go b/orm/orm_object.go index 8a5d85e2..de3181ce 100644 --- a/orm/orm_object.go +++ b/orm/orm_object.go @@ -50,7 +50,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) { } if id > 0 { if o.mi.fields.pk.auto { - if o.mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { + if o.mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetUint(uint64(id)) } else { ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetInt(id) diff --git a/orm/orm_test.go b/orm/orm_test.go index d9fd7d51..b5973448 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -19,6 +19,7 @@ import ( "database/sql" "fmt" "io/ioutil" + "math" "os" "path/filepath" "reflect" @@ -33,6 +34,7 @@ var _ = os.PathSeparator var ( testDate = formatDate + " -0700" testDateTime = formatDateTime + " -0700" + testTime = formatTime + " -0700" ) type argAny []interface{} @@ -188,6 +190,9 @@ func TestSyncDb(t *testing.T) { RegisterModel(new(Permission)) RegisterModel(new(GroupPermissions)) RegisterModel(new(InLine)) + RegisterModel(new(InLineOneToOne)) + RegisterModel(new(IntegerPk)) + RegisterModel(new(UintPk)) err := RunSyncdb("default", true, Debug) throwFail(t, err) @@ -208,6 +213,9 @@ func TestRegisterModels(t *testing.T) { RegisterModel(new(Permission)) RegisterModel(new(GroupPermissions)) RegisterModel(new(InLine)) + RegisterModel(new(InLineOneToOne)) + RegisterModel(new(IntegerPk)) + RegisterModel(new(UintPk)) BootStrap() @@ -233,6 +241,9 @@ var DataValues = map[string]interface{}{ "Boolean": true, "Char": "char", "Text": "text", + "JSON": `{"name":"json"}`, + "Jsonb": `{"name": "jsonb"}`, + "Time": time.Now(), "Date": time.Now(), "DateTime": time.Now(), "Byte": byte(1<<8 - 1), @@ -257,10 +268,12 @@ func TestDataTypes(t *testing.T) { ind := reflect.Indirect(reflect.ValueOf(&d)) for name, value := range DataValues { + if name == "JSON" { + continue + } e := ind.FieldByName(name) e.Set(reflect.ValueOf(value)) } - id, err := dORM.Insert(&d) throwFail(t, err) throwFail(t, AssertIs(id, 1)) @@ -281,6 +294,9 @@ func TestDataTypes(t *testing.T) { case "DateTime": vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + case "Time": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) } throwFail(t, AssertIs(vu == value, true), value, vu) } @@ -299,10 +315,18 @@ func TestNullDataTypes(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(id, 1)) + data := `{"ok":1,"data":{"arr":[1,2],"msg":"gopher"}}` + d = DataNull{ID: 1, JSON: data} + num, err := dORM.Update(&d) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + d = DataNull{ID: 1} err = dORM.Read(&d) throwFail(t, err) + throwFail(t, AssertIs(d.JSON, data)) + throwFail(t, AssertIs(d.NullBool.Valid, false)) throwFail(t, AssertIs(d.NullString.Valid, false)) throwFail(t, AssertIs(d.NullInt64.Valid, false)) @@ -326,6 +350,9 @@ func TestNullDataTypes(t *testing.T) { throwFail(t, AssertIs(d.Float32Ptr, nil)) throwFail(t, AssertIs(d.Float64Ptr, nil)) throwFail(t, AssertIs(d.DecimalPtr, nil)) + throwFail(t, AssertIs(d.TimePtr, nil)) + throwFail(t, AssertIs(d.DatePtr, nil)) + throwFail(t, AssertIs(d.DateTimePtr, nil)) _, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec() throwFail(t, err) @@ -352,6 +379,9 @@ func TestNullDataTypes(t *testing.T) { float32Ptr := float32(42.0) float64Ptr := float64(42.0) decimalPtr := float64(42.0) + timePtr := time.Now() + datePtr := time.Now() + dateTimePtr := time.Now() d = DataNull{ DateTime: time.Now(), @@ -377,6 +407,9 @@ func TestNullDataTypes(t *testing.T) { Float32Ptr: &float32Ptr, Float64Ptr: &float64Ptr, DecimalPtr: &decimalPtr, + TimePtr: &timePtr, + DatePtr: &datePtr, + DateTimePtr: &dateTimePtr, } id, err = dORM.Insert(&d) @@ -417,6 +450,9 @@ func TestNullDataTypes(t *testing.T) { throwFail(t, AssertIs(*d.Float32Ptr, float32Ptr)) throwFail(t, AssertIs(*d.Float64Ptr, float64Ptr)) throwFail(t, AssertIs(*d.DecimalPtr, decimalPtr)) + throwFail(t, AssertIs((*d.TimePtr).Format(testTime), timePtr.Format(testTime))) + throwFail(t, AssertIs((*d.DatePtr).Format(testDate), datePtr.Format(testDate))) + throwFail(t, AssertIs((*d.DateTimePtr).Format(testDateTime), dateTimePtr.Format(testDateTime))) } func TestDataCustomTypes(t *testing.T) { @@ -1521,6 +1557,7 @@ func TestRawQueryRow(t *testing.T) { Boolean bool Char string Text string + Time time.Time Date time.Time DateTime time.Time Byte byte @@ -1549,14 +1586,14 @@ func TestRawQueryRow(t *testing.T) { Q := dDbBaser.TableQuote() cols := []string{ - "id", "boolean", "char", "text", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32", + "id", "boolean", "char", "text", "time", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32", "int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal", } sep := fmt.Sprintf("%s, %s", Q, Q) query := fmt.Sprintf("SELECT %s%s%s FROM data WHERE id = ?", Q, strings.Join(cols, sep), Q) var id int values := []interface{}{ - &id, &Boolean, &Char, &Text, &Date, &DateTime, &Byte, &Rune, &Int, &Int8, &Int16, &Int32, + &id, &Boolean, &Char, &Text, &Time, &Date, &DateTime, &Byte, &Rune, &Int, &Int8, &Int16, &Int32, &Int64, &Uint, &Uint8, &Uint16, &Uint32, &Uint64, &Float32, &Float64, &Decimal, } err := dORM.Raw(query, 1).QueryRow(values...) @@ -1567,6 +1604,10 @@ func TestRawQueryRow(t *testing.T) { switch col { case "id": throwFail(t, AssertIs(id, 1)) + case "time": + v = v.(time.Time).In(DefaultTimeLoc) + value := dataValues[col].(time.Time).In(DefaultTimeLoc) + throwFail(t, AssertIs(v, value, testTime)) case "date": v = v.(time.Time).In(DefaultTimeLoc) value := dataValues[col].(time.Time).In(DefaultTimeLoc) @@ -1614,6 +1655,9 @@ func TestQueryRows(t *testing.T) { e := ind.FieldByName(name) vu := e.Interface() switch name { + case "Time": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) case "Date": vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) @@ -1638,6 +1682,9 @@ func TestQueryRows(t *testing.T) { e := ind.FieldByName(name) vu := e.Interface() switch name { + case "Time": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) case "Date": vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) @@ -1959,3 +2006,171 @@ func TestInLine(t *testing.T) { throwFail(t, AssertIs(il.Created.In(DefaultTimeLoc), inline.Created.In(DefaultTimeLoc), testDate)) throwFail(t, AssertIs(il.Updated.In(DefaultTimeLoc), inline.Updated.In(DefaultTimeLoc), testDateTime)) } + +func TestInLineOneToOne(t *testing.T) { + name := "121" + email := "121@go.com" + inline := NewInLine() + inline.Name = name + inline.Email = email + + id, err := dORM.Insert(inline) + throwFail(t, err) + throwFail(t, AssertIs(id, 2)) + + note := "one2one" + il121 := NewInLineOneToOne() + il121.Note = note + il121.InLine = inline + _, err = dORM.Insert(il121) + throwFail(t, err) + throwFail(t, AssertIs(il121.ID, 1)) + + il := NewInLineOneToOne() + err = dORM.QueryTable(il).Filter("Id", 1).RelatedSel().One(il) + + throwFail(t, err) + throwFail(t, AssertIs(il.Note, note)) + throwFail(t, AssertIs(il.InLine.ID, id)) + throwFail(t, AssertIs(il.InLine.Name, name)) + throwFail(t, AssertIs(il.InLine.Email, email)) + + rinline := NewInLine() + err = dORM.QueryTable(rinline).Filter("InLineOneToOne__Id", 1).One(rinline) + + throwFail(t, err) + throwFail(t, AssertIs(rinline.ID, id)) + throwFail(t, AssertIs(rinline.Name, name)) + throwFail(t, AssertIs(rinline.Email, email)) +} + +func TestIntegerPk(t *testing.T) { + its := []IntegerPk{ + {ID: math.MinInt64, Value: "-"}, + {ID: 0, Value: "0"}, + {ID: math.MaxInt64, Value: "+"}, + } + + num, err := dORM.InsertMulti(len(its), its) + throwFail(t, err) + throwFail(t, AssertIs(num, len(its))) + + for _, intPk := range its { + out := IntegerPk{ID: intPk.ID} + err = dORM.Read(&out) + throwFail(t, err) + throwFail(t, AssertIs(out.Value, intPk.Value)) + } + + num, err = dORM.InsertMulti(1, []*IntegerPk{&IntegerPk{ + ID: 1, Value: "ok", + }}) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestInsertAuto(t *testing.T) { + u := &User{ + UserName: "autoPre", + Email: "autoPre@gmail.com", + } + + id, err := dORM.Insert(u) + throwFail(t, err) + + id += 100 + su := &User{ + ID: int(id), + UserName: "auto", + Email: "auto@gmail.com", + } + + nid, err := dORM.Insert(su) + throwFail(t, err) + throwFail(t, AssertIs(nid, id)) + + users := []User{ + {ID: int(id + 100), UserName: "auto_100"}, + {ID: int(id + 110), UserName: "auto_110"}, + {ID: int(id + 120), UserName: "auto_120"}, + } + num, err := dORM.InsertMulti(100, users) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + u = &User{ + UserName: "auto_121", + } + + nid, err = dORM.Insert(u) + throwFail(t, err) + throwFail(t, AssertIs(nid, id+120+1)) +} + +func TestUintPk(t *testing.T) { + name := "go" + u := &UintPk{ + ID: 8, + Name: name, + } + + created, pk, err := dORM.ReadOrCreate(u, "ID") + throwFail(t, err) + throwFail(t, AssertIs(created, true)) + throwFail(t, AssertIs(u.Name, name)) + + nu := &UintPk{ID: 8} + created, pk, err = dORM.ReadOrCreate(nu, "ID") + throwFail(t, err) + throwFail(t, AssertIs(created, false)) + throwFail(t, AssertIs(nu.ID, u.ID)) + throwFail(t, AssertIs(pk, u.ID)) + throwFail(t, AssertIs(nu.Name, name)) + + dORM.Delete(u) +} + +func TestSnake(t *testing.T) { + cases := map[string]string{ + "i": "i", + "I": "i", + "iD": "i_d", + "ID": "id", + "NO": "no", + "NOO": "noo", + "NOOooOOoo": "noo_oo_oo_oo", + "OrderNO": "order_no", + "tagName": "tag_name", + "tag_Name": "tag_name", + "tag_name": "tag_name", + "_tag_name": "_tag_name", + "tag_666name": "tag_666name", + "tag_666Name": "tag_666_name", + } + for name, want := range cases { + got := snakeString(name) + throwFail(t, AssertIs(got, want)) + } +} + +func TestIgnoreCaseTag(t *testing.T) { + type testTagModel struct { + ID int `orm:"pk"` + NOO string `orm:"column(n)"` + Name01 string `orm:"NULL"` + Name02 string `orm:"COLUMN(Name)"` + Name03 string `orm:"Column(name)"` + } + modelCache.clean() + RegisterModel(&testTagModel{}) + info, ok := modelCache.get("test_tag_model") + throwFail(t, AssertIs(ok, true)) + throwFail(t, AssertNot(info, nil)) + if t == nil { + return + } + throwFail(t, AssertIs(info.fields.GetByName("NOO").column, "n")) + throwFail(t, AssertIs(info.fields.GetByName("Name01").null, true)) + throwFail(t, AssertIs(info.fields.GetByName("Name02").column, "Name")) + throwFail(t, AssertIs(info.fields.GetByName("Name03").column, "name")) +} diff --git a/orm/types.go b/orm/types.go index 41933dd1..cb55e71a 100644 --- a/orm/types.go +++ b/orm/types.go @@ -420,4 +420,5 @@ type dbBaser interface { ShowColumnsQuery(string) string IndexExists(dbQuerier, string, string) bool collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) + setval(dbQuerier, *modelInfo, []string) error } diff --git a/orm/utils.go b/orm/utils.go index 99437c7b..e3cd8ad6 100644 --- a/orm/utils.go +++ b/orm/utils.go @@ -181,18 +181,36 @@ func ToInt64(value interface{}) (d int64) { return } -// snake string, XxYy to xx_yy +// snake string, XxYy to xx_yy , XxYY to xx_yy func snakeString(s string) string { data := make([]byte, 0, len(s)*2) - j := false num := len(s) for i := 0; i < num; i++ { d := s[i] - if i > 0 && d >= 'A' && d <= 'Z' && j { - data = append(data, '_') - } - if d != '_' { - j = true + if i > 0 && d != '_' && s[i-1] != '_' { + need := false + // upper as 1, lower as 0 + // XX -> 11 -> 11 + // Xx -> 10 -> 10 + // XxYyZZ -> 101011 -> 10_10_11 + isUpper := d >= 'A' && d <= 'Z' + preIsUpper := s[i-1] >= 'A' && s[i-1] <= 'Z' + if isUpper { + // like : xxYy + if !preIsUpper { + need = true + } + } else { + if preIsUpper { + // ignore "Xy" in "xxXyy" + if i-2 >= 0 && s[i-2] >= 'A' && s[i-2] <= 'Z' { + need = true + } + } + } + if need { + data = append(data, '_') + } } data = append(data, d) } diff --git a/parser.go b/parser.go index 46d02320..caa2b38b 100644 --- a/parser.go +++ b/parser.go @@ -23,10 +23,11 @@ import ( "go/token" "io/ioutil" "os" - "path" + "path/filepath" "sort" "strings" + "github.com/astaxie/beego/logs" "github.com/astaxie/beego/utils" ) @@ -55,10 +56,11 @@ func init() { } func parserPkg(pkgRealpath, pkgpath string) error { - rep := strings.NewReplacer("/", "_", ".", "_") - commentFilename = coomentPrefix + rep.Replace(pkgpath) + ".go" + rep := strings.NewReplacer("\\", "_", "/", "_", ".", "_") + commentFilename, _ = filepath.Rel(AppPath, pkgRealpath) + commentFilename = coomentPrefix + rep.Replace(commentFilename) + ".go" if !compareFile(pkgRealpath) { - Info(pkgRealpath + " no changed") + logs.Info(pkgRealpath + " no changed") return nil } genInfoList = make(map[string][]ControllerComments) @@ -86,7 +88,7 @@ func parserPkg(pkgRealpath, pkgpath string) error { } } } - genRouterCode() + genRouterCode(pkgRealpath) savetoFile(pkgRealpath) return nil } @@ -129,9 +131,9 @@ func parserComments(comments *ast.CommentGroup, funcName, controllerName, pkgpat return nil } -func genRouterCode() { - os.Mkdir(path.Join(AppPath, "routers"), 0755) - Info("generate router from comments") +func genRouterCode(pkgRealpath string) { + os.Mkdir(getRouterDir(pkgRealpath), 0755) + logs.Info("generate router from comments") var ( globalinfo string sortKey []string @@ -164,15 +166,15 @@ func genRouterCode() { globalinfo = globalinfo + ` beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"], beego.ControllerComments{ - "` + strings.TrimSpace(c.Method) + `", - ` + "`" + c.Router + "`" + `, - ` + allmethod + `, - ` + params + `}) + Method: "` + strings.TrimSpace(c.Method) + `", + ` + "Router: `" + c.Router + "`" + `, + AllowHTTPMethods: ` + allmethod + `, + Params: ` + params + `}) ` } } if globalinfo != "" { - f, err := os.Create(path.Join(AppPath, "routers", commentFilename)) + f, err := os.Create(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) if err != nil { panic(err) } @@ -182,7 +184,7 @@ func genRouterCode() { } func compareFile(pkgRealpath string) bool { - if !utils.FileExists(path.Join(AppPath, "routers", commentFilename)) { + if !utils.FileExists(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) { return true } if utils.FileExists(lastupdateFilename) { @@ -229,3 +231,19 @@ func getpathTime(pkgRealpath string) (lastupdate int64, err error) { } return lastupdate, nil } + +func getRouterDir(pkgRealpath string) string { + dir := filepath.Dir(pkgRealpath) + for { + d := filepath.Join(dir, "routers") + if utils.FileExists(d) { + return d + } + + if r, _ := filepath.Rel(dir, AppPath); r == "." { + return d + } + // Parent dir. + dir = filepath.Dir(dir) + } +} diff --git a/router.go b/router.go index d0bf534f..960cd104 100644 --- a/router.go +++ b/router.go @@ -28,6 +28,7 @@ import ( "time" beecontext "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" "github.com/astaxie/beego/toolbox" "github.com/astaxie/beego/utils" ) @@ -114,7 +115,7 @@ type controllerInfo struct { type ControllerRegister struct { routers map[string]*Tree enableFilter bool - filters map[int][]*FilterRouter + filters [FinishRouter + 1][]*FilterRouter pool sync.Pool } @@ -122,7 +123,6 @@ type ControllerRegister struct { func NewControllerRegister() *ControllerRegister { cr := &ControllerRegister{ routers: make(map[string]*Tree), - filters: make(map[int][]*FilterRouter), } cr.pool.New = func() interface{} { return beecontext.NewContext() @@ -408,7 +408,6 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) // InsertFilter Add a FilterFunc with pattern rule and action constant. // The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { - mr := new(FilterRouter) mr.tree = NewTree() mr.pattern = pattern @@ -426,9 +425,13 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter } // add Filter into -func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) error { - p.filters[pos] = append(p.filters[pos], mr) +func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) { + if pos < BeforeStatic || pos > FinishRouter { + err = fmt.Errorf("can not find your filter postion") + return + } p.enableFilter = true + p.filters[pos] = append(p.filters[pos], mr) return nil } @@ -437,11 +440,11 @@ func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) error func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) string { paths := strings.Split(endpoint, ".") if len(paths) <= 1 { - Warn("urlfor endpoint must like path.controller.method") + logs.Warn("urlfor endpoint must like path.controller.method") return "" } if len(values)%2 != 0 { - Warn("urlfor params must key-value pair") + logs.Warn("urlfor params must key-value pair") return "" } params := make(map[string]string) @@ -577,20 +580,16 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin return false, "" } -func (p *ControllerRegister) execFilter(context *beecontext.Context, pos int, urlPath string) (started bool) { - if p.enableFilter { - if l, ok := p.filters[pos]; ok { - for _, filterR := range l { - if filterR.returnOnOutput && context.ResponseWriter.Started { - return true - } - if ok := filterR.ValidRouter(urlPath, context); ok { - filterR.filterFunc(context) - } - if filterR.returnOnOutput && context.ResponseWriter.Started { - return true - } - } +func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath string, pos int) (started bool) { + for _, filterR := range p.filters[pos] { + if filterR.returnOnOutput && context.ResponseWriter.Started { + return true + } + if ok := filterR.ValidRouter(urlPath, context); ok { + filterR.filterFunc(context) + } + if filterR.returnOnOutput && context.ResponseWriter.Started { + return true } } return false @@ -617,11 +616,10 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) context.Output.Header("Server", BConfig.ServerName) } - var urlPath string + var urlPath = r.URL.Path + if !BConfig.RouterCaseSensitive { - urlPath = strings.ToLower(r.URL.Path) - } else { - urlPath = r.URL.Path + urlPath = strings.ToLower(urlPath) } // filter wrong http method @@ -631,11 +629,12 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) } // filter for static file - if p.execFilter(context, BeforeStatic, urlPath) { + if len(p.filters[BeforeStatic]) > 0 && p.execFilter(context, urlPath, BeforeStatic) { goto Admin } serverStaticRouter(context) + if context.ResponseWriter.Started { findRouter = true goto Admin @@ -653,9 +652,9 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) var err error context.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) if err != nil { - Error(err) + logs.Error(err) exception("503", context) - return + goto Admin } defer func() { if context.Input.CruSession != nil { @@ -663,8 +662,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) } }() } - - if p.execFilter(context, BeforeRouter, urlPath) { + if len(p.filters[BeforeRouter]) > 0 && p.execFilter(context, urlPath, BeforeRouter) { goto Admin } @@ -693,7 +691,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) if findRouter { //execute middleware filters - if p.execFilter(context, BeforeExec, urlPath) { + if len(p.filters[BeforeExec]) > 0 && p.execFilter(context, urlPath, BeforeExec) { goto Admin } isRunnable := false @@ -783,7 +781,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) if !context.ResponseWriter.Started && context.Output.Status == 0 { if BConfig.WebConfig.AutoRender { if err := execController.Render(); err != nil { - panic(err) + logs.Error(err) } } } @@ -794,17 +792,18 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) } //execute middleware filters - if p.execFilter(context, AfterExec, urlPath) { + if len(p.filters[AfterExec]) > 0 && p.execFilter(context, urlPath, AfterExec) { goto Admin } } - - p.execFilter(context, FinishRouter, urlPath) + if len(p.filters[FinishRouter]) > 0 && p.execFilter(context, urlPath, FinishRouter) { + goto Admin + } Admin: - timeDur := time.Since(startTime) //admin module record QPS if BConfig.Listen.EnableAdmin { + timeDur := time.Since(startTime) if FilterMonitorFunc(r.Method, r.URL.Path, timeDur) { if runRouter != nil { go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, runRouter.Name(), timeDur) @@ -815,6 +814,7 @@ Admin: } if BConfig.RunMode == DEV || BConfig.Log.AccessLogs { + timeDur := time.Since(startTime) var devInfo string if findRouter { if routerInfo != nil { @@ -826,7 +826,7 @@ Admin: devInfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s |", r.Method, r.URL.Path, timeDur.String(), "notmatch") } if DefaultAccessLogFilter == nil || !DefaultAccessLogFilter.Filter(context) { - Debug(devInfo) + logs.Debug(devInfo) } } @@ -843,27 +843,26 @@ func (p *ControllerRegister) recoverPanic(context *beecontext.Context) { } if !BConfig.RecoverPanic { panic(err) - } else { - if BConfig.EnableErrorsShow { - if _, ok := ErrorMaps[fmt.Sprint(err)]; ok { - exception(fmt.Sprint(err), context) - return - } + } + if BConfig.EnableErrorsShow { + if _, ok := ErrorMaps[fmt.Sprint(err)]; ok { + exception(fmt.Sprint(err), context) + return } - var stack string - Critical("the request url is ", context.Input.URL()) - Critical("Handler crashed with error", err) - for i := 1; ; i++ { - _, file, line, ok := runtime.Caller(i) - if !ok { - break - } - Critical(fmt.Sprintf("%s:%d", file, line)) - stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line)) - } - if BConfig.RunMode == DEV { - showErr(err, context, stack) + } + var stack string + logs.Critical("the request url is ", context.Input.URL()) + logs.Critical("Handler crashed with error", err) + for i := 1; ; i++ { + _, file, line, ok := runtime.Caller(i) + if !ok { + break } + logs.Critical(fmt.Sprintf("%s:%d", file, line)) + stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line)) + } + if BConfig.RunMode == DEV { + showErr(err, context, stack) } } } diff --git a/router_test.go b/router_test.go index f26f0c86..9f11286c 100644 --- a/router_test.go +++ b/router_test.go @@ -21,6 +21,7 @@ import ( "testing" "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" ) type TestController struct { @@ -94,7 +95,7 @@ func TestUrlFor(t *testing.T) { handler.Add("/api/list", &TestController{}, "*:List") handler.Add("/person/:last/:first", &TestController{}, "*:Param") if a := handler.URLFor("TestController.List"); a != "/api/list" { - Info(a) + logs.Info(a) t.Errorf("TestController.List must equal to /api/list") } if a := handler.URLFor("TestController.Param", ":last", "xie", ":first", "asta"); a != "/person/xie/asta" { @@ -120,24 +121,24 @@ func TestUrlFor2(t *testing.T) { handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, "*:Param") handler.Add("/:year:int/:month:int/:title/:entid", &TestController{}) if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" { - Info(handler.URLFor("TestController.GetURL")) + logs.Info(handler.URLFor("TestController.GetURL")) t.Errorf("TestController.List must equal to /v1/astaxie/edit") } if handler.URLFor("TestController.List", ":v", "za", ":id", "12", ":page", "123") != "/v1/za/cms_12_123.html" { - Info(handler.URLFor("TestController.List")) + logs.Info(handler.URLFor("TestController.List")) t.Errorf("TestController.List must equal to /v1/za/cms_12_123.html") } if handler.URLFor("TestController.Param", ":v", "za", ":id", "12", ":page", "123") != "/v1/za_cms/ttt_12_123.html" { - Info(handler.URLFor("TestController.Param")) + logs.Info(handler.URLFor("TestController.Param")) t.Errorf("TestController.List must equal to /v1/za_cms/ttt_12_123.html") } if handler.URLFor("TestController.Get", ":year", "1111", ":month", "11", ":title", "aaaa", ":entid", "aaaa") != "/1111/11/aaaa/aaaa" { - Info(handler.URLFor("TestController.Get")) + logs.Info(handler.URLFor("TestController.Get")) t.Errorf("TestController.Get must equal to /1111/11/aaaa/aaaa") } } diff --git a/session/mysql/sess_mysql.go b/session/mysql/sess_mysql.go index 969d26c9..838ec669 100644 --- a/session/mysql/sess_mysql.go +++ b/session/mysql/sess_mysql.go @@ -115,7 +115,6 @@ func (st *SessionStore) SessionRelease(w http.ResponseWriter) { } st.c.Exec("UPDATE "+TableName+" set `session_data`=?, `session_expiry`=? where session_key=?", b, time.Now().Unix(), st.sid) - } // Provider mysql session provider diff --git a/session/sess_file.go b/session/sess_file.go index 9265b030..91acfcd4 100644 --- a/session/sess_file.go +++ b/session/sess_file.go @@ -16,7 +16,6 @@ package session import ( "errors" - "fmt" "io" "io/ioutil" "net/http" @@ -82,14 +81,17 @@ func (fs *FileSessionStore) SessionID() string { func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) { b, err := EncodeGob(fs.values) if err != nil { + SLogger.Println(err) return } _, err = os.Stat(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid)) var f *os.File if err == nil { f, err = os.OpenFile(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid), os.O_RDWR, 0777) + SLogger.Println(err) } else if os.IsNotExist(err) { f, err = os.Create(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid)) + SLogger.Println(err) } else { return } @@ -123,7 +125,7 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) { err := os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777) if err != nil { - println(err.Error()) + SLogger.Println(err.Error()) } _, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) var f *os.File @@ -191,7 +193,7 @@ func (fp *FileProvider) SessionAll() int { return a.visit(path, f, err) }) if err != nil { - fmt.Printf("filepath.Walk() returned %v\n", err) + SLogger.Printf("filepath.Walk() returned %v\n", err) return 0 } return a.total @@ -205,11 +207,11 @@ func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) { err := os.MkdirAll(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])), 0777) if err != nil { - println(err.Error()) + SLogger.Println(err.Error()) } err = os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777) if err != nil { - println(err.Error()) + SLogger.Println(err.Error()) } _, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) var newf *os.File diff --git a/session/session.go b/session/session.go index 9fe99a17..73f0d677 100644 --- a/session/session.go +++ b/session/session.go @@ -31,9 +31,14 @@ import ( "crypto/rand" "encoding/hex" "encoding/json" + "errors" "fmt" + "io" + "log" "net/http" + "net/textproto" "net/url" + "os" "time" ) @@ -61,6 +66,9 @@ type Provider interface { var provides = make(map[string]Provider) +// SLogger a helpful variable to log information about session +var SLogger = NewSessionLog(os.Stderr) + // Register makes a session provide available by the provided name. // If Register is called twice with the same name or if driver is nil, // it panics. @@ -75,15 +83,18 @@ func Register(name string, provide Provider) { } type managerConfig struct { - CookieName string `json:"cookieName"` - EnableSetCookie bool `json:"enableSetCookie,omitempty"` - Gclifetime int64 `json:"gclifetime"` - Maxlifetime int64 `json:"maxLifetime"` - Secure bool `json:"secure"` - CookieLifeTime int `json:"cookieLifeTime"` - ProviderConfig string `json:"providerConfig"` - Domain string `json:"domain"` - SessionIDLength int64 `json:"sessionIDLength"` + CookieName string `json:"cookieName"` + EnableSetCookie bool `json:"enableSetCookie,omitempty"` + Gclifetime int64 `json:"gclifetime"` + Maxlifetime int64 `json:"maxLifetime"` + Secure bool `json:"secure"` + CookieLifeTime int `json:"cookieLifeTime"` + ProviderConfig string `json:"providerConfig"` + Domain string `json:"domain"` + SessionIDLength int64 `json:"sessionIDLength"` + EnableSidInHttpHeader bool `json:"enableSidInHttpHeader"` + SessionNameInHttpHeader string `json:"sessionNameInHttpHeader"` + EnableSidInUrlQuery bool `json:"enableSidInUrlQuery"` } // Manager contains Provider and its configuration. @@ -118,6 +129,19 @@ func NewManager(provideName, config string) (*Manager, error) { if cf.Maxlifetime == 0 { cf.Maxlifetime = cf.Gclifetime } + + if cf.EnableSidInHttpHeader { + if cf.SessionNameInHttpHeader == "" { + panic(errors.New("SessionNameInHttpHeader is empty")) + } + + strMimeHeader := textproto.CanonicalMIMEHeaderKey(cf.SessionNameInHttpHeader) + if cf.SessionNameInHttpHeader != strMimeHeader { + strErrMsg := "SessionNameInHttpHeader (" + cf.SessionNameInHttpHeader + ") has the wrong format, it should be like this : " + strMimeHeader + panic(errors.New(strErrMsg)) + } + } + err = provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig) if err != nil { return nil, err @@ -143,12 +167,24 @@ func NewManager(provideName, config string) (*Manager, error) { func (manager *Manager) getSid(r *http.Request) (string, error) { cookie, errs := r.Cookie(manager.config.CookieName) if errs != nil || cookie.Value == "" || cookie.MaxAge < 0 { - errs := r.ParseForm() - if errs != nil { - return "", errs + var sid string + if manager.config.EnableSidInUrlQuery { + errs := r.ParseForm() + if errs != nil { + return "", errs + } + + sid = r.FormValue(manager.config.CookieName) + } + + // if not found in Cookie / param, then read it from request headers + if manager.config.EnableSidInHttpHeader && sid == "" { + sids, isFound := r.Header[manager.config.SessionNameInHttpHeader] + if isFound && len(sids) != 0 { + return sids[0], nil + } } - sid := r.FormValue(manager.config.CookieName) return sid, nil } @@ -192,11 +228,21 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se } r.AddCookie(cookie) + if manager.config.EnableSidInHttpHeader { + r.Header.Set(manager.config.SessionNameInHttpHeader, sid) + w.Header().Set(manager.config.SessionNameInHttpHeader, sid) + } + return } // SessionDestroy Destroy session by its id in http request cookie. func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { + if manager.config.EnableSidInHttpHeader { + r.Header.Del(manager.config.SessionNameInHttpHeader) + w.Header().Del(manager.config.SessionNameInHttpHeader) + } + cookie, err := r.Cookie(manager.config.CookieName) if err != nil || cookie.Value == "" { return @@ -261,6 +307,12 @@ func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Reque http.SetCookie(w, cookie) } r.AddCookie(cookie) + + if manager.config.EnableSidInHttpHeader { + r.Header.Set(manager.config.SessionNameInHttpHeader, sid) + w.Header().Set(manager.config.SessionNameInHttpHeader, sid) + } + return } @@ -296,3 +348,15 @@ func (manager *Manager) isSecure(req *http.Request) bool { } return true } + +// Log implement the log.Logger +type Log struct { + *log.Logger +} + +// NewSessionLog set io.Writer to create a Logger for session. +func NewSessionLog(out io.Writer) *Log { + sl := new(Log) + sl.Logger = log.New(out, "[SESSION]", 1e9) + return sl +} diff --git a/session/ssdb/sess_ssdb.go b/session/ssdb/sess_ssdb.go new file mode 100644 index 00000000..4dcf160a --- /dev/null +++ b/session/ssdb/sess_ssdb.go @@ -0,0 +1,192 @@ +package ssdb + +import ( + "errors" + "net/http" + "strconv" + "strings" + "sync" + + "github.com/astaxie/beego/session" + "github.com/ssdb/gossdb/ssdb" +) + +var ssdbProvider = &SsdbProvider{} + +type SsdbProvider struct { + client *ssdb.Client + host string + port int + maxLifetime int64 +} + +func (p *SsdbProvider) connectInit() error { + var err error + if p.host == "" || p.port == 0 { + return errors.New("SessionInit First") + } + p.client, err = ssdb.Connect(p.host, p.port) + if err != nil { + return err + } + return nil +} + +func (p *SsdbProvider) SessionInit(maxLifetime int64, savePath string) error { + var e error = nil + p.maxLifetime = maxLifetime + address := strings.Split(savePath, ":") + p.host = address[0] + p.port, e = strconv.Atoi(address[1]) + if e != nil { + return e + } + err := p.connectInit() + if err != nil { + return err + } + return nil +} + +func (p *SsdbProvider) SessionRead(sid string) (session.Store, error) { + if p.client == nil { + if err := p.connectInit(); err != nil { + return nil, err + } + } + var kv map[interface{}]interface{} + value, err := p.client.Get(sid) + if err != nil { + return nil, err + } + if value == nil || len(value.(string)) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob([]byte(value.(string))) + if err != nil { + return nil, err + } + } + rs := &SessionStore{sid: sid, values: kv, maxLifetime: p.maxLifetime, client: p.client} + return rs, nil +} + +func (p *SsdbProvider) SessionExist(sid string) bool { + if p.client == nil { + if err := p.connectInit(); err != nil { + panic(err) + } + } + value, err := p.client.Get(sid) + if err != nil { + panic(err) + } + if value == nil || len(value.(string)) == 0 { + return false + } + return true + +} +func (p *SsdbProvider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + //conn.Do("setx", key, v, ttl) + if p.client == nil { + if err := p.connectInit(); err != nil { + return nil, err + } + } + value, err := p.client.Get(oldsid) + if err != nil { + return nil, err + } + var kv map[interface{}]interface{} + if value == nil || len(value.(string)) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob([]byte(value.(string))) + if err != nil { + return nil, err + } + _, err = p.client.Del(oldsid) + if err != nil { + return nil, err + } + } + _, e := p.client.Do("setx", sid, value, p.maxLifetime) + if e != nil { + return nil, e + } + rs := &SessionStore{sid: sid, values: kv, maxLifetime: p.maxLifetime, client: p.client} + return rs, nil +} + +func (p *SsdbProvider) SessionDestroy(sid string) error { + if p.client == nil { + if err := p.connectInit(); err != nil { + return err + } + } + _, err := p.client.Del(sid) + if err != nil { + return err + } + return nil +} + +func (p *SsdbProvider) SessionGC() { + return +} + +func (p *SsdbProvider) SessionAll() int { + return 0 +} + +type SessionStore struct { + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxLifetime int64 + client *ssdb.Client +} + +func (s *SessionStore) Set(key, value interface{}) error { + s.lock.Lock() + defer s.lock.Unlock() + s.values[key] = value + return nil +} +func (s *SessionStore) Get(key interface{}) interface{} { + s.lock.Lock() + defer s.lock.Unlock() + if value, ok := s.values[key]; ok { + return value + } + return nil +} + +func (s *SessionStore) Delete(key interface{}) error { + s.lock.Lock() + defer s.lock.Unlock() + delete(s.values, key) + return nil +} +func (s *SessionStore) Flush() error { + s.lock.Lock() + defer s.lock.Unlock() + s.values = make(map[interface{}]interface{}) + return nil +} +func (s *SessionStore) SessionID() string { + return s.sid +} + +func (s *SessionStore) SessionRelease(w http.ResponseWriter) { + b, err := session.EncodeGob(s.values) + if err != nil { + return + } + s.client.Do("setx", s.sid, string(b), s.maxLifetime) + +} +func init() { + session.Register("ssdb", ssdbProvider) +} diff --git a/staticfile.go b/staticfile.go index 4b19f949..11a2cdcc 100644 --- a/staticfile.go +++ b/staticfile.go @@ -27,6 +27,7 @@ import ( "time" "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" ) var errNotStaticRequest = errors.New("request not a static file request") @@ -48,14 +49,19 @@ func serverStaticRouter(ctx *context.Context) { if filePath == "" || fileInfo == nil { if BConfig.RunMode == DEV { - Warn("Can't find/open the file:", filePath, err) + logs.Warn("Can't find/open the file:", filePath, err) } http.NotFound(ctx.ResponseWriter, ctx.Request) return } if fileInfo.IsDir() { - //serveFile will list dir - http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath) + requestURL := ctx.Input.URL() + if requestURL[len(requestURL)-1] != '/' { + ctx.Redirect(302, requestURL+"/") + } else { + //serveFile will list dir + http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath) + } return } @@ -67,7 +73,7 @@ func serverStaticRouter(ctx *context.Context) { b, n, sch, err := openFile(filePath, fileInfo, acceptEncoding) if err != nil { if BConfig.RunMode == DEV { - Warn("Can't compress the file:", filePath, err) + logs.Warn("Can't compress the file:", filePath, err) } http.NotFound(ctx.ResponseWriter, ctx.Request) return diff --git a/swagger/docs_spec.go b/swagger/docs_spec.go deleted file mode 100644 index 680324dc..00000000 --- a/swagger/docs_spec.go +++ /dev/null @@ -1,160 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package swagger struct definition -package swagger - -// SwaggerVersion show the current swagger version -const SwaggerVersion = "1.2" - -// ResourceListing list the resource -type ResourceListing struct { - APIVersion string `json:"apiVersion"` - SwaggerVersion string `json:"swaggerVersion"` // e.g 1.2 - // BasePath string `json:"basePath"` obsolete in 1.1 - APIs []APIRef `json:"apis"` - Info Information `json:"info"` -} - -// APIRef description the api path and description -type APIRef struct { - Path string `json:"path"` // relative or absolute, must start with / - Description string `json:"description"` -} - -// Information show the API Information -type Information struct { - Title string `json:"title,omitempty"` - Description string `json:"description,omitempty"` - Contact string `json:"contact,omitempty"` - TermsOfServiceURL string `json:"termsOfServiceUrl,omitempty"` - License string `json:"license,omitempty"` - LicenseURL string `json:"licenseUrl,omitempty"` -} - -// APIDeclaration see https://github.com/wordnik/swagger-core/blob/scala_2.10-1.3-RC3/schemas/api-declaration-schema.json -type APIDeclaration struct { - APIVersion string `json:"apiVersion"` - SwaggerVersion string `json:"swaggerVersion"` - BasePath string `json:"basePath"` - ResourcePath string `json:"resourcePath"` // must start with / - Consumes []string `json:"consumes,omitempty"` - Produces []string `json:"produces,omitempty"` - APIs []API `json:"apis,omitempty"` - Models map[string]Model `json:"models,omitempty"` -} - -// API show tha API struct -type API struct { - Path string `json:"path"` // relative or absolute, must start with / - Description string `json:"description"` - Operations []Operation `json:"operations,omitempty"` -} - -// Operation desc the Operation -type Operation struct { - HTTPMethod string `json:"httpMethod"` - Nickname string `json:"nickname"` - Type string `json:"type"` // in 1.1 = DataType - // ResponseClass string `json:"responseClass"` obsolete in 1.2 - Summary string `json:"summary,omitempty"` - Notes string `json:"notes,omitempty"` - Parameters []Parameter `json:"parameters,omitempty"` - ResponseMessages []ResponseMessage `json:"responseMessages,omitempty"` // optional - Consumes []string `json:"consumes,omitempty"` - Produces []string `json:"produces,omitempty"` - Authorizations []Authorization `json:"authorizations,omitempty"` - Protocols []Protocol `json:"protocols,omitempty"` -} - -// Protocol support which Protocol -type Protocol struct { -} - -// ResponseMessage Show the -type ResponseMessage struct { - Code int `json:"code"` - Message string `json:"message"` - ResponseModel string `json:"responseModel"` -} - -// Parameter desc the request parameters -type Parameter struct { - ParamType string `json:"paramType"` // path,query,body,header,form - Name string `json:"name"` - Description string `json:"description"` - DataType string `json:"dataType"` // 1.2 needed? - Type string `json:"type"` // integer - Format string `json:"format"` // int64 - AllowMultiple bool `json:"allowMultiple"` - Required bool `json:"required"` - Minimum int `json:"minimum"` - Maximum int `json:"maximum"` -} - -// ErrorResponse desc response -type ErrorResponse struct { - Code int `json:"code"` - Reason string `json:"reason"` -} - -// Model define the data model -type Model struct { - ID string `json:"id"` - Required []string `json:"required,omitempty"` - Properties map[string]ModelProperty `json:"properties"` -} - -// ModelProperty define the properties -type ModelProperty struct { - Type string `json:"type"` - Description string `json:"description"` - Items map[string]string `json:"items,omitempty"` - Format string `json:"format"` -} - -// Authorization see https://github.com/wordnik/swagger-core/wiki/authorizations -type Authorization struct { - LocalOAuth OAuth `json:"local-oauth"` - APIKey APIKey `json:"apiKey"` -} - -// OAuth see https://github.com/wordnik/swagger-core/wiki/authorizations -type OAuth struct { - Type string `json:"type"` // e.g. oauth2 - Scopes []string `json:"scopes"` // e.g. PUBLIC - GrantTypes map[string]GrantType `json:"grantTypes"` -} - -// GrantType see https://github.com/wordnik/swagger-core/wiki/authorizations -type GrantType struct { - LoginEndpoint Endpoint `json:"loginEndpoint"` - TokenName string `json:"tokenName"` // e.g. access_code - TokenRequestEndpoint Endpoint `json:"tokenRequestEndpoint"` - TokenEndpoint Endpoint `json:"tokenEndpoint"` -} - -// Endpoint see https://github.com/wordnik/swagger-core/wiki/authorizations -type Endpoint struct { - URL string `json:"url"` - ClientIDName string `json:"clientIdName"` - ClientSecretName string `json:"clientSecretName"` - TokenName string `json:"tokenName"` -} - -// APIKey see https://github.com/wordnik/swagger-core/wiki/authorizations -type APIKey struct { - Type string `json:"type"` // e.g. apiKey - PassAs string `json:"passAs"` // e.g. header -} diff --git a/swagger/swagger.go b/swagger/swagger.go new file mode 100644 index 00000000..e48dcf1e --- /dev/null +++ b/swagger/swagger.go @@ -0,0 +1,155 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Swagger™ is a project used to describe and document RESTful APIs. +// +// The Swagger specification defines a set of files required to describe such an API. These files can then be used by the Swagger-UI project to display the API and Swagger-Codegen to generate clients in various languages. Additional utilities can also take advantage of the resulting files, such as testing tools. +// Now in version 2.0, Swagger is more enabling than ever. And it's 100% open source software. + +// Package swagger struct definition +package swagger + +// Swagger list the resource +type Swagger struct { + SwaggerVersion string `json:"swagger,omitempty"` + Infos Information `json:"info"` + Host string `json:"host,omitempty"` + BasePath string `json:"basePath,omitempty"` + Schemes []string `json:"schemes,omitempty"` + Consumes []string `json:"consumes,omitempty"` + Produces []string `json:"produces,omitempty"` + Paths map[string]Item `json:"paths"` + Definitions map[string]Schema `json:"definitions,omitempty"` + SecurityDefinitions map[string]Scurity `json:"securityDefinitions,omitempty"` + Security map[string][]string `json:"security,omitempty"` + Tags []Tag `json:"tags,omitempty"` + ExternalDocs ExternalDocs `json:"externalDocs,omitempty"` +} + +// Information Provides metadata about the API. The metadata can be used by the clients if needed. +type Information struct { + Title string `json:"title,omitempty"` + Description string `json:"description,omitempty"` + Version string `json:"version,omitempty"` + TermsOfServiceURL string `json:"termsOfServiceUrl,omitempty"` + + Contact Contact `json:"contact,omitempty"` + License License `json:"license,omitempty"` +} + +// Contact information for the exposed API. +type Contact struct { + Name string `json:"name,omitempty"` + URL string `json:"url,omitempty"` + EMail string `json:"email,omitempty"` +} + +// License information for the exposed API. +type License struct { + Name string `json:"name,omitempty"` + URL string `json:"url,omitempty"` +} + +// Item Describes the operations available on a single path. +type Item struct { + Ref string `json:"$ref,omitempty"` + Get *Operation `json:"get,omitempty"` + Put *Operation `json:"put,omitempty"` + Post *Operation `json:"post,omitempty"` + Delete *Operation `json:"delete,omitempty"` + Options *Operation `json:"options,omitempty"` + Head *Operation `json:"head,omitempty"` + Patch *Operation `json:"patch,omitempty"` +} + +// Operation Describes a single API operation on a path. +type Operation struct { + Tags []string `json:"tags,omitempty"` + Summary string `json:"summary,omitempty"` + Description string `json:"description,omitempty"` + OperationID string `json:"operationId,omitempty"` + Consumes []string `json:"consumes,omitempty"` + Produces []string `json:"produces,omitempty"` + Schemes []string `json:"schemes,omitempty"` + Parameters []Parameter `json:"parameters,omitempty"` + Responses map[string]Response `json:"responses,omitempty"` + Deprecated bool `json:"deprecated,omitempty"` +} + +// Parameter Describes a single operation parameter. +type Parameter struct { + In string `json:"in,omitempty"` + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Required bool `json:"required,omitempty"` + Schema Schema `json:"schema,omitempty"` + Type string `json:"type,omitempty"` + Format string `json:"format,omitempty"` +} + +// Schema Object allows the definition of input and output data types. +type Schema struct { + Ref string `json:"$ref,omitempty"` + Title string `json:"title,omitempty"` + Format string `json:"format,omitempty"` + Description string `json:"description,omitempty"` + Required []string `json:"required,omitempty"` + Type string `json:"type,omitempty"` + Properties map[string]Propertie `json:"properties,omitempty"` +} + +// Propertie are taken from the JSON Schema definition but their definitions were adjusted to the Swagger Specification +type Propertie struct { + Title string `json:"title,omitempty"` + Description string `json:"description,omitempty"` + Default string `json:"default,omitempty"` + Type string `json:"type,omitempty"` + Example string `json:"example,omitempty"` + Required []string `json:"required,omitempty"` + Format string `json:"format,omitempty"` + ReadOnly bool `json:"readOnly,omitempty"` + Properties map[string]Propertie `json:"properties,omitempty"` +} + +// Response as they are returned from executing this operation. +type Response struct { + Description string `json:"description,omitempty"` + Schema Schema `json:"schema,omitempty"` + Ref string `json:"$ref,omitempty"` +} + +// Scurity Allows the definition of a security scheme that can be used by the operations +type Scurity struct { + Type string `json:"type,omitempty"` // Valid values are "basic", "apiKey" or "oauth2". + Description string `json:"description,omitempty"` + Name string `json:"name,omitempty"` + In string `json:"in,omitempty"` // Valid values are "query" or "header". + Flow string `json:"flow,omitempty"` // Valid values are "implicit", "password", "application" or "accessCode". + AuthorizationURL string `json:"authorizationUrl,omitempty"` + TokenURL string `json:"tokenUrl,omitempty"` + Scopes map[string]string `json:"scopes,omitempty"` // The available scopes for the OAuth2 security scheme. +} + +// Tag Allows adding meta data to a single tag that is used by the Operation Object +type Tag struct { + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + ExternalDocs ExternalDocs `json:"externalDocs,omitempty"` +} + +// ExternalDocs include Additional external documentation +type ExternalDocs struct { + Description string `json:"description,omitempty"` + URL string `json:"url,omitempty"` +} diff --git a/template.go b/template.go index e6c43f87..494acc4f 100644 --- a/template.go +++ b/template.go @@ -26,6 +26,7 @@ import ( "strings" "sync" + "github.com/astaxie/beego/logs" "github.com/astaxie/beego/utils" ) @@ -36,19 +37,35 @@ var ( templatesLock sync.RWMutex // beeTemplateExt stores the template extension which will build beeTemplateExt = []string{"tpl", "html"} + // beeTemplatePreprocessors stores associations of extension -> preprocessor handler + beeTemplateEngines = map[string]templatePreProcessor{} ) -func executeTemplate(wr io.Writer, name string, data interface{}) error { +// ExecuteTemplate applies the template with name to the specified data object, +// writing the output to wr. +// A template will be executed safely in parallel. +func ExecuteTemplate(wr io.Writer, name string, data interface{}) error { if BConfig.RunMode == DEV { templatesLock.RLock() defer templatesLock.RUnlock() } if t, ok := beeTemplates[name]; ok { - err := t.ExecuteTemplate(wr, name, data) - if err != nil { - Trace("template Execute err:", err) + if t.Lookup(name) != nil { + err := t.ExecuteTemplate(wr, name, data) + if err != nil { + logs.Trace("template Execute err:", err) + } + return err + } else { + err := t.Execute(wr, data) + if err != nil { + if err != nil { + logs.Trace("template Execute err:", err) + } + return err + } } - return err + return nil } panic("can't find templatefile in the path:" + name) } @@ -88,6 +105,8 @@ func AddFuncMap(key string, fn interface{}) error { return nil } +type templatePreProcessor func(root, path string, funcs template.FuncMap) (*template.Template, error) + type templateFile struct { root string files map[string][]string @@ -156,13 +175,22 @@ func BuildTemplate(dir string, files ...string) error { fmt.Printf("filepath.Walk() returned %v\n", err) return err } + buildAllFiles := len(files) == 0 for _, v := range self.files { for _, file := range v { - if len(files) == 0 || utils.InSlice(file, files) { + if buildAllFiles || utils.InSlice(file, files) { templatesLock.Lock() - t, err := getTemplate(self.root, file, v...) + ext := filepath.Ext(file) + var t *template.Template + if len(ext) == 0 { + t, err = getTemplate(self.root, file, v...) + } else if fn, ok := beeTemplateEngines[ext[1:]]; ok { + t, err = fn(self.root, file, beegoTplFuncMap) + } else { + t, err = getTemplate(self.root, file, v...) + } if err != nil { - Trace("parse template err:", file, err) + logs.Trace("parse template err:", file, err) } else { beeTemplates[file] = t } @@ -240,7 +268,7 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others var subMods1 [][]string t, subMods1, err = getTplDeep(root, otherFile, "", t) if err != nil { - Trace("template parse file err:", err) + logs.Trace("template parse file err:", err) } else if subMods1 != nil && len(subMods1) > 0 { t, err = _getTemplate(t, root, subMods1, others...) } @@ -261,7 +289,7 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others var subMods1 [][]string t, subMods1, err = getTplDeep(root, otherFile, "", t) if err != nil { - Trace("template parse file err:", err) + logs.Trace("template parse file err:", err) } else if subMods1 != nil && len(subMods1) > 0 { t, err = _getTemplate(t, root, subMods1, others...) } @@ -305,3 +333,9 @@ func DelStaticPath(url string) *App { delete(BConfig.WebConfig.StaticDir, url) return BeeApp } + +func AddTemplateEngine(extension string, fn templatePreProcessor) *App { + AddTemplateExt(extension) + beeTemplateEngines[extension] = fn + return BeeApp +} diff --git a/templatefunc.go b/templatefunc.go index 8558733f..36442984 100644 --- a/templatefunc.go +++ b/templatefunc.go @@ -421,18 +421,18 @@ func RenderForm(obj interface{}) template.HTML { fieldT := objT.Field(i) - label, name, fType, id, class, ignored := parseFormTag(fieldT) + label, name, fType, id, class, ignored, required := parseFormTag(fieldT) if ignored { continue } - raw = append(raw, renderFormField(label, name, fType, fieldV.Interface(), id, class)) + raw = append(raw, renderFormField(label, name, fType, fieldV.Interface(), id, class, required)) } return template.HTML(strings.Join(raw, "
")) } // renderFormField returns a string containing HTML of a single form field. -func renderFormField(label, name, fType string, value interface{}, id string, class string) string { +func renderFormField(label, name, fType string, value interface{}, id string, class string, required bool) string { if id != "" { id = " id=\"" + id + "\"" } @@ -441,11 +441,16 @@ func renderFormField(label, name, fType string, value interface{}, id string, cl class = " class=\"" + class + "\"" } - if isValidForInput(fType) { - return fmt.Sprintf(`%v`, label, id, class, name, fType, value) + requiredString := "" + if required { + requiredString = " required" } - return fmt.Sprintf(`%v<%v%v%v name="%v">%v`, label, fType, id, class, name, value, fType) + if isValidForInput(fType) { + return fmt.Sprintf(`%v`, label, id, class, name, fType, value, requiredString) + } + + return fmt.Sprintf(`%v<%v%v%v name="%v"%v>%v`, label, fType, id, class, name, requiredString, value, fType) } // isValidForInput checks if fType is a valid value for the `type` property of an HTML input element. @@ -461,7 +466,7 @@ func isValidForInput(fType string) bool { // parseFormTag takes the stuct-tag of a StructField and parses the `form` value. // returned are the form label, name-property, type and wether the field should be ignored. -func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id string, class string, ignored bool) { +func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id string, class string, ignored bool, required bool) { tags := strings.Split(fieldT.Tag.Get("form"), ",") label = fieldT.Name + ": " name = fieldT.Name @@ -470,6 +475,12 @@ func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id str id = fieldT.Tag.Get("id") class = fieldT.Tag.Get("class") + required = false + required_field := fieldT.Tag.Get("required") + if required_field != "-" && required_field != "" { + required, _ = strconv.ParseBool(required_field) + } + switch len(tags) { case 1: if tags[0] == "-" { @@ -496,6 +507,7 @@ func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id str label = tags[2] } } + return } diff --git a/templatefunc_test.go b/templatefunc_test.go index 98fbf7ab..86de37ae 100644 --- a/templatefunc_test.go +++ b/templatefunc_test.go @@ -195,54 +195,78 @@ func TestRenderForm(t *testing.T) { } func TestRenderFormField(t *testing.T) { - html := renderFormField("Label: ", "Name", "text", "Value", "", "") + html := renderFormField("Label: ", "Name", "text", "Value", "", "", false) if html != `Label: ` { t.Errorf("Wrong html output for input[type=text]: %v ", html) } - html = renderFormField("Label: ", "Name", "textarea", "Value", "", "") + html = renderFormField("Label: ", "Name", "textarea", "Value", "", "", false) if html != `Label: ` { t.Errorf("Wrong html output for textarea: %v ", html) } + + html = renderFormField("Label: ", "Name", "textarea", "Value", "", "", true) + if html != `Label: ` { + t.Errorf("Wrong html output for textarea: %v ", html) + } } func TestParseFormTag(t *testing.T) { // create struct to contain field with different types of struct-tag `form` type user struct { - All int `form:"name,text,年龄:"` - NoName int `form:",hidden,年龄:"` - OnlyLabel int `form:",,年龄:"` - OnlyName int `form:"name" id:"name" class:"form-name"` - Ignored int `form:"-"` + All int `form:"name,text,年龄:"` + NoName int `form:",hidden,年龄:"` + OnlyLabel int `form:",,年龄:"` + OnlyName int `form:"name" id:"name" class:"form-name"` + Ignored int `form:"-"` + Required int `form:"name" required:"true"` + IgnoreRequired int `form:"name"` + NotRequired int `form:"name" required:"false"` } objT := reflect.TypeOf(&user{}).Elem() - label, name, fType, id, class, ignored := parseFormTag(objT.Field(0)) + label, name, fType, id, class, ignored, required := parseFormTag(objT.Field(0)) if !(name == "name" && label == "年龄:" && fType == "text" && ignored == false) { t.Errorf("Form Tag with name, label and type was not correctly parsed.") } - label, name, fType, id, class, ignored = parseFormTag(objT.Field(1)) + label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(1)) if !(name == "NoName" && label == "年龄:" && fType == "hidden" && ignored == false) { t.Errorf("Form Tag with label and type but without name was not correctly parsed.") } - label, name, fType, id, class, ignored = parseFormTag(objT.Field(2)) + label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(2)) if !(name == "OnlyLabel" && label == "年龄:" && fType == "text" && ignored == false) { t.Errorf("Form Tag containing only label was not correctly parsed.") } - label, name, fType, id, class, ignored = parseFormTag(objT.Field(3)) + label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(3)) if !(name == "name" && label == "OnlyName: " && fType == "text" && ignored == false && id == "name" && class == "form-name") { t.Errorf("Form Tag containing only name was not correctly parsed.") } - label, name, fType, id, class, ignored = parseFormTag(objT.Field(4)) + label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(4)) if ignored == false { t.Errorf("Form Tag that should be ignored was not correctly parsed.") } + + label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(5)) + if !(name == "name" && required == true) { + t.Errorf("Form Tag containing only name and required was not correctly parsed.") + } + + label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(6)) + if !(name == "name" && required == false) { + t.Errorf("Form Tag containing only name and ignore required was not correctly parsed.") + } + + label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(7)) + if !(name == "name" && required == false) { + t.Errorf("Form Tag containing only name and not required was not correctly parsed.") + } + } func TestMapGet(t *testing.T) { diff --git a/toolbox/task.go b/toolbox/task.go index 537de428..abd411c8 100644 --- a/toolbox/task.go +++ b/toolbox/task.go @@ -389,6 +389,10 @@ func dayMatches(s *Schedule, t time.Time) bool { // StartTask start all tasks func StartTask() { + if isstart { + //If already started, no need to start another goroutine. + return + } isstart = true go run() } @@ -432,8 +436,11 @@ func run() { // StopTask stop all tasks func StopTask() { - isstart = false - stop <- true + if isstart { + isstart = false + stop <- true + } + } // AddTask add task with name diff --git a/utils/captcha/captcha.go b/utils/captcha/captcha.go index 1a4a6edc..42ac70d3 100644 --- a/utils/captcha/captcha.go +++ b/utils/captcha/captcha.go @@ -69,6 +69,7 @@ import ( "github.com/astaxie/beego" "github.com/astaxie/beego/cache" "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" "github.com/astaxie/beego/utils" ) @@ -139,7 +140,7 @@ func (c *Captcha) Handler(ctx *context.Context) { if err := c.store.Put(key, chars, c.Expiration); err != nil { ctx.Output.SetStatus(500) ctx.WriteString("captcha reload error") - beego.Error("Reload Create Captcha Error:", err) + logs.Error("Reload Create Captcha Error:", err) return } } else { @@ -154,7 +155,7 @@ func (c *Captcha) Handler(ctx *context.Context) { img := NewImage(chars, c.StdWidth, c.StdHeight) if _, err := img.WriteTo(ctx.ResponseWriter); err != nil { - beego.Error("Write Captcha Image Error:", err) + logs.Error("Write Captcha Image Error:", err) } } @@ -162,7 +163,7 @@ func (c *Captcha) Handler(ctx *context.Context) { func (c *Captcha) CreateCaptchaHTML() template.HTML { value, err := c.CreateCaptcha() if err != nil { - beego.Error("Create Captcha Error:", err) + logs.Error("Create Captcha Error:", err) return "" } diff --git a/utils/pagination/controller.go b/utils/pagination/controller.go index 1d99cac5..2f022d0c 100644 --- a/utils/pagination/controller.go +++ b/utils/pagination/controller.go @@ -18,7 +18,7 @@ import ( "github.com/astaxie/beego/context" ) -// SetPaginator Instantiates a Paginator and assigns it to context.Input.Data["paginator"]. +// SetPaginator Instantiates a Paginator and assigns it to context.Input.Data("paginator"). func SetPaginator(context *context.Context, per int, nums int64) (paginator *Paginator) { paginator = NewPaginator(context.Request, per, nums) context.Input.SetData("paginator", &paginator) diff --git a/utils/rand.go b/utils/rand.go index 74bb4121..344d1cd5 100644 --- a/utils/rand.go +++ b/utils/rand.go @@ -20,28 +20,24 @@ import ( "time" ) +var alphaNum = []byte(`0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz`) + // RandomCreateBytes generate random []byte by specify chars. func RandomCreateBytes(n int, alphabets ...byte) []byte { - const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + if len(alphabets) == 0 { + alphabets = alphaNum + } var bytes = make([]byte, n) - var randby bool + var randBy bool if num, err := rand.Read(bytes); num != n || err != nil { r.Seed(time.Now().UnixNano()) - randby = true + randBy = true } for i, b := range bytes { - if len(alphabets) == 0 { - if randby { - bytes[i] = alphanum[r.Intn(len(alphanum))] - } else { - bytes[i] = alphanum[b%byte(len(alphanum))] - } + if randBy { + bytes[i] = alphabets[r.Intn(len(alphabets))] } else { - if randby { - bytes[i] = alphabets[r.Intn(len(alphabets))] - } else { - bytes[i] = alphabets[b%byte(len(alphabets))] - } + bytes[i] = alphabets[b%byte(len(alphabets))] } } return bytes diff --git a/docs.go b/utils/rand_test.go similarity index 50% rename from docs.go rename to utils/rand_test.go index 72532876..6c238b5e 100644 --- a/docs.go +++ b/utils/rand_test.go @@ -1,4 +1,4 @@ -// Copyright 2014 beego Author. All Rights Reserved. +// Copyright 2016 beego Author. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,28 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package utils -import ( - "github.com/astaxie/beego/context" -) +import "testing" -// GlobalDocAPI store the swagger api documents -var GlobalDocAPI = make(map[string]interface{}) +func TestRand_01(t *testing.T) { + bs0 := RandomCreateBytes(16) + bs1 := RandomCreateBytes(16) -func serverDocs(ctx *context.Context) { - var obj interface{} - if splat := ctx.Input.Param(":splat"); splat == "" { - obj = GlobalDocAPI["Root"] - } else { - if v, ok := GlobalDocAPI[splat]; ok { - obj = v - } + t.Log(string(bs0), string(bs1)) + if string(bs0) == string(bs1) { + t.FailNow() } - if obj != nil { - ctx.Output.Header("Access-Control-Allow-Origin", "*") - ctx.Output.JSON(obj, false, false) - return + + bs0 = RandomCreateBytes(4, []byte(`a`)...) + + if string(bs0) != "aaaa" { + t.FailNow() } - ctx.Output.SetStatus(404) } diff --git a/validation/validation.go b/validation/validation.go index 2b020aa8..489dfa5e 100644 --- a/validation/validation.go +++ b/validation/validation.go @@ -73,6 +73,10 @@ func (e *Error) String() string { return e.Message } +// Implement Error interface. +// Return e.String() +func (e *Error) Error() string { return e.String() } + // Result is returned from every validation method. // It provides an indication of success, and a pointer to the Error (if any). type Result struct {