diff --git a/.travis.yml b/.travis.yml
index 3c821dcd..93536488 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,8 +1,7 @@
language: go
go:
- - tip
- - 1.6.0
+ - 1.6
- 1.5.3
- 1.4.3
services:
@@ -31,21 +30,20 @@ install:
- go get github.com/belogik/goes
- go get github.com/siddontang/ledisdb/config
- go get github.com/siddontang/ledisdb/ledis
- - go get golang.org/x/tools/cmd/vet
- - go get github.com/golang/lint/golint
- go get github.com/ssdb/gossdb/ssdb
before_script:
+ - psql --version
- sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi"
- sh -c "if [ '$ORM_DRIVER' = 'mysql' ]; then mysql -u root -e 'create database orm_test;'; fi"
- sh -c "if [ '$ORM_DRIVER' = 'sqlite' ]; then touch $TRAVIS_BUILD_DIR/orm_test.db; fi"
+ - sh -c "if [ $(go version) == *1.[5-9]* ]; then go get github.com/golang/lint/golint; golint ./...; fi"
+ - sh -c "if [ $(go version) == *1.[5-9]* ]; then go tool vet .; fi"
- mkdir -p res/var
- ./ssdb/ssdb-server ./ssdb/ssdb.conf -d
after_script:
-killall -w ssdb-server
- rm -rf ./res/var/*
script:
- - go vet -x ./...
- - $HOME/gopath/bin/golint ./...
- go test -v ./...
-notifications:
- webhooks: https://hooks.pubu.im/services/z7m9bvybl3rgtg9
+addons:
+ postgresql: "9.4"
diff --git a/README.md b/README.md
index 6c589584..fbd7ccb7 100644
--- a/README.md
+++ b/README.md
@@ -30,7 +30,7 @@ func main(){
```
######Congratulations!
You just built your first beego app.
-Open your browser and visit `http://localhost:8000`.
+Open your browser and visit `http://localhost:8080`.
Please see [Documentation](http://beego.me/docs) for more.
## Features
diff --git a/admin.go b/admin.go
index 031e6421..a2b2f53a 100644
--- a/admin.go
+++ b/admin.go
@@ -23,7 +23,10 @@ import (
"text/template"
"time"
+ "reflect"
+
"github.com/astaxie/beego/grace"
+ "github.com/astaxie/beego/logs"
"github.com/astaxie/beego/toolbox"
"github.com/astaxie/beego/utils"
)
@@ -90,57 +93,9 @@ func listConf(rw http.ResponseWriter, r *http.Request) {
switch command {
case "conf":
m := make(map[string]interface{})
+ list("BConfig", BConfig, m)
m["AppConfigPath"] = appConfigPath
m["AppConfigProvider"] = appConfigProvider
- m["BConfig.AppName"] = BConfig.AppName
- m["BConfig.RunMode"] = BConfig.RunMode
- m["BConfig.RouterCaseSensitive"] = BConfig.RouterCaseSensitive
- m["BConfig.ServerName"] = BConfig.ServerName
- m["BConfig.RecoverPanic"] = BConfig.RecoverPanic
- m["BConfig.CopyRequestBody"] = BConfig.CopyRequestBody
- m["BConfig.EnableGzip"] = BConfig.EnableGzip
- m["BConfig.MaxMemory"] = BConfig.MaxMemory
- m["BConfig.EnableErrorsShow"] = BConfig.EnableErrorsShow
- m["BConfig.Listen.Graceful"] = BConfig.Listen.Graceful
- m["BConfig.Listen.ServerTimeOut"] = BConfig.Listen.ServerTimeOut
- m["BConfig.Listen.ListenTCP4"] = BConfig.Listen.ListenTCP4
- m["BConfig.Listen.EnableHTTP"] = BConfig.Listen.EnableHTTP
- m["BConfig.Listen.HTTPAddr"] = BConfig.Listen.HTTPAddr
- m["BConfig.Listen.HTTPPort"] = BConfig.Listen.HTTPPort
- m["BConfig.Listen.EnableHTTPS"] = BConfig.Listen.EnableHTTPS
- m["BConfig.Listen.HTTPSAddr"] = BConfig.Listen.HTTPSAddr
- m["BConfig.Listen.HTTPSPort"] = BConfig.Listen.HTTPSPort
- m["BConfig.Listen.HTTPSCertFile"] = BConfig.Listen.HTTPSCertFile
- m["BConfig.Listen.HTTPSKeyFile"] = BConfig.Listen.HTTPSKeyFile
- m["BConfig.Listen.EnableAdmin"] = BConfig.Listen.EnableAdmin
- m["BConfig.Listen.AdminAddr"] = BConfig.Listen.AdminAddr
- m["BConfig.Listen.AdminPort"] = BConfig.Listen.AdminPort
- m["BConfig.Listen.EnableFcgi"] = BConfig.Listen.EnableFcgi
- m["BConfig.Listen.EnableStdIo"] = BConfig.Listen.EnableStdIo
- m["BConfig.WebConfig.AutoRender"] = BConfig.WebConfig.AutoRender
- m["BConfig.WebConfig.EnableDocs"] = BConfig.WebConfig.EnableDocs
- m["BConfig.WebConfig.FlashName"] = BConfig.WebConfig.FlashName
- m["BConfig.WebConfig.FlashSeparator"] = BConfig.WebConfig.FlashSeparator
- m["BConfig.WebConfig.DirectoryIndex"] = BConfig.WebConfig.DirectoryIndex
- m["BConfig.WebConfig.StaticDir"] = BConfig.WebConfig.StaticDir
- m["BConfig.WebConfig.StaticExtensionsToGzip"] = BConfig.WebConfig.StaticExtensionsToGzip
- m["BConfig.WebConfig.TemplateLeft"] = BConfig.WebConfig.TemplateLeft
- m["BConfig.WebConfig.TemplateRight"] = BConfig.WebConfig.TemplateRight
- m["BConfig.WebConfig.ViewsPath"] = BConfig.WebConfig.ViewsPath
- m["BConfig.WebConfig.EnableXSRF"] = BConfig.WebConfig.EnableXSRF
- m["BConfig.WebConfig.XSRFKEY"] = BConfig.WebConfig.XSRFKey
- m["BConfig.WebConfig.XSRFExpire"] = BConfig.WebConfig.XSRFExpire
- m["BConfig.WebConfig.Session.SessionOn"] = BConfig.WebConfig.Session.SessionOn
- m["BConfig.WebConfig.Session.SessionProvider"] = BConfig.WebConfig.Session.SessionProvider
- m["BConfig.WebConfig.Session.SessionName"] = BConfig.WebConfig.Session.SessionName
- m["BConfig.WebConfig.Session.SessionGCMaxLifetime"] = BConfig.WebConfig.Session.SessionGCMaxLifetime
- m["BConfig.WebConfig.Session.SessionProviderConfig"] = BConfig.WebConfig.Session.SessionProviderConfig
- m["BConfig.WebConfig.Session.SessionCookieLifeTime"] = BConfig.WebConfig.Session.SessionCookieLifeTime
- m["BConfig.WebConfig.Session.SessionAutoSetCookie"] = BConfig.WebConfig.Session.SessionAutoSetCookie
- m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain
- m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs
- m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum
- m["BConfig.Log.Outputs"] = BConfig.Log.Outputs
tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
tmpl = template.Must(tmpl.Parse(configTpl))
tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
@@ -196,7 +151,7 @@ func listConf(rw http.ResponseWriter, r *http.Request) {
BeforeExec: "Before Exec",
AfterExec: "After Exec",
FinishRouter: "Finish Router"} {
- if bf, ok := BeeApp.Handlers.filters[k]; ok {
+ if bf := BeeApp.Handlers.filters[k]; len(bf) > 0 {
filterType = fr
filterTypes = append(filterTypes, filterType)
resultList := new([][]string)
@@ -223,6 +178,28 @@ func listConf(rw http.ResponseWriter, r *http.Request) {
}
}
+func list(root string, p interface{}, m map[string]interface{}) {
+ pt := reflect.TypeOf(p)
+ pv := reflect.ValueOf(p)
+ if pt.Kind() == reflect.Ptr {
+ pt = pt.Elem()
+ pv = pv.Elem()
+ }
+ for i := 0; i < pv.NumField(); i++ {
+ var key string
+ if root == "" {
+ key = pt.Field(i).Name
+ } else {
+ key = root + "." + pt.Field(i).Name
+ }
+ if pv.Field(i).Kind() == reflect.Struct {
+ list(key, pv.Field(i).Interface(), m)
+ } else {
+ m[key] = pv.Field(i).Interface()
+ }
+ }
+}
+
func printTree(resultList *[][]string, t *Tree) {
for _, tr := range t.fixrouters {
printTree(resultList, tr)
@@ -410,7 +387,7 @@ func (admin *adminApp) Run() {
for p, f := range admin.routers {
http.Handle(p, f)
}
- BeeLogger.Info("Admin server Running on %s", addr)
+ logs.Info("Admin server Running on %s", addr)
var err error
if BConfig.Listen.Graceful {
@@ -419,6 +396,6 @@ func (admin *adminApp) Run() {
err = http.ListenAndServe(addr, nil)
}
if err != nil {
- BeeLogger.Critical("Admin ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid()))
+ logs.Critical("Admin ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid()))
}
}
diff --git a/admin_test.go b/admin_test.go
new file mode 100644
index 00000000..04744f8c
--- /dev/null
+++ b/admin_test.go
@@ -0,0 +1,72 @@
+package beego
+
+import (
+ "testing"
+ "fmt"
+)
+
+func TestList_01(t *testing.T) {
+ m := make(map[string]interface{})
+ list("BConfig", BConfig, m)
+ t.Log(m)
+ om := oldMap()
+ for k, v := range om {
+ if fmt.Sprint(m[k])!= fmt.Sprint(v) {
+ t.Log(k, "old-key",v,"new-key", m[k])
+ t.FailNow()
+ }
+ }
+}
+
+func oldMap() map[string]interface{} {
+ m := make(map[string]interface{})
+ m["BConfig.AppName"] = BConfig.AppName
+ m["BConfig.RunMode"] = BConfig.RunMode
+ m["BConfig.RouterCaseSensitive"] = BConfig.RouterCaseSensitive
+ m["BConfig.ServerName"] = BConfig.ServerName
+ m["BConfig.RecoverPanic"] = BConfig.RecoverPanic
+ m["BConfig.CopyRequestBody"] = BConfig.CopyRequestBody
+ m["BConfig.EnableGzip"] = BConfig.EnableGzip
+ m["BConfig.MaxMemory"] = BConfig.MaxMemory
+ m["BConfig.EnableErrorsShow"] = BConfig.EnableErrorsShow
+ m["BConfig.Listen.Graceful"] = BConfig.Listen.Graceful
+ m["BConfig.Listen.ServerTimeOut"] = BConfig.Listen.ServerTimeOut
+ m["BConfig.Listen.ListenTCP4"] = BConfig.Listen.ListenTCP4
+ m["BConfig.Listen.EnableHTTP"] = BConfig.Listen.EnableHTTP
+ m["BConfig.Listen.HTTPAddr"] = BConfig.Listen.HTTPAddr
+ m["BConfig.Listen.HTTPPort"] = BConfig.Listen.HTTPPort
+ m["BConfig.Listen.EnableHTTPS"] = BConfig.Listen.EnableHTTPS
+ m["BConfig.Listen.HTTPSAddr"] = BConfig.Listen.HTTPSAddr
+ m["BConfig.Listen.HTTPSPort"] = BConfig.Listen.HTTPSPort
+ m["BConfig.Listen.HTTPSCertFile"] = BConfig.Listen.HTTPSCertFile
+ m["BConfig.Listen.HTTPSKeyFile"] = BConfig.Listen.HTTPSKeyFile
+ m["BConfig.Listen.EnableAdmin"] = BConfig.Listen.EnableAdmin
+ m["BConfig.Listen.AdminAddr"] = BConfig.Listen.AdminAddr
+ m["BConfig.Listen.AdminPort"] = BConfig.Listen.AdminPort
+ m["BConfig.Listen.EnableFcgi"] = BConfig.Listen.EnableFcgi
+ m["BConfig.Listen.EnableStdIo"] = BConfig.Listen.EnableStdIo
+ m["BConfig.WebConfig.AutoRender"] = BConfig.WebConfig.AutoRender
+ m["BConfig.WebConfig.EnableDocs"] = BConfig.WebConfig.EnableDocs
+ m["BConfig.WebConfig.FlashName"] = BConfig.WebConfig.FlashName
+ m["BConfig.WebConfig.FlashSeparator"] = BConfig.WebConfig.FlashSeparator
+ m["BConfig.WebConfig.DirectoryIndex"] = BConfig.WebConfig.DirectoryIndex
+ m["BConfig.WebConfig.StaticDir"] = BConfig.WebConfig.StaticDir
+ m["BConfig.WebConfig.StaticExtensionsToGzip"] = BConfig.WebConfig.StaticExtensionsToGzip
+ m["BConfig.WebConfig.TemplateLeft"] = BConfig.WebConfig.TemplateLeft
+ m["BConfig.WebConfig.TemplateRight"] = BConfig.WebConfig.TemplateRight
+ m["BConfig.WebConfig.ViewsPath"] = BConfig.WebConfig.ViewsPath
+ m["BConfig.WebConfig.EnableXSRF"] = BConfig.WebConfig.EnableXSRF
+ m["BConfig.WebConfig.XSRFExpire"] = BConfig.WebConfig.XSRFExpire
+ m["BConfig.WebConfig.Session.SessionOn"] = BConfig.WebConfig.Session.SessionOn
+ m["BConfig.WebConfig.Session.SessionProvider"] = BConfig.WebConfig.Session.SessionProvider
+ m["BConfig.WebConfig.Session.SessionName"] = BConfig.WebConfig.Session.SessionName
+ m["BConfig.WebConfig.Session.SessionGCMaxLifetime"] = BConfig.WebConfig.Session.SessionGCMaxLifetime
+ m["BConfig.WebConfig.Session.SessionProviderConfig"] = BConfig.WebConfig.Session.SessionProviderConfig
+ m["BConfig.WebConfig.Session.SessionCookieLifeTime"] = BConfig.WebConfig.Session.SessionCookieLifeTime
+ m["BConfig.WebConfig.Session.SessionAutoSetCookie"] = BConfig.WebConfig.Session.SessionAutoSetCookie
+ m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain
+ m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs
+ m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum
+ m["BConfig.Log.Outputs"] = BConfig.Log.Outputs
+ return m
+}
diff --git a/app.go b/app.go
index af54ea4b..423a0a6b 100644
--- a/app.go
+++ b/app.go
@@ -24,6 +24,7 @@ import (
"time"
"github.com/astaxie/beego/grace"
+ "github.com/astaxie/beego/logs"
"github.com/astaxie/beego/utils"
)
@@ -68,9 +69,9 @@ func (app *App) Run() {
if BConfig.Listen.EnableFcgi {
if BConfig.Listen.EnableStdIo {
if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O
- BeeLogger.Info("Use FCGI via standard I/O")
+ logs.Info("Use FCGI via standard I/O")
} else {
- BeeLogger.Critical("Cannot use FCGI via standard I/O", err)
+ logs.Critical("Cannot use FCGI via standard I/O", err)
}
return
}
@@ -84,10 +85,10 @@ func (app *App) Run() {
l, err = net.Listen("tcp", addr)
}
if err != nil {
- BeeLogger.Critical("Listen: ", err)
+ logs.Critical("Listen: ", err)
}
if err = fcgi.Serve(l, app.Handlers); err != nil {
- BeeLogger.Critical("fcgi.Serve: ", err)
+ logs.Critical("fcgi.Serve: ", err)
}
return
}
@@ -95,6 +96,7 @@ func (app *App) Run() {
app.Server.Handler = app.Handlers
app.Server.ReadTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second
app.Server.WriteTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second
+ app.Server.ErrorLog = logs.GetLogger("HTTP")
// run graceful mode
if BConfig.Listen.Graceful {
@@ -111,7 +113,7 @@ func (app *App) Run() {
server.Server.ReadTimeout = app.Server.ReadTimeout
server.Server.WriteTimeout = app.Server.WriteTimeout
if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil {
- BeeLogger.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
+ logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond)
endRunning <- true
}
@@ -126,7 +128,7 @@ func (app *App) Run() {
server.Network = "tcp4"
}
if err := server.ListenAndServe(); err != nil {
- BeeLogger.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid()))
+ logs.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond)
endRunning <- true
}
@@ -137,16 +139,18 @@ func (app *App) Run() {
}
// run normal mode
- app.Server.Addr = addr
if BConfig.Listen.EnableHTTPS {
go func() {
time.Sleep(20 * time.Microsecond)
if BConfig.Listen.HTTPSPort != 0 {
app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort)
+ } else if BConfig.Listen.EnableHTTP {
+ BeeLogger.Info("Start https server error, confict with http.Please reset https port")
+ return
}
- BeeLogger.Info("https server Running on %s", app.Server.Addr)
+ logs.Info("https server Running on %s", app.Server.Addr)
if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil {
- BeeLogger.Critical("ListenAndServeTLS: ", err)
+ logs.Critical("ListenAndServeTLS: ", err)
time.Sleep(100 * time.Microsecond)
endRunning <- true
}
@@ -155,24 +159,24 @@ func (app *App) Run() {
if BConfig.Listen.EnableHTTP {
go func() {
app.Server.Addr = addr
- BeeLogger.Info("http server Running on %s", app.Server.Addr)
+ logs.Info("http server Running on %s", app.Server.Addr)
if BConfig.Listen.ListenTCP4 {
ln, err := net.Listen("tcp4", app.Server.Addr)
if err != nil {
- BeeLogger.Critical("ListenAndServe: ", err)
+ logs.Critical("ListenAndServe: ", err)
time.Sleep(100 * time.Microsecond)
endRunning <- true
return
}
if err = app.Server.Serve(ln); err != nil {
- BeeLogger.Critical("ListenAndServe: ", err)
+ logs.Critical("ListenAndServe: ", err)
time.Sleep(100 * time.Microsecond)
endRunning <- true
return
}
} else {
if err := app.Server.ListenAndServe(); err != nil {
- BeeLogger.Critical("ListenAndServe: ", err)
+ logs.Critical("ListenAndServe: ", err)
time.Sleep(100 * time.Microsecond)
endRunning <- true
}
diff --git a/beego.go b/beego.go
index 8f82cdcf..32b64f75 100644
--- a/beego.go
+++ b/beego.go
@@ -51,6 +51,7 @@ func AddAPPStartHook(hf hookfunc) {
// beego.Run(":8089")
// beego.Run("127.0.0.1:8089")
func Run(params ...string) {
+
initBeforeHTTPRun()
if len(params) > 0 && params[0] != "" {
@@ -71,9 +72,9 @@ func initBeforeHTTPRun() {
AddAPPStartHook(registerMime)
AddAPPStartHook(registerDefaultErrorHandler)
AddAPPStartHook(registerSession)
- AddAPPStartHook(registerDocs)
AddAPPStartHook(registerTemplate)
AddAPPStartHook(registerAdmin)
+ AddAPPStartHook(registerGzip)
for _, hk := range hooks {
if err := hk(); err != nil {
@@ -84,8 +85,11 @@ func initBeforeHTTPRun() {
// TestBeegoInit is for test package init
func TestBeegoInit(ap string) {
- os.Setenv("BEEGO_RUNMODE", "test")
appConfigPath = filepath.Join(ap, "conf", "app.conf")
os.Chdir(ap)
+ if err := LoadAppConfig(appConfigProvider, appConfigPath); err != nil {
+ panic(err)
+ }
+ BConfig.RunMode = "test"
initBeforeHTTPRun()
}
diff --git a/cache/redis/redis_test.go b/cache/redis/redis_test.go
index 47c5acc6..6b81da4d 100644
--- a/cache/redis/redis_test.go
+++ b/cache/redis/redis_test.go
@@ -18,9 +18,8 @@ import (
"testing"
"time"
- "github.com/garyburd/redigo/redis"
-
"github.com/astaxie/beego/cache"
+ "github.com/garyburd/redigo/redis"
)
func TestRedisCache(t *testing.T) {
diff --git a/cache/ssdb/ssdb_test.go b/cache/ssdb/ssdb_test.go
index e03ba343..dd474960 100644
--- a/cache/ssdb/ssdb_test.go
+++ b/cache/ssdb/ssdb_test.go
@@ -1,10 +1,11 @@
package ssdb
import (
- "github.com/astaxie/beego/cache"
"strconv"
"testing"
"time"
+
+ "github.com/astaxie/beego/cache"
)
func TestSsdbcacheCache(t *testing.T) {
diff --git a/config.go b/config.go
index 067f0d1b..effc5e12 100644
--- a/config.go
+++ b/config.go
@@ -18,9 +18,11 @@ import (
"fmt"
"os"
"path/filepath"
+ "reflect"
"strings"
"github.com/astaxie/beego/config"
+ "github.com/astaxie/beego/logs"
"github.com/astaxie/beego/session"
"github.com/astaxie/beego/utils"
)
@@ -81,14 +83,17 @@ type WebConfig struct {
// SessionConfig holds session related config
type SessionConfig struct {
- SessionOn bool
- SessionProvider string
- SessionName string
- SessionGCMaxLifetime int64
- SessionProviderConfig string
- SessionCookieLifeTime int
- SessionAutoSetCookie bool
- SessionDomain string
+ SessionOn bool
+ SessionProvider string
+ SessionName string
+ SessionGCMaxLifetime int64
+ SessionProviderConfig string
+ SessionCookieLifeTime int
+ SessionAutoSetCookie bool
+ SessionDomain string
+ EnableSidInHttpHeader bool // enable store/get the sessionId into/from http headers
+ SessionNameInHttpHeader string
+ EnableSidInUrlQuery bool // enable get the sessionId from Url Query params
}
// LogConfig holds Log related config
@@ -115,11 +120,30 @@ var (
)
func init() {
- AppPath, _ = filepath.Abs(filepath.Dir(os.Args[0]))
+ BConfig = newBConfig()
+ var err error
+ if AppPath, err = filepath.Abs(filepath.Dir(os.Args[0])); err != nil {
+ panic(err)
+ }
+ workPath, err := os.Getwd()
+ if err != nil {
+ panic(err)
+ }
+ appConfigPath = filepath.Join(workPath, "conf", "app.conf")
+ if !utils.FileExists(appConfigPath) {
+ appConfigPath = filepath.Join(AppPath, "conf", "app.conf")
+ if !utils.FileExists(appConfigPath) {
+ AppConfig = &beegoAppConfig{innerConfig: config.NewFakeConfig()}
+ return
+ }
+ }
+ if err = parseConfig(appConfigPath); err != nil {
+ panic(err)
+ }
+}
- os.Chdir(AppPath)
-
- BConfig = &Config{
+func newBConfig() *Config {
+ return &Config{
AppName: "beego",
RunMode: DEV,
RouterCaseSensitive: true,
@@ -162,14 +186,17 @@ func init() {
XSRFKey: "beegoxsrf",
XSRFExpire: 0,
Session: SessionConfig{
- SessionOn: false,
- SessionProvider: "memory",
- SessionName: "beegosessionID",
- SessionGCMaxLifetime: 3600,
- SessionProviderConfig: "",
- SessionCookieLifeTime: 0, //set cookie default is the browser life
- SessionAutoSetCookie: true,
- SessionDomain: "",
+ SessionOn: false,
+ SessionProvider: "memory",
+ SessionName: "beegosessionID",
+ SessionGCMaxLifetime: 3600,
+ SessionProviderConfig: "",
+ SessionCookieLifeTime: 0, //set cookie default is the browser life
+ SessionAutoSetCookie: true,
+ SessionDomain: "",
+ EnableSidInHttpHeader: false, // enable store/get the sessionId into/from http headers
+ SessionNameInHttpHeader: "Beegosessionid",
+ EnableSidInUrlQuery: false, // enable get the sessionId from Url Query params
},
},
Log: LogConfig{
@@ -178,16 +205,6 @@ func init() {
Outputs: map[string]string{"console": ""},
},
}
-
- appConfigPath = filepath.Join(AppPath, "conf", "app.conf")
- if !utils.FileExists(appConfigPath) {
- AppConfig = &beegoAppConfig{innerConfig: config.NewFakeConfig()}
- return
- }
-
- if err := parseConfig(appConfigPath); err != nil {
- panic(err)
- }
}
// now only support ini, next will support json.
@@ -196,63 +213,23 @@ func parseConfig(appConfigPath string) (err error) {
if err != nil {
return err
}
+ return assignConfig(AppConfig)
+}
+
+func assignConfig(ac config.Configer) error {
// set the run mode first
if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" {
BConfig.RunMode = envRunMode
- } else if runMode := AppConfig.String("RunMode"); runMode != "" {
+ } else if runMode := ac.String("RunMode"); runMode != "" {
BConfig.RunMode = runMode
}
- BConfig.AppName = AppConfig.DefaultString("AppName", BConfig.AppName)
- BConfig.RecoverPanic = AppConfig.DefaultBool("RecoverPanic", BConfig.RecoverPanic)
- BConfig.RouterCaseSensitive = AppConfig.DefaultBool("RouterCaseSensitive", BConfig.RouterCaseSensitive)
- BConfig.ServerName = AppConfig.DefaultString("ServerName", BConfig.ServerName)
- BConfig.EnableGzip = AppConfig.DefaultBool("EnableGzip", BConfig.EnableGzip)
- BConfig.EnableErrorsShow = AppConfig.DefaultBool("EnableErrorsShow", BConfig.EnableErrorsShow)
- BConfig.CopyRequestBody = AppConfig.DefaultBool("CopyRequestBody", BConfig.CopyRequestBody)
- BConfig.MaxMemory = AppConfig.DefaultInt64("MaxMemory", BConfig.MaxMemory)
- BConfig.Listen.Graceful = AppConfig.DefaultBool("Graceful", BConfig.Listen.Graceful)
- BConfig.Listen.HTTPAddr = AppConfig.String("HTTPAddr")
- BConfig.Listen.HTTPPort = AppConfig.DefaultInt("HTTPPort", BConfig.Listen.HTTPPort)
- BConfig.Listen.ListenTCP4 = AppConfig.DefaultBool("ListenTCP4", BConfig.Listen.ListenTCP4)
- BConfig.Listen.EnableHTTP = AppConfig.DefaultBool("EnableHTTP", BConfig.Listen.EnableHTTP)
- BConfig.Listen.EnableHTTPS = AppConfig.DefaultBool("EnableHTTPS", BConfig.Listen.EnableHTTPS)
- BConfig.Listen.HTTPSAddr = AppConfig.DefaultString("HTTPSAddr", BConfig.Listen.HTTPSAddr)
- BConfig.Listen.HTTPSPort = AppConfig.DefaultInt("HTTPSPort", BConfig.Listen.HTTPSPort)
- BConfig.Listen.HTTPSCertFile = AppConfig.DefaultString("HTTPSCertFile", BConfig.Listen.HTTPSCertFile)
- BConfig.Listen.HTTPSKeyFile = AppConfig.DefaultString("HTTPSKeyFile", BConfig.Listen.HTTPSKeyFile)
- BConfig.Listen.EnableAdmin = AppConfig.DefaultBool("EnableAdmin", BConfig.Listen.EnableAdmin)
- BConfig.Listen.AdminAddr = AppConfig.DefaultString("AdminAddr", BConfig.Listen.AdminAddr)
- BConfig.Listen.AdminPort = AppConfig.DefaultInt("AdminPort", BConfig.Listen.AdminPort)
- BConfig.Listen.EnableFcgi = AppConfig.DefaultBool("EnableFcgi", BConfig.Listen.EnableFcgi)
- BConfig.Listen.EnableStdIo = AppConfig.DefaultBool("EnableStdIo", BConfig.Listen.EnableStdIo)
- BConfig.Listen.ServerTimeOut = AppConfig.DefaultInt64("ServerTimeOut", BConfig.Listen.ServerTimeOut)
- BConfig.WebConfig.AutoRender = AppConfig.DefaultBool("AutoRender", BConfig.WebConfig.AutoRender)
- BConfig.WebConfig.ViewsPath = AppConfig.DefaultString("ViewsPath", BConfig.WebConfig.ViewsPath)
- BConfig.WebConfig.DirectoryIndex = AppConfig.DefaultBool("DirectoryIndex", BConfig.WebConfig.DirectoryIndex)
- BConfig.WebConfig.FlashName = AppConfig.DefaultString("FlashName", BConfig.WebConfig.FlashName)
- BConfig.WebConfig.FlashSeparator = AppConfig.DefaultString("FlashSeparator", BConfig.WebConfig.FlashSeparator)
- BConfig.WebConfig.EnableDocs = AppConfig.DefaultBool("EnableDocs", BConfig.WebConfig.EnableDocs)
- BConfig.WebConfig.XSRFKey = AppConfig.DefaultString("XSRFKEY", BConfig.WebConfig.XSRFKey)
- BConfig.WebConfig.EnableXSRF = AppConfig.DefaultBool("EnableXSRF", BConfig.WebConfig.EnableXSRF)
- BConfig.WebConfig.XSRFExpire = AppConfig.DefaultInt("XSRFExpire", BConfig.WebConfig.XSRFExpire)
- BConfig.WebConfig.TemplateLeft = AppConfig.DefaultString("TemplateLeft", BConfig.WebConfig.TemplateLeft)
- BConfig.WebConfig.TemplateRight = AppConfig.DefaultString("TemplateRight", BConfig.WebConfig.TemplateRight)
- BConfig.WebConfig.Session.SessionOn = AppConfig.DefaultBool("SessionOn", BConfig.WebConfig.Session.SessionOn)
- BConfig.WebConfig.Session.SessionProvider = AppConfig.DefaultString("SessionProvider", BConfig.WebConfig.Session.SessionProvider)
- BConfig.WebConfig.Session.SessionName = AppConfig.DefaultString("SessionName", BConfig.WebConfig.Session.SessionName)
- BConfig.WebConfig.Session.SessionProviderConfig = AppConfig.DefaultString("SessionProviderConfig", BConfig.WebConfig.Session.SessionProviderConfig)
- BConfig.WebConfig.Session.SessionGCMaxLifetime = AppConfig.DefaultInt64("SessionGCMaxLifetime", BConfig.WebConfig.Session.SessionGCMaxLifetime)
- BConfig.WebConfig.Session.SessionCookieLifeTime = AppConfig.DefaultInt("SessionCookieLifeTime", BConfig.WebConfig.Session.SessionCookieLifeTime)
- BConfig.WebConfig.Session.SessionAutoSetCookie = AppConfig.DefaultBool("SessionAutoSetCookie", BConfig.WebConfig.Session.SessionAutoSetCookie)
- BConfig.WebConfig.Session.SessionDomain = AppConfig.DefaultString("SessionDomain", BConfig.WebConfig.Session.SessionDomain)
- BConfig.Log.AccessLogs = AppConfig.DefaultBool("LogAccessLogs", BConfig.Log.AccessLogs)
- BConfig.Log.FileLineNum = AppConfig.DefaultBool("LogFileLineNum", BConfig.Log.FileLineNum)
+ for _, i := range []interface{}{BConfig, &BConfig.Listen, &BConfig.WebConfig, &BConfig.Log, &BConfig.WebConfig.Session} {
+ assignSingleConfig(i, ac)
+ }
- if sd := AppConfig.String("StaticDir"); sd != "" {
- for k := range BConfig.WebConfig.StaticDir {
- delete(BConfig.WebConfig.StaticDir, k)
- }
+ if sd := ac.String("StaticDir"); sd != "" {
+ BConfig.WebConfig.StaticDir = map[string]string{}
sds := strings.Fields(sd)
for _, v := range sds {
if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 {
@@ -262,7 +239,8 @@ func parseConfig(appConfigPath string) (err error) {
}
}
}
- if sgz := AppConfig.String("StaticExtensionsToGzip"); sgz != "" {
+
+ if sgz := ac.String("StaticExtensionsToGzip"); sgz != "" {
extensions := strings.Split(sgz, ",")
fileExts := []string{}
for _, ext := range extensions {
@@ -280,7 +258,7 @@ func parseConfig(appConfigPath string) (err error) {
}
}
- if lo := AppConfig.String("LogOutputs"); lo != "" {
+ if lo := ac.String("LogOutputs"); lo != "" {
los := strings.Split(lo, ";")
for _, v := range los {
if logType2Config := strings.SplitN(v, ",", 2); len(logType2Config) == 2 {
@@ -292,18 +270,50 @@ func parseConfig(appConfigPath string) (err error) {
}
//init log
- BeeLogger.Reset()
+ logs.Reset()
for adaptor, config := range BConfig.Log.Outputs {
- err = BeeLogger.SetLogger(adaptor, config)
+ err := logs.SetLogger(adaptor, config)
if err != nil {
- fmt.Printf("%s with the config `%s` got err:%s\n", adaptor, config, err)
+ fmt.Fprintln(os.Stderr, fmt.Sprintf("%s with the config %q got err:%s", adaptor, config, err.Error()))
}
}
- SetLogFuncCall(BConfig.Log.FileLineNum)
+ logs.SetLogFuncCall(BConfig.Log.FileLineNum)
return nil
}
+func assignSingleConfig(p interface{}, ac config.Configer) {
+ pt := reflect.TypeOf(p)
+ if pt.Kind() != reflect.Ptr {
+ return
+ }
+ pt = pt.Elem()
+ if pt.Kind() != reflect.Struct {
+ return
+ }
+ pv := reflect.ValueOf(p).Elem()
+
+ for i := 0; i < pt.NumField(); i++ {
+ pf := pv.Field(i)
+ if !pf.CanSet() {
+ continue
+ }
+ name := pt.Field(i).Name
+ switch pf.Kind() {
+ case reflect.String:
+ pf.SetString(ac.DefaultString(name, pf.String()))
+ case reflect.Int, reflect.Int64:
+ pf.SetInt(int64(ac.DefaultInt64(name, pf.Int())))
+ case reflect.Bool:
+ pf.SetBool(ac.DefaultBool(name, pf.Bool()))
+ case reflect.Struct:
+ default:
+ //do nothing here
+ }
+ }
+
+}
+
// LoadAppConfig allow developer to apply a config file
func LoadAppConfig(adapterName, configPath string) error {
absConfigPath, err := filepath.Abs(configPath)
@@ -315,10 +325,6 @@ func LoadAppConfig(adapterName, configPath string) error {
return fmt.Errorf("the target config file: %s don't exist", configPath)
}
- if absConfigPath == appConfigPath {
- return nil
- }
-
appConfigPath = absConfigPath
appConfigProvider = adapterName
@@ -352,7 +358,7 @@ func (b *beegoAppConfig) String(key string) string {
}
func (b *beegoAppConfig) Strings(key string) []string {
- if v := b.innerConfig.Strings(BConfig.RunMode + "::" + key); v[0] != "" {
+ if v := b.innerConfig.Strings(BConfig.RunMode + "::" + key); len(v) > 0 {
return v
}
return b.innerConfig.Strings(key)
diff --git a/config/config.go b/config/config.go
index c0afec05..9f41fb79 100644
--- a/config/config.go
+++ b/config/config.go
@@ -12,11 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package config is used to parse config
+// Package config is used to parse config.
// Usage:
-// import(
-// "github.com/astaxie/beego/config"
-// )
+// import "github.com/astaxie/beego/config"
+//Examples.
//
// cnf, err := config.NewConfig("ini", "config.conf")
//
@@ -38,12 +37,12 @@
// cnf.DIY(key string) (interface{}, error)
// cnf.GetSection(section string) (map[string]string, error)
// cnf.SaveConfigFile(filename string) error
-//
-// more docs http://beego.me/docs/module/config.md
+//More docs http://beego.me/docs/module/config.md
package config
import (
"fmt"
+ "os"
)
// Configer defines how to get and set value from configuration raw data.
@@ -107,6 +106,69 @@ func NewConfigData(adapterName string, data []byte) (Configer, error) {
return adapter.ParseData(data)
}
+// ExpandValueEnvForMap convert all string value with environment variable.
+func ExpandValueEnvForMap(m map[string]interface{}) map[string]interface{} {
+ for k, v := range m {
+ switch value := v.(type) {
+ case string:
+ m[k] = ExpandValueEnv(value)
+ case map[string]interface{}:
+ m[k] = ExpandValueEnvForMap(value)
+ case map[string]string:
+ for k2, v2 := range value {
+ value[k2] = ExpandValueEnv(v2)
+ }
+ m[k] = value
+ }
+ }
+ return m
+}
+
+// ExpandValueEnv returns value of convert with environment variable.
+//
+// Return environment variable if value start with "${" and end with "}".
+// Return default value if environment variable is empty or not exist.
+//
+// It accept value formats "${env}" , "${env||}}" , "${env||defaultValue}" , "defaultvalue".
+// Examples:
+// v1 := config.ExpandValueEnv("${GOPATH}") // return the GOPATH environment variable.
+// v2 := config.ExpandValueEnv("${GOAsta||/usr/local/go}") // return the default value "/usr/local/go/".
+// v3 := config.ExpandValueEnv("Astaxie") // return the value "Astaxie".
+func ExpandValueEnv(value string) (realValue string) {
+ realValue = value
+
+ vLen := len(value)
+ // 3 = ${}
+ if vLen < 3 {
+ return
+ }
+ // Need start with "${" and end with "}", then return.
+ if value[0] != '$' || value[1] != '{' || value[vLen-1] != '}' {
+ return
+ }
+
+ key := ""
+ defalutV := ""
+ // value start with "${"
+ for i := 2; i < vLen; i++ {
+ if value[i] == '|' && (i+1 < vLen && value[i+1] == '|') {
+ key = value[2:i]
+ defalutV = value[i+2 : vLen-1] // other string is default value.
+ break
+ } else if value[i] == '}' {
+ key = value[2:i]
+ break
+ }
+ }
+
+ realValue = os.Getenv(key)
+ if realValue == "" {
+ realValue = defalutV
+ }
+
+ return
+}
+
// ParseBool returns the boolean value represented by the string.
//
// It accepts 1, 1.0, t, T, TRUE, true, True, YES, yes, Yes,Y, y, ON, on, On,
diff --git a/config/config_test.go b/config/config_test.go
new file mode 100644
index 00000000..15d6ffa6
--- /dev/null
+++ b/config/config_test.go
@@ -0,0 +1,55 @@
+// Copyright 2016 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package config
+
+import (
+ "os"
+ "testing"
+)
+
+func TestExpandValueEnv(t *testing.T) {
+
+ testCases := []struct {
+ item string
+ want string
+ }{
+ {"", ""},
+ {"$", "$"},
+ {"{", "{"},
+ {"{}", "{}"},
+ {"${}", ""},
+ {"${|}", ""},
+ {"${}", ""},
+ {"${{}}", ""},
+ {"${{||}}", "}"},
+ {"${pwd||}", ""},
+ {"${pwd||}", ""},
+ {"${pwd||}", ""},
+ {"${pwd||}}", "}"},
+ {"${pwd||{{||}}}", "{{||}}"},
+ {"${GOPATH}", os.Getenv("GOPATH")},
+ {"${GOPATH||}", os.Getenv("GOPATH")},
+ {"${GOPATH||root}", os.Getenv("GOPATH")},
+ {"${GOPATH_NOT||root}", "root"},
+ {"${GOPATH_NOT||||root}", "||root"},
+ }
+
+ for _, c := range testCases {
+ if got := ExpandValueEnv(c.item); got != c.want {
+ t.Errorf("expand value error, item %q want %q, got %q", c.item, c.want, got)
+ }
+ }
+
+}
diff --git a/config/fake.go b/config/fake.go
index 7e362608..f5144598 100644
--- a/config/fake.go
+++ b/config/fake.go
@@ -38,7 +38,7 @@ func (c *fakeConfigContainer) String(key string) string {
}
func (c *fakeConfigContainer) DefaultString(key string, defaultval string) string {
- v := c.getData(key)
+ v := c.String(key)
if v == "" {
return defaultval
}
@@ -46,7 +46,7 @@ func (c *fakeConfigContainer) DefaultString(key string, defaultval string) strin
}
func (c *fakeConfigContainer) Strings(key string) []string {
- v := c.getData(key)
+ v := c.String(key)
if v == "" {
return nil
}
diff --git a/config/ini.go b/config/ini.go
index 9c19b9b1..53bd992d 100644
--- a/config/ini.go
+++ b/config/ini.go
@@ -82,6 +82,10 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
if err == io.EOF {
break
}
+ //It might be a good idea to throw a error on all unknonw errors?
+ if _, ok := err.(*os.PathError); ok {
+ return nil, err
+ }
if bytes.Equal(line, bEmpty) {
continue
}
@@ -162,7 +166,7 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
val = bytes.Trim(val, `"`)
}
- cfg.data[section][key] = string(val)
+ cfg.data[section][key] = ExpandValueEnv(string(val))
if comment.Len() > 0 {
cfg.keyComment[section+"."+key] = comment.String()
comment.Reset()
@@ -296,7 +300,9 @@ func (c *IniConfigContainer) GetSection(section string) (map[string]string, erro
return nil, errors.New("not exist setction")
}
-// SaveConfigFile save the config into file
+// SaveConfigFile save the config into file.
+//
+// BUG(env): The environment variable config item will be saved with real value in SaveConfigFile Funcation.
func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) {
// Write configuration file by filename.
f, err := os.Create(filename)
diff --git a/config/ini_test.go b/config/ini_test.go
index 93fce61f..83ff3668 100644
--- a/config/ini_test.go
+++ b/config/ini_test.go
@@ -42,11 +42,14 @@ needlogin = ON
enableSession = Y
enableCookie = N
flag = 1
+path1 = ${GOPATH}
+path2 = ${GOPATH||/home/go}
[demo]
key1="asta"
key2 = "xie"
CaseInsensitive = true
peers = one;two;three
+password = ${GOPATH}
`
keyValue = map[string]interface{}{
@@ -64,10 +67,13 @@ peers = one;two;three
"enableSession": true,
"enableCookie": false,
"flag": true,
+ "path1": os.Getenv("GOPATH"),
+ "path2": os.Getenv("GOPATH"),
"demo::key1": "asta",
"demo::key2": "xie",
"demo::CaseInsensitive": true,
"demo::peers": []string{"one", "two", "three"},
+ "demo::password": os.Getenv("GOPATH"),
"null": "",
"demo2::key1": "",
"error": "",
diff --git a/config/json.go b/config/json.go
index fce517eb..a0d93210 100644
--- a/config/json.go
+++ b/config/json.go
@@ -57,6 +57,9 @@ func (js *JSONConfig) ParseData(data []byte) (Configer, error) {
}
x.data["rootArray"] = wrappingArray
}
+
+ x.data = ExpandValueEnvForMap(x.data)
+
return x, nil
}
diff --git a/config/json_test.go b/config/json_test.go
index df663461..24ff9644 100644
--- a/config/json_test.go
+++ b/config/json_test.go
@@ -86,16 +86,19 @@ func TestJson(t *testing.T) {
"enableSession": "Y",
"enableCookie": "N",
"flag": 1,
+"path1": "${GOPATH}",
+"path2": "${GOPATH||/home/go}",
"database": {
"host": "host",
"port": "port",
"database": "database",
"username": "username",
- "password": "password",
+ "password": "${GOPATH}",
"conns":{
"maxconnection":12,
"autoconnect":true,
- "connectioninfo":"info"
+ "connectioninfo":"info",
+ "root": "${GOPATH}"
}
}
}`
@@ -115,13 +118,16 @@ func TestJson(t *testing.T) {
"enableSession": true,
"enableCookie": false,
"flag": true,
+ "path1": os.Getenv("GOPATH"),
+ "path2": os.Getenv("GOPATH"),
"database::host": "host",
"database::port": "port",
"database::database": "database",
- "database::password": "password",
+ "database::password": os.Getenv("GOPATH"),
"database::conns::maxconnection": 12,
"database::conns::autoconnect": true,
"database::conns::connectioninfo": "info",
+ "database::conns::root": os.Getenv("GOPATH"),
"unknown": "",
}
)
diff --git a/config/xml/xml.go b/config/xml/xml.go
index b5291bf4..0c4e4d27 100644
--- a/config/xml/xml.go
+++ b/config/xml/xml.go
@@ -12,21 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package xml for config provider
+// Package xml for config provider.
//
-// depend on github.com/beego/x2j
+// depend on github.com/beego/x2j.
//
-// go install github.com/beego/x2j
+// go install github.com/beego/x2j.
//
// Usage:
-// import(
-// _ "github.com/astaxie/beego/config/xml"
-// "github.com/astaxie/beego/config"
-// )
+// import(
+// _ "github.com/astaxie/beego/config/xml"
+// "github.com/astaxie/beego/config"
+// )
//
// cnf, err := config.NewConfig("xml", "config.xml")
//
-// more docs http://beego.me/docs/module/config.md
+//More docs http://beego.me/docs/module/config.md
package xml
import (
@@ -69,7 +69,7 @@ func (xc *Config) Parse(filename string) (config.Configer, error) {
return nil, err
}
- x.data = d["config"].(map[string]interface{})
+ x.data = config.ExpandValueEnvForMap(d["config"].(map[string]interface{}))
return x, nil
}
@@ -92,7 +92,7 @@ type ConfigContainer struct {
// Bool returns the boolean value for a given key.
func (c *ConfigContainer) Bool(key string) (bool, error) {
- if v, ok := c.data[key]; ok {
+ if v := c.data[key]; v != nil {
return config.ParseBool(v)
}
return false, fmt.Errorf("not exist key: %q", key)
diff --git a/config/xml/xml_test.go b/config/xml/xml_test.go
index 60dcba54..d8a09a59 100644
--- a/config/xml/xml_test.go
+++ b/config/xml/xml_test.go
@@ -15,14 +15,18 @@
package xml
import (
+ "fmt"
"os"
"testing"
"github.com/astaxie/beego/config"
)
-//xml parse should incluce in tags
-var xmlcontext = `
+func TestXML(t *testing.T) {
+
+ var (
+ //xml parse should incluce in tags
+ xmlcontext = `
beeapi8080
@@ -31,10 +35,25 @@ var xmlcontext = `
devfalsetrue
+${GOPATH}
+${GOPATH||/home/go}
`
+ keyValue = map[string]interface{}{
+ "appname": "beeapi",
+ "httpport": 8080,
+ "mysqlport": int64(3600),
+ "PI": 3.1415976,
+ "runmode": "dev",
+ "autorender": false,
+ "copyrequestbody": true,
+ "path1": os.Getenv("GOPATH"),
+ "path2": os.Getenv("GOPATH"),
+ "error": "",
+ "emptystrings": []string{},
+ }
+ )
-func TestXML(t *testing.T) {
f, err := os.Create("testxml.conf")
if err != nil {
t.Fatal(err)
@@ -50,39 +69,42 @@ func TestXML(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- if xmlconf.String("appname") != "beeapi" {
- t.Fatal("appname not equal to beeapi")
- }
- if port, err := xmlconf.Int("httpport"); err != nil || port != 8080 {
- t.Error(port)
- t.Fatal(err)
- }
- if port, err := xmlconf.Int64("mysqlport"); err != nil || port != 3600 {
- t.Error(port)
- t.Fatal(err)
- }
- if pi, err := xmlconf.Float("PI"); err != nil || pi != 3.1415976 {
- t.Error(pi)
- t.Fatal(err)
- }
- if xmlconf.String("runmode") != "dev" {
- t.Fatal("runmode not equal to dev")
- }
- if v, err := xmlconf.Bool("autorender"); err != nil || v != false {
- t.Error(v)
- t.Fatal(err)
- }
- if v, err := xmlconf.Bool("copyrequestbody"); err != nil || v != true {
- t.Error(v)
- t.Fatal(err)
+
+ for k, v := range keyValue {
+
+ var (
+ value interface{}
+ err error
+ )
+
+ switch v.(type) {
+ case int:
+ value, err = xmlconf.Int(k)
+ case int64:
+ value, err = xmlconf.Int64(k)
+ case float64:
+ value, err = xmlconf.Float(k)
+ case bool:
+ value, err = xmlconf.Bool(k)
+ case []string:
+ value = xmlconf.Strings(k)
+ case string:
+ value = xmlconf.String(k)
+ default:
+ value, err = xmlconf.DIY(k)
+ }
+ if err != nil {
+ t.Errorf("get key %q value fatal,%v err %s", k, v, err)
+ } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) {
+ t.Errorf("get key %q value, want %v got %v .", k, v, value)
+ }
+
}
+
if err = xmlconf.Set("name", "astaxie"); err != nil {
t.Fatal(err)
}
if xmlconf.String("name") != "astaxie" {
t.Fatal("get name error")
}
- if xmlconf.Strings("emptystrings") != nil {
- t.Fatal("get emtpy strings error")
- }
}
diff --git a/config/yaml/yaml.go b/config/yaml/yaml.go
index 7e1d0426..64e25cb3 100644
--- a/config/yaml/yaml.go
+++ b/config/yaml/yaml.go
@@ -19,14 +19,14 @@
// go install github.com/beego/goyaml2
//
// Usage:
-// import(
+// import(
// _ "github.com/astaxie/beego/config/yaml"
-// "github.com/astaxie/beego/config"
-// )
+// "github.com/astaxie/beego/config"
+// )
//
// cnf, err := config.NewConfig("yaml", "config.yaml")
//
-// more docs http://beego.me/docs/module/config.md
+//More docs http://beego.me/docs/module/config.md
package yaml
import (
@@ -110,6 +110,7 @@ func ReadYmlReader(path string) (cnf map[string]interface{}, err error) {
log.Println("Not a Map? >> ", string(buf), data)
cnf = nil
}
+ cnf = config.ExpandValueEnvForMap(cnf)
return
}
@@ -121,10 +122,11 @@ type ConfigContainer struct {
// Bool returns the boolean value for a given key.
func (c *ConfigContainer) Bool(key string) (bool, error) {
- if v, ok := c.data[key]; ok {
- return config.ParseBool(v)
+ v, err := c.getData(key)
+ if err != nil {
+ return false, err
}
- return false, fmt.Errorf("not exist key: %q", key)
+ return config.ParseBool(v)
}
// DefaultBool return the bool value if has no error
@@ -139,8 +141,12 @@ func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool {
// Int returns the integer value for a given key.
func (c *ConfigContainer) Int(key string) (int, error) {
- if v, ok := c.data[key].(int64); ok {
- return int(v), nil
+ if v, err := c.getData(key); err != nil {
+ return 0, err
+ } else if vv, ok := v.(int); ok {
+ return vv, nil
+ } else if vv, ok := v.(int64); ok {
+ return int(vv), nil
}
return 0, errors.New("not int value")
}
@@ -157,8 +163,10 @@ func (c *ConfigContainer) DefaultInt(key string, defaultval int) int {
// Int64 returns the int64 value for a given key.
func (c *ConfigContainer) Int64(key string) (int64, error) {
- if v, ok := c.data[key].(int64); ok {
- return v, nil
+ if v, err := c.getData(key); err != nil {
+ return 0, err
+ } else if vv, ok := v.(int64); ok {
+ return vv, nil
}
return 0, errors.New("not bool value")
}
@@ -175,8 +183,14 @@ func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
// Float returns the float value for a given key.
func (c *ConfigContainer) Float(key string) (float64, error) {
- if v, ok := c.data[key].(float64); ok {
- return v, nil
+ if v, err := c.getData(key); err != nil {
+ return 0.0, err
+ } else if vv, ok := v.(float64); ok {
+ return vv, nil
+ } else if vv, ok := v.(int); ok {
+ return float64(vv), nil
+ } else if vv, ok := v.(int64); ok {
+ return float64(vv), nil
}
return 0.0, errors.New("not float64 value")
}
@@ -193,8 +207,10 @@ func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
// String returns the string value for a given key.
func (c *ConfigContainer) String(key string) string {
- if v, ok := c.data[key].(string); ok {
- return v
+ if v, err := c.getData(key); err == nil {
+ if vv, ok := v.(string); ok {
+ return vv
+ }
}
return ""
}
@@ -230,8 +246,8 @@ func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []stri
// GetSection returns map for the given section
func (c *ConfigContainer) GetSection(section string) (map[string]string, error) {
- v, ok := c.data[section]
- if ok {
+
+ if v, ok := c.data[section]; ok {
return v.(map[string]string), nil
}
return nil, errors.New("not exist setction")
@@ -259,10 +275,19 @@ func (c *ConfigContainer) Set(key, val string) error {
// DIY returns the raw value by a given key.
func (c *ConfigContainer) DIY(key string) (v interface{}, err error) {
+ return c.getData(key)
+}
+
+func (c *ConfigContainer) getData(key string) (interface{}, error) {
+
+ if len(key) == 0 {
+ return nil, errors.New("key is emtpy")
+ }
+
if v, ok := c.data[key]; ok {
return v, nil
}
- return nil, errors.New("not exist key")
+ return nil, fmt.Errorf("not exist key %q", key)
}
func init() {
diff --git a/config/yaml/yaml_test.go b/config/yaml/yaml_test.go
index 80cbb8fe..49cc1d1e 100644
--- a/config/yaml/yaml_test.go
+++ b/config/yaml/yaml_test.go
@@ -15,13 +15,17 @@
package yaml
import (
+ "fmt"
"os"
"testing"
"github.com/astaxie/beego/config"
)
-var yamlcontext = `
+func TestYaml(t *testing.T) {
+
+ var (
+ yamlcontext = `
"appname": beeapi
"httpport": 8080
"mysqlport": 3600
@@ -29,9 +33,27 @@ var yamlcontext = `
"runmode": dev
"autorender": false
"copyrequestbody": true
+"PATH": GOPATH
+"path1": ${GOPATH}
+"path2": ${GOPATH||/home/go}
+"empty": ""
`
-func TestYaml(t *testing.T) {
+ keyValue = map[string]interface{}{
+ "appname": "beeapi",
+ "httpport": 8080,
+ "mysqlport": int64(3600),
+ "PI": 3.1415976,
+ "runmode": "dev",
+ "autorender": false,
+ "copyrequestbody": true,
+ "PATH": "GOPATH",
+ "path1": os.Getenv("GOPATH"),
+ "path2": os.Getenv("GOPATH"),
+ "error": "",
+ "emptystrings": []string{},
+ }
+ )
f, err := os.Create("testyaml.conf")
if err != nil {
t.Fatal(err)
@@ -47,32 +69,42 @@ func TestYaml(t *testing.T) {
if err != nil {
t.Fatal(err)
}
+
if yamlconf.String("appname") != "beeapi" {
t.Fatal("appname not equal to beeapi")
}
- if port, err := yamlconf.Int("httpport"); err != nil || port != 8080 {
- t.Error(port)
- t.Fatal(err)
- }
- if port, err := yamlconf.Int64("mysqlport"); err != nil || port != 3600 {
- t.Error(port)
- t.Fatal(err)
- }
- if pi, err := yamlconf.Float("PI"); err != nil || pi != 3.1415976 {
- t.Error(pi)
- t.Fatal(err)
- }
- if yamlconf.String("runmode") != "dev" {
- t.Fatal("runmode not equal to dev")
- }
- if v, err := yamlconf.Bool("autorender"); err != nil || v != false {
- t.Error(v)
- t.Fatal(err)
- }
- if v, err := yamlconf.Bool("copyrequestbody"); err != nil || v != true {
- t.Error(v)
- t.Fatal(err)
+
+ for k, v := range keyValue {
+
+ var (
+ value interface{}
+ err error
+ )
+
+ switch v.(type) {
+ case int:
+ value, err = yamlconf.Int(k)
+ case int64:
+ value, err = yamlconf.Int64(k)
+ case float64:
+ value, err = yamlconf.Float(k)
+ case bool:
+ value, err = yamlconf.Bool(k)
+ case []string:
+ value = yamlconf.Strings(k)
+ case string:
+ value = yamlconf.String(k)
+ default:
+ value, err = yamlconf.DIY(k)
+ }
+ if err != nil {
+ t.Errorf("get key %q value fatal,%v err %s", k, v, err)
+ } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) {
+ t.Errorf("get key %q value, want %v got %v .", k, v, value)
+ }
+
}
+
if err = yamlconf.Set("name", "astaxie"); err != nil {
t.Fatal(err)
}
@@ -80,7 +112,4 @@ func TestYaml(t *testing.T) {
t.Fatal("get name error")
}
- if yamlconf.Strings("emptystrings") != nil {
- t.Fatal("get emtpy strings error")
- }
}
diff --git a/config_test.go b/config_test.go
index cf4a781d..c9576afd 100644
--- a/config_test.go
+++ b/config_test.go
@@ -15,7 +15,11 @@
package beego
import (
+ "encoding/json"
+ "reflect"
"testing"
+
+ "github.com/astaxie/beego/config"
)
func TestDefaults(t *testing.T) {
@@ -27,3 +31,109 @@ func TestDefaults(t *testing.T) {
t.Errorf("FlashName was not set to default.")
}
}
+
+func TestAssignConfig_01(t *testing.T) {
+ _BConfig := &Config{}
+ _BConfig.AppName = "beego_test"
+ jcf := &config.JSONConfig{}
+ ac, _ := jcf.ParseData([]byte(`{"AppName":"beego_json"}`))
+ assignSingleConfig(_BConfig, ac)
+ if _BConfig.AppName != "beego_json" {
+ t.Log(_BConfig)
+ t.FailNow()
+ }
+}
+
+func TestAssignConfig_02(t *testing.T) {
+ _BConfig := &Config{}
+ bs, _ := json.Marshal(newBConfig())
+
+ jsonMap := map[string]interface{}{}
+ json.Unmarshal(bs, &jsonMap)
+
+ configMap := map[string]interface{}{}
+ for k, v := range jsonMap {
+ if reflect.TypeOf(v).Kind() == reflect.Map {
+ for k1, v1 := range v.(map[string]interface{}) {
+ if reflect.TypeOf(v1).Kind() == reflect.Map {
+ for k2, v2 := range v1.(map[string]interface{}) {
+ configMap[k2] = v2
+ }
+ } else {
+ configMap[k1] = v1
+ }
+ }
+ } else {
+ configMap[k] = v
+ }
+ }
+ configMap["MaxMemory"] = 1024
+ configMap["Graceful"] = true
+ configMap["XSRFExpire"] = 32
+ configMap["SessionProviderConfig"] = "file"
+ configMap["FileLineNum"] = true
+
+ jcf := &config.JSONConfig{}
+ bs, _ = json.Marshal(configMap)
+ ac, _ := jcf.ParseData([]byte(bs))
+
+ for _, i := range []interface{}{_BConfig, &_BConfig.Listen, &_BConfig.WebConfig, &_BConfig.Log, &_BConfig.WebConfig.Session} {
+ assignSingleConfig(i, ac)
+ }
+
+ if _BConfig.MaxMemory != 1024 {
+ t.Log(_BConfig.MaxMemory)
+ t.FailNow()
+ }
+
+ if !_BConfig.Listen.Graceful {
+ t.Log(_BConfig.Listen.Graceful)
+ t.FailNow()
+ }
+
+ if _BConfig.WebConfig.XSRFExpire != 32 {
+ t.Log(_BConfig.WebConfig.XSRFExpire)
+ t.FailNow()
+ }
+
+ if _BConfig.WebConfig.Session.SessionProviderConfig != "file" {
+ t.Log(_BConfig.WebConfig.Session.SessionProviderConfig)
+ t.FailNow()
+ }
+
+ if !_BConfig.Log.FileLineNum {
+ t.Log(_BConfig.Log.FileLineNum)
+ t.FailNow()
+ }
+
+}
+
+func TestAssignConfig_03(t *testing.T) {
+ jcf := &config.JSONConfig{}
+ ac, _ := jcf.ParseData([]byte(`{"AppName":"beego"}`))
+ ac.Set("AppName", "test_app")
+ ac.Set("RunMode", "online")
+ ac.Set("StaticDir", "download:down download2:down2")
+ ac.Set("StaticExtensionsToGzip", ".css,.js,.html,.jpg,.png")
+ assignConfig(ac)
+
+
+ t.Logf("%#v",BConfig)
+
+ if BConfig.AppName != "test_app" {
+ t.FailNow()
+ }
+
+ if BConfig.RunMode != "online" {
+ t.FailNow()
+ }
+ if BConfig.WebConfig.StaticDir["/download"] != "down" {
+ t.FailNow()
+ }
+ if BConfig.WebConfig.StaticDir["/download2"] != "down2" {
+ t.FailNow()
+ }
+ if len(BConfig.WebConfig.StaticExtensionsToGzip) != 5 {
+ t.FailNow()
+ }
+}
diff --git a/context/acceptencoder.go b/context/acceptencoder.go
index 033d9ca8..cb735445 100644
--- a/context/acceptencoder.go
+++ b/context/acceptencoder.go
@@ -27,6 +27,33 @@ import (
"sync"
)
+var (
+ //Default size==20B same as nginx
+ defaultGzipMinLength = 20
+ //Content will only be compressed if content length is either unknown or greater than gzipMinLength.
+ gzipMinLength = defaultGzipMinLength
+ //The compression level used for deflate compression. (0-9).
+ gzipCompressLevel int
+ //List of HTTP methods to compress. If not set, only GET requests are compressed.
+ includedMethods map[string]bool
+ getMethodOnly bool
+)
+
+func InitGzip(minLength, compressLevel int, methods []string) {
+ if minLength >= 0 {
+ gzipMinLength = minLength
+ }
+ gzipCompressLevel = compressLevel
+ if gzipCompressLevel < flate.NoCompression || gzipCompressLevel > flate.BestCompression {
+ gzipCompressLevel = flate.BestSpeed
+ }
+ getMethodOnly = (len(methods) == 0) || (len(methods) == 1 && strings.ToUpper(methods[0]) == "GET")
+ includedMethods = make(map[string]bool, len(methods))
+ for _, v := range methods {
+ includedMethods[strings.ToUpper(v)] = true
+ }
+}
+
type resetWriter interface {
io.Writer
Reset(w io.Writer)
@@ -41,20 +68,20 @@ func (n nopResetWriter) Reset(w io.Writer) {
}
type acceptEncoder struct {
- name string
- levelEncode func(int) resetWriter
- bestSpeedPool *sync.Pool
- bestCompressionPool *sync.Pool
+ name string
+ levelEncode func(int) resetWriter
+ customCompressLevelPool *sync.Pool
+ bestCompressionPool *sync.Pool
}
func (ac acceptEncoder) encode(wr io.Writer, level int) resetWriter {
- if ac.bestSpeedPool == nil || ac.bestCompressionPool == nil {
+ if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil {
return nopResetWriter{wr}
}
var rwr resetWriter
switch level {
case flate.BestSpeed:
- rwr = ac.bestSpeedPool.Get().(resetWriter)
+ rwr = ac.customCompressLevelPool.Get().(resetWriter)
case flate.BestCompression:
rwr = ac.bestCompressionPool.Get().(resetWriter)
default:
@@ -65,13 +92,18 @@ func (ac acceptEncoder) encode(wr io.Writer, level int) resetWriter {
}
func (ac acceptEncoder) put(wr resetWriter, level int) {
- if ac.bestSpeedPool == nil || ac.bestCompressionPool == nil {
+ if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil {
return
}
wr.Reset(nil)
+
+ //notice
+ //compressionLevel==BestCompression DOES NOT MATTER
+ //sync.Pool will not memory leak
+
switch level {
- case flate.BestSpeed:
- ac.bestSpeedPool.Put(wr)
+ case gzipCompressLevel:
+ ac.customCompressLevelPool.Put(wr)
case flate.BestCompression:
ac.bestCompressionPool.Put(wr)
}
@@ -79,28 +111,22 @@ func (ac acceptEncoder) put(wr resetWriter, level int) {
var (
noneCompressEncoder = acceptEncoder{"", nil, nil, nil}
- gzipCompressEncoder = acceptEncoder{"gzip",
- func(level int) resetWriter { wr, _ := gzip.NewWriterLevel(nil, level); return wr },
- &sync.Pool{
- New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestSpeed); return wr },
- },
- &sync.Pool{
- New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestCompression); return wr },
- },
+ gzipCompressEncoder = acceptEncoder{
+ name: "gzip",
+ levelEncode: func(level int) resetWriter { wr, _ := gzip.NewWriterLevel(nil, level); return wr },
+ customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, gzipCompressLevel); return wr }},
+ bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestCompression); return wr }},
}
//according to the sec :http://tools.ietf.org/html/rfc2616#section-3.5 ,the deflate compress in http is zlib indeed
//deflate
//The "zlib" format defined in RFC 1950 [31] in combination with
//the "deflate" compression mechanism described in RFC 1951 [29].
- deflateCompressEncoder = acceptEncoder{"deflate",
- func(level int) resetWriter { wr, _ := zlib.NewWriterLevel(nil, level); return wr },
- &sync.Pool{
- New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestSpeed); return wr },
- },
- &sync.Pool{
- New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestCompression); return wr },
- },
+ deflateCompressEncoder = acceptEncoder{
+ name: "deflate",
+ levelEncode: func(level int) resetWriter { wr, _ := zlib.NewWriterLevel(nil, level); return wr },
+ customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, gzipCompressLevel); return wr }},
+ bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestCompression); return wr }},
}
)
@@ -120,7 +146,11 @@ func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string,
// WriteBody reads writes content to writer by the specific encoding(gzip/deflate)
func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) {
- return writeLevel(encoding, writer, bytes.NewReader(content), flate.BestSpeed)
+ if encoding == "" || len(content) < gzipMinLength {
+ _, err := writer.Write(content)
+ return false, "", err
+ }
+ return writeLevel(encoding, writer, bytes.NewReader(content), gzipCompressLevel)
}
// writeLevel reads from reader,writes to writer by specific encoding and compress level
@@ -156,7 +186,10 @@ func ParseEncoding(r *http.Request) string {
if r == nil {
return ""
}
- return parseEncoding(r)
+ if (getMethodOnly && r.Method == "GET") || includedMethods[r.Method] {
+ return parseEncoding(r)
+ }
+ return ""
}
type q struct {
diff --git a/context/context.go b/context/context.go
index 63a1313d..03286097 100644
--- a/context/context.go
+++ b/context/context.go
@@ -24,13 +24,11 @@ package context
import (
"bufio"
- "bytes"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"errors"
"fmt"
- "io"
"net"
"net/http"
"strconv"
@@ -67,18 +65,18 @@ func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) {
ctx.ResponseWriter.reset(rw)
ctx.Input.Reset(ctx)
ctx.Output.Reset(ctx)
+ ctx._xsrfToken = ""
}
// Redirect does redirection to localurl with http header status code.
-// It sends http response header directly.
func (ctx *Context) Redirect(status int, localurl string) {
- ctx.Output.Header("Location", localurl)
- ctx.ResponseWriter.WriteHeader(status)
+ http.Redirect(ctx.ResponseWriter, ctx.Request, localurl, status)
}
// Abort stops this request.
// if beego.ErrorMaps exists, panic body.
func (ctx *Context) Abort(status int, body string) {
+ ctx.Output.SetStatus(status)
panic(body)
}
@@ -195,14 +193,6 @@ func (r *Response) Write(p []byte) (int, error) {
return r.ResponseWriter.Write(p)
}
-// Copy writes the data to the connection as part of an HTTP reply,
-// and sets `started` to true.
-// started means the response has sent out.
-func (r *Response) Copy(buf *bytes.Buffer) (int64, error) {
- r.Started = true
- return io.Copy(r.ResponseWriter, buf)
-}
-
// WriteHeader sends an HTTP response header with status code,
// and sets `started` to true.
func (r *Response) WriteHeader(code int) {
diff --git a/context/context_test.go b/context/context_test.go
new file mode 100644
index 00000000..7c0535e0
--- /dev/null
+++ b/context/context_test.go
@@ -0,0 +1,47 @@
+// Copyright 2016 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package context
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+)
+
+func TestXsrfReset_01(t *testing.T) {
+ r := &http.Request{}
+ c := NewContext()
+ c.Request = r
+ c.ResponseWriter = &Response{}
+ c.ResponseWriter.reset(httptest.NewRecorder())
+ c.Output.Reset(c)
+ c.Input.Reset(c)
+ c.XSRFToken("key", 16)
+ if c._xsrfToken == "" {
+ t.FailNow()
+ }
+ token := c._xsrfToken
+ c.Reset(&Response{ResponseWriter: httptest.NewRecorder()}, r)
+ if c._xsrfToken != "" {
+ t.FailNow()
+ }
+ c.XSRFToken("key", 16)
+ if c._xsrfToken == "" {
+ t.FailNow()
+ }
+ if token == c._xsrfToken {
+ t.FailNow()
+ }
+}
diff --git a/context/input.go b/context/input.go
index edfdf530..c47996c9 100644
--- a/context/input.go
+++ b/context/input.go
@@ -89,6 +89,9 @@ func (input *BeegoInput) Site() string {
// Scheme returns request scheme as "http" or "https".
func (input *BeegoInput) Scheme() string {
+ if scheme := input.Header("X-Forwarded-Proto"); scheme != "" {
+ return scheme
+ }
if input.Context.Request.URL.Scheme != "" {
return input.Context.Request.URL.Scheme
}
diff --git a/context/input_test.go b/context/input_test.go
index 24f6fd99..8887aec4 100644
--- a/context/input_test.go
+++ b/context/input_test.go
@@ -100,7 +100,7 @@ func TestSubDomain(t *testing.T) {
/* TODO Fix this
r, _ = http.NewRequest("GET", "http://127.0.0.1/", nil)
- beegoInput.Request = r
+ beegoInput.Context.Request = r
if beegoInput.SubDomains() != "" {
t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains())
}
diff --git a/context/output.go b/context/output.go
index 17404702..e1ad23e0 100644
--- a/context/output.go
+++ b/context/output.go
@@ -21,8 +21,11 @@ import (
"errors"
"fmt"
"html/template"
+ "io"
"mime"
"net/http"
+ "net/url"
+ "os"
"path/filepath"
"strconv"
"strings"
@@ -72,10 +75,11 @@ func (output *BeegoOutput) Body(content []byte) error {
if output.Status != 0 {
output.Context.ResponseWriter.WriteHeader(output.Status)
output.Status = 0
+ } else {
+ output.Context.ResponseWriter.Started = true
}
-
- _, err := output.Context.ResponseWriter.Copy(buf)
- return err
+ io.Copy(output.Context.ResponseWriter, buf)
+ return nil
}
// Cookie sets cookie value via given key.
@@ -235,13 +239,21 @@ func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error {
// Download forces response for download file.
// it prepares the download response header automatically.
func (output *BeegoOutput) Download(file string, filename ...string) {
+ // check get file error, file not found or other error.
+ if _, err := os.Stat(file); err != nil {
+ http.ServeFile(output.Context.ResponseWriter, output.Context.Request, file)
+ return
+ }
+
+ var fName string
+ if len(filename) > 0 && filename[0] != "" {
+ fName = filename[0]
+ } else {
+ fName = filepath.Base(file)
+ }
+ output.Header("Content-Disposition", "attachment; filename="+url.QueryEscape(fName))
output.Header("Content-Description", "File Transfer")
output.Header("Content-Type", "application/octet-stream")
- if len(filename) > 0 && filename[0] != "" {
- output.Header("Content-Disposition", "attachment; filename="+filename[0])
- } else {
- output.Header("Content-Disposition", "attachment; filename="+filepath.Base(file))
- }
output.Header("Content-Transfer-Encoding", "binary")
output.Header("Expires", "0")
output.Header("Cache-Control", "must-revalidate")
@@ -269,55 +281,55 @@ func (output *BeegoOutput) SetStatus(status int) {
// IsCachable returns boolean of this request is cached.
// HTTP 304 means cached.
-func (output *BeegoOutput) IsCachable(status int) bool {
+func (output *BeegoOutput) IsCachable() bool {
return output.Status >= 200 && output.Status < 300 || output.Status == 304
}
// IsEmpty returns boolean of this request is empty.
// HTTP 201,204 and 304 means empty.
-func (output *BeegoOutput) IsEmpty(status int) bool {
+func (output *BeegoOutput) IsEmpty() bool {
return output.Status == 201 || output.Status == 204 || output.Status == 304
}
// IsOk returns boolean of this request runs well.
// HTTP 200 means ok.
-func (output *BeegoOutput) IsOk(status int) bool {
+func (output *BeegoOutput) IsOk() bool {
return output.Status == 200
}
// IsSuccessful returns boolean of this request runs successfully.
// HTTP 2xx means ok.
-func (output *BeegoOutput) IsSuccessful(status int) bool {
+func (output *BeegoOutput) IsSuccessful() bool {
return output.Status >= 200 && output.Status < 300
}
// IsRedirect returns boolean of this request is redirection header.
// HTTP 301,302,307 means redirection.
-func (output *BeegoOutput) IsRedirect(status int) bool {
+func (output *BeegoOutput) IsRedirect() bool {
return output.Status == 301 || output.Status == 302 || output.Status == 303 || output.Status == 307
}
// IsForbidden returns boolean of this request is forbidden.
// HTTP 403 means forbidden.
-func (output *BeegoOutput) IsForbidden(status int) bool {
+func (output *BeegoOutput) IsForbidden() bool {
return output.Status == 403
}
// IsNotFound returns boolean of this request is not found.
// HTTP 404 means forbidden.
-func (output *BeegoOutput) IsNotFound(status int) bool {
+func (output *BeegoOutput) IsNotFound() bool {
return output.Status == 404
}
// IsClientError returns boolean of this request client sends error data.
// HTTP 4xx means forbidden.
-func (output *BeegoOutput) IsClientError(status int) bool {
+func (output *BeegoOutput) IsClientError() bool {
return output.Status >= 400 && output.Status < 500
}
// IsServerError returns boolean of this server handler errors.
// HTTP 5xx means server internal error.
-func (output *BeegoOutput) IsServerError(status int) bool {
+func (output *BeegoOutput) IsServerError() bool {
return output.Status >= 500 && output.Status < 600
}
diff --git a/controller.go b/controller.go
index 85894275..3a9d1618 100644
--- a/controller.go
+++ b/controller.go
@@ -208,7 +208,7 @@ func (c *Controller) RenderBytes() ([]byte, error) {
continue
}
buf.Reset()
- err = executeTemplate(&buf, sectionTpl, c.Data)
+ err = ExecuteTemplate(&buf, sectionTpl, c.Data)
if err != nil {
return nil, err
}
@@ -217,7 +217,7 @@ func (c *Controller) RenderBytes() ([]byte, error) {
}
buf.Reset()
- executeTemplate(&buf, c.Layout, c.Data)
+ ExecuteTemplate(&buf, c.Layout, c.Data)
}
return buf.Bytes(), err
}
@@ -242,7 +242,7 @@ func (c *Controller) renderTemplate() (bytes.Buffer, error) {
}
BuildTemplate(BConfig.WebConfig.ViewsPath, buildFiles...)
}
- return buf, executeTemplate(&buf, c.TplName, c.Data)
+ return buf, ExecuteTemplate(&buf, c.TplName, c.Data)
}
// Redirect sends the redirection response to url with status code.
@@ -261,12 +261,13 @@ func (c *Controller) Abort(code string) {
// CustomAbort stops controller handler and show the error data, it's similar Aborts, but support status code and body.
func (c *Controller) CustomAbort(status int, body string) {
- c.Ctx.Output.Status = status
- // first panic from ErrorMaps, is is user defined error functions.
+ // first panic from ErrorMaps, it is user defined error functions.
if _, ok := ErrorMaps[body]; ok {
+ c.Ctx.Output.Status = status
panic(body)
}
// last panic user string
+ c.Ctx.ResponseWriter.WriteHeader(status)
c.Ctx.ResponseWriter.Write([]byte(body))
panic(ErrAbort)
}
diff --git a/error.go b/error.go
index 4f48fab2..bad08d86 100644
--- a/error.go
+++ b/error.go
@@ -210,159 +210,139 @@ var ErrorMaps = make(map[string]*errorInfo, 10)
// show 401 unauthorized error.
func unauthorized(rw http.ResponseWriter, r *http.Request) {
- t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := map[string]interface{}{
- "Title": http.StatusText(401),
- "BeegoVersion": VERSION,
- }
- data["Content"] = template.HTML(" The page you have requested can't be authorized." +
- " Perhaps you are here because:" +
- "
" +
- " The credentials you supplied are incorrect" +
- " There are errors in the website address" +
- "
")
- t.Execute(rw, data)
+ responseError(rw, r,
+ 401,
+ " The page you have requested can't be authorized."+
+ " Perhaps you are here because:"+
+ "
"+
+ " The credentials you supplied are incorrect"+
+ " There are errors in the website address"+
+ "
",
+ )
}
// show 402 Payment Required
func paymentRequired(rw http.ResponseWriter, r *http.Request) {
- t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := map[string]interface{}{
- "Title": http.StatusText(402),
- "BeegoVersion": VERSION,
- }
- data["Content"] = template.HTML(" The page you have requested Payment Required." +
- " Perhaps you are here because:" +
- "
" +
- " The credentials you supplied are incorrect" +
- " There are errors in the website address" +
- "
")
- t.Execute(rw, data)
+ responseError(rw, r,
+ 402,
+ " The page you have requested Payment Required."+
+ " Perhaps you are here because:"+
+ "
"+
+ " The credentials you supplied are incorrect"+
+ " There are errors in the website address"+
+ "
",
+ )
}
// show 403 forbidden error.
func forbidden(rw http.ResponseWriter, r *http.Request) {
- t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := map[string]interface{}{
- "Title": http.StatusText(403),
- "BeegoVersion": VERSION,
- }
- data["Content"] = template.HTML(" The page you have requested is forbidden." +
- " Perhaps you are here because:" +
- "
" +
- " Your address may be blocked" +
- " The site may be disabled" +
- " You need to log in" +
- "
")
- t.Execute(rw, data)
+ responseError(rw, r,
+ 403,
+ " The page you have requested is forbidden."+
+ " Perhaps you are here because:"+
+ "
"+
+ " Your address may be blocked"+
+ " The site may be disabled"+
+ " You need to log in"+
+ "
",
+ )
}
-// show 404 notfound error.
+// show 404 not found error.
func notFound(rw http.ResponseWriter, r *http.Request) {
- t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := map[string]interface{}{
- "Title": http.StatusText(404),
- "BeegoVersion": VERSION,
- }
- data["Content"] = template.HTML(" The page you have requested has flown the coop." +
- " Perhaps you are here because:" +
- "
" +
- " The page has moved" +
- " The page no longer exists" +
- " You were looking for your puppy and got lost" +
- " You like 404 pages" +
- "
")
- t.Execute(rw, data)
+ responseError(rw, r,
+ 404,
+ " The page you have requested has flown the coop."+
+ " Perhaps you are here because:"+
+ "
"+
+ " The page has moved"+
+ " The page no longer exists"+
+ " You were looking for your puppy and got lost"+
+ " You like 404 pages"+
+ "
",
+ )
}
// show 405 Method Not Allowed
func methodNotAllowed(rw http.ResponseWriter, r *http.Request) {
- t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := map[string]interface{}{
- "Title": http.StatusText(405),
- "BeegoVersion": VERSION,
- }
- data["Content"] = template.HTML(" The method you have requested Not Allowed." +
- " Perhaps you are here because:" +
- "
" +
- " The method specified in the Request-Line is not allowed for the resource identified by the Request-URI" +
- " The response MUST include an Allow header containing a list of valid methods for the requested resource." +
- "
")
- t.Execute(rw, data)
+ responseError(rw, r,
+ 405,
+ " The method you have requested Not Allowed."+
+ " Perhaps you are here because:"+
+ "
"+
+ " The method specified in the Request-Line is not allowed for the resource identified by the Request-URI"+
+ " The response MUST include an Allow header containing a list of valid methods for the requested resource."+
+ "
",
+ )
}
// show 500 internal server error.
func internalServerError(rw http.ResponseWriter, r *http.Request) {
- t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := map[string]interface{}{
- "Title": http.StatusText(500),
- "BeegoVersion": VERSION,
- }
- data["Content"] = template.HTML(" The page you have requested is down right now." +
- "
" +
- " Please try again later and report the error to the website administrator" +
- "
")
- t.Execute(rw, data)
+ responseError(rw, r,
+ 500,
+ " The page you have requested is down right now."+
+ "
"+
+ " Please try again later and report the error to the website administrator"+
+ "
",
+ )
}
// show 501 Not Implemented.
func notImplemented(rw http.ResponseWriter, r *http.Request) {
- t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := map[string]interface{}{
- "Title": http.StatusText(504),
- "BeegoVersion": VERSION,
- }
- data["Content"] = template.HTML(" The page you have requested is Not Implemented." +
- "
" +
- " Please try again later and report the error to the website administrator" +
- "
")
- t.Execute(rw, data)
+ responseError(rw, r,
+ 501,
+ " The page you have requested is Not Implemented."+
+ "
"+
+ " Please try again later and report the error to the website administrator"+
+ "
",
+ )
}
// show 502 Bad Gateway.
func badGateway(rw http.ResponseWriter, r *http.Request) {
- t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := map[string]interface{}{
- "Title": http.StatusText(502),
- "BeegoVersion": VERSION,
- }
- data["Content"] = template.HTML(" The page you have requested is down right now." +
- "
" +
- " The server, while acting as a gateway or proxy, received an invalid response from the upstream server it accessed in attempting to fulfill the request." +
- " Please try again later and report the error to the website administrator" +
- "
")
- t.Execute(rw, data)
+ responseError(rw, r,
+ 502,
+ " The page you have requested is down right now."+
+ "
"+
+ " The server, while acting as a gateway or proxy, received an invalid response from the upstream server it accessed in attempting to fulfill the request."+
+ " Please try again later and report the error to the website administrator"+
+ "
",
+ )
}
// show 503 service unavailable error.
func serviceUnavailable(rw http.ResponseWriter, r *http.Request) {
- t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := map[string]interface{}{
- "Title": http.StatusText(503),
- "BeegoVersion": VERSION,
- }
- data["Content"] = template.HTML(" The page you have requested is unavailable." +
- " Perhaps you are here because:" +
- "
" +
- "
The page is overloaded" +
- " Please try again later." +
- "
")
- t.Execute(rw, data)
+ responseError(rw, r,
+ 503,
+ " The page you have requested is unavailable."+
+ " Perhaps you are here because:"+
+ "
"+
+ "
The page is overloaded"+
+ " Please try again later."+
+ "
",
+ )
}
// show 504 Gateway Timeout.
func gatewayTimeout(rw http.ResponseWriter, r *http.Request) {
+ responseError(rw, r,
+ 504,
+ " The page you have requested is unavailable"+
+ " Perhaps you are here because:"+
+ "
"+
+ "
The server, while acting as a gateway or proxy, did not receive a timely response from the upstream server specified by the URI."+
+ " Please try again later."+
+ "
",
+ )
+}
+
+func responseError(rw http.ResponseWriter, r *http.Request, errCode int, errContent string) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := map[string]interface{}{
- "Title": http.StatusText(504),
+ "Title": http.StatusText(errCode),
"BeegoVersion": VERSION,
+ "Content": template.HTML(errContent),
}
- data["Content"] = template.HTML(" The page you have requested is unavailable." +
- " Perhaps you are here because:" +
- "
" +
- "
The server, while acting as a gateway or proxy, did not receive a timely response from the upstream server specified by the URI." +
- " Please try again later." +
- "
")
t.Execute(rw, data)
}
@@ -408,7 +388,10 @@ func exception(errCode string, ctx *context.Context) {
if err == nil {
return v
}
- return 503
+ if ctx.Output.Status == 0 {
+ return 503
+ }
+ return ctx.Output.Status
}
for _, ec := range []string{errCode, "503", "500"} {
diff --git a/error_test.go b/error_test.go
new file mode 100644
index 00000000..2fb8f962
--- /dev/null
+++ b/error_test.go
@@ -0,0 +1,88 @@
+// Copyright 2016 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package beego
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "strconv"
+ "strings"
+ "testing"
+)
+
+type errorTestController struct {
+ Controller
+}
+
+const parseCodeError = "parse code error"
+
+func (ec *errorTestController) Get() {
+ errorCode, err := ec.GetInt("code")
+ if err != nil {
+ ec.Abort(parseCodeError)
+ }
+ if errorCode != 0 {
+ ec.CustomAbort(errorCode, ec.GetString("code"))
+ }
+ ec.Abort("404")
+}
+
+func TestErrorCode_01(t *testing.T) {
+ registerDefaultErrorHandler()
+ for k := range ErrorMaps {
+ r, _ := http.NewRequest("GET", "/error?code="+k, nil)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.Add("/error", &errorTestController{})
+ handler.ServeHTTP(w, r)
+ code, _ := strconv.Atoi(k)
+ if w.Code != code {
+ t.Fail()
+ }
+ if !strings.Contains(string(w.Body.Bytes()), http.StatusText(code)) {
+ t.Fail()
+ }
+ }
+}
+
+func TestErrorCode_02(t *testing.T) {
+ registerDefaultErrorHandler()
+ r, _ := http.NewRequest("GET", "/error?code=0", nil)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.Add("/error", &errorTestController{})
+ handler.ServeHTTP(w, r)
+ if w.Code != 404 {
+ t.Fail()
+ }
+}
+
+func TestErrorCode_03(t *testing.T) {
+ registerDefaultErrorHandler()
+ r, _ := http.NewRequest("GET", "/error?code=panic", nil)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.Add("/error", &errorTestController{})
+ handler.ServeHTTP(w, r)
+ if w.Code != 200 {
+ t.Fail()
+ }
+ if string(w.Body.Bytes()) != parseCodeError {
+ t.Fail()
+ }
+}
diff --git a/filter_test.go b/filter_test.go
index d9928d8d..4ca4d2b8 100644
--- a/filter_test.go
+++ b/filter_test.go
@@ -20,14 +20,8 @@ import (
"testing"
"github.com/astaxie/beego/context"
- "github.com/astaxie/beego/logs"
)
-func init() {
- BeeLogger = logs.NewLogger(10000)
- BeeLogger.SetLogger("console", "")
-}
-
var FilterUser = func(ctx *context.Context) {
ctx.Output.Body([]byte("i am " + ctx.Input.Param(":last") + ctx.Input.Param(":first")))
}
diff --git a/hooks.go b/hooks.go
index 59b10b32..3dca1b8d 100644
--- a/hooks.go
+++ b/hooks.go
@@ -6,6 +6,8 @@ import (
"net/http"
"path/filepath"
+ "github.com/astaxie/beego/context"
+ "github.com/astaxie/beego/logs"
"github.com/astaxie/beego/session"
)
@@ -45,13 +47,16 @@ func registerSession() error {
sessionConfig := AppConfig.String("sessionConfig")
if sessionConfig == "" {
conf := map[string]interface{}{
- "cookieName": BConfig.WebConfig.Session.SessionName,
- "gclifetime": BConfig.WebConfig.Session.SessionGCMaxLifetime,
- "providerConfig": filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig),
- "secure": BConfig.Listen.EnableHTTPS,
- "enableSetCookie": BConfig.WebConfig.Session.SessionAutoSetCookie,
- "domain": BConfig.WebConfig.Session.SessionDomain,
- "cookieLifeTime": BConfig.WebConfig.Session.SessionCookieLifeTime,
+ "cookieName": BConfig.WebConfig.Session.SessionName,
+ "gclifetime": BConfig.WebConfig.Session.SessionGCMaxLifetime,
+ "providerConfig": filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig),
+ "secure": BConfig.Listen.EnableHTTPS,
+ "enableSetCookie": BConfig.WebConfig.Session.SessionAutoSetCookie,
+ "domain": BConfig.WebConfig.Session.SessionDomain,
+ "cookieLifeTime": BConfig.WebConfig.Session.SessionCookieLifeTime,
+ "enableSidInHttpHeader": BConfig.WebConfig.Session.EnableSidInHttpHeader,
+ "sessionNameInHttpHeader": BConfig.WebConfig.Session.SessionNameInHttpHeader,
+ "enableSidInUrlQuery": BConfig.WebConfig.Session.EnableSidInUrlQuery,
}
confBytes, err := json.Marshal(conf)
if err != nil {
@@ -70,24 +75,27 @@ func registerSession() error {
func registerTemplate() error {
if err := BuildTemplate(BConfig.WebConfig.ViewsPath); err != nil {
if BConfig.RunMode == DEV {
- Warn(err)
+ logs.Warn(err)
}
return err
}
return nil
}
-func registerDocs() error {
- if BConfig.WebConfig.EnableDocs {
- Get("/docs", serverDocs)
- Get("/docs/*", serverDocs)
- }
- return nil
-}
-
func registerAdmin() error {
if BConfig.Listen.EnableAdmin {
go beeAdminApp.Run()
}
return nil
}
+
+func registerGzip() error {
+ if BConfig.EnableGzip {
+ context.InitGzip(
+ AppConfig.DefaultInt("gzipMinLength", -1),
+ AppConfig.DefaultInt("gzipCompressLevel", -1),
+ AppConfig.DefaultStrings("includedMethods", []string{"GET"}),
+ )
+ }
+ return nil
+}
diff --git a/log.go b/log.go
index 46ec57dd..e9412f92 100644
--- a/log.go
+++ b/log.go
@@ -33,82 +33,77 @@ const (
)
// BeeLogger references the used application logger.
-var BeeLogger = logs.NewLogger(100)
+var BeeLogger = logs.GetBeeLogger()
// SetLevel sets the global log level used by the simple logger.
func SetLevel(l int) {
- BeeLogger.SetLevel(l)
+ logs.SetLevel(l)
}
// SetLogFuncCall set the CallDepth, default is 3
func SetLogFuncCall(b bool) {
- BeeLogger.EnableFuncCallDepth(b)
- BeeLogger.SetLogFuncCallDepth(3)
+ logs.SetLogFuncCall(b)
}
// SetLogger sets a new logger.
func SetLogger(adaptername string, config string) error {
- err := BeeLogger.SetLogger(adaptername, config)
- if err != nil {
- return err
- }
- return nil
+ return logs.SetLogger(adaptername, config)
}
// Emergency logs a message at emergency level.
func Emergency(v ...interface{}) {
- BeeLogger.Emergency(generateFmtStr(len(v)), v...)
+ logs.Emergency(generateFmtStr(len(v)), v...)
}
// Alert logs a message at alert level.
func Alert(v ...interface{}) {
- BeeLogger.Alert(generateFmtStr(len(v)), v...)
+ logs.Alert(generateFmtStr(len(v)), v...)
}
// Critical logs a message at critical level.
func Critical(v ...interface{}) {
- BeeLogger.Critical(generateFmtStr(len(v)), v...)
+ logs.Critical(generateFmtStr(len(v)), v...)
}
// Error logs a message at error level.
func Error(v ...interface{}) {
- BeeLogger.Error(generateFmtStr(len(v)), v...)
+ logs.Error(generateFmtStr(len(v)), v...)
}
// Warning logs a message at warning level.
func Warning(v ...interface{}) {
- BeeLogger.Warning(generateFmtStr(len(v)), v...)
+ logs.Warning(generateFmtStr(len(v)), v...)
}
// Warn compatibility alias for Warning()
func Warn(v ...interface{}) {
- BeeLogger.Warn(generateFmtStr(len(v)), v...)
+ logs.Warn(generateFmtStr(len(v)), v...)
}
// Notice logs a message at notice level.
func Notice(v ...interface{}) {
- BeeLogger.Notice(generateFmtStr(len(v)), v...)
+ logs.Notice(generateFmtStr(len(v)), v...)
}
// Informational logs a message at info level.
func Informational(v ...interface{}) {
- BeeLogger.Informational(generateFmtStr(len(v)), v...)
+ logs.Informational(generateFmtStr(len(v)), v...)
}
// Info compatibility alias for Warning()
func Info(v ...interface{}) {
- BeeLogger.Info(generateFmtStr(len(v)), v...)
+ logs.Info(generateFmtStr(len(v)), v...)
}
// Debug logs a message at debug level.
func Debug(v ...interface{}) {
- BeeLogger.Debug(generateFmtStr(len(v)), v...)
+ logs.Debug(generateFmtStr(len(v)), v...)
}
// Trace logs a message at trace level.
// compatibility alias for Warning()
func Trace(v ...interface{}) {
- BeeLogger.Trace(generateFmtStr(len(v)), v...)
+ logs.Trace(generateFmtStr(len(v)), v...)
}
func generateFmtStr(n int) string {
diff --git a/logs/conn.go b/logs/conn.go
index 1db1a427..6d5bf6bf 100644
--- a/logs/conn.go
+++ b/logs/conn.go
@@ -113,5 +113,5 @@ func (c *connWriter) needToConnectOnMsg() bool {
}
func init() {
- Register("conn", NewConn)
+ Register(AdapterConn, NewConn)
}
diff --git a/logs/console.go b/logs/console.go
index dc41dd7d..e6bf6c29 100644
--- a/logs/console.go
+++ b/logs/console.go
@@ -97,5 +97,5 @@ func (c *consoleWriter) Flush() {
}
func init() {
- Register("console", NewConsole)
+ Register(AdapterConsole, NewConsole)
}
diff --git a/logs/es/es.go b/logs/es/es.go
index 397ca2ef..22f4f650 100644
--- a/logs/es/es.go
+++ b/logs/es/es.go
@@ -76,5 +76,5 @@ func (el *esLogger) Flush() {
}
func init() {
- logs.Register("es", NewES)
+ logs.Register(logs.AdapterEs, NewES)
}
diff --git a/logs/file.go b/logs/file.go
index 9d3f78a0..7798a221 100644
--- a/logs/file.go
+++ b/logs/file.go
@@ -22,6 +22,7 @@ import (
"io"
"os"
"path/filepath"
+ "strconv"
"strings"
"sync"
"time"
@@ -30,7 +31,7 @@ import (
// fileLogWriter implements LoggerInterface.
// It writes messages by lines limit, file size limit, or time frequency.
type fileLogWriter struct {
- sync.Mutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize
+ sync.RWMutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize
// The opened file
Filename string `json:"filename"`
fileWriter *os.File
@@ -47,12 +48,13 @@ type fileLogWriter struct {
Daily bool `json:"daily"`
MaxDays int64 `json:"maxdays"`
dailyOpenDate int
+ dailyOpenTime time.Time
Rotate bool `json:"rotate"`
Level int `json:"level"`
- Perm os.FileMode `json:"perm"`
+ Perm string `json:"perm"`
fileNameOnly, suffix string // like "project.log", project is fileNameOnly and .log is suffix
}
@@ -60,14 +62,11 @@ type fileLogWriter struct {
// newFileWriter create a FileLogWriter returning as LoggerInterface.
func newFileWriter() Logger {
w := &fileLogWriter{
- Filename: "",
- MaxLines: 1000000,
- MaxSize: 1 << 28, //256 MB
- Daily: true,
- MaxDays: 7,
- Rotate: true,
- Level: LevelTrace,
- Perm: 0660,
+ Daily: true,
+ MaxDays: 7,
+ Rotate: true,
+ Level: LevelTrace,
+ Perm: "0660",
}
return w
}
@@ -77,11 +76,11 @@ func newFileWriter() Logger {
// {
// "filename":"logs/beego.log",
// "maxLines":10000,
-// "maxsize":1<<30,
+// "maxsize":1024,
// "daily":true,
// "maxDays":15,
// "rotate":true,
-// "perm":0600
+// "perm":"0600"
// }
func (w *fileLogWriter) Init(jsonConfig string) error {
err := json.Unmarshal([]byte(jsonConfig), w)
@@ -128,7 +127,9 @@ func (w *fileLogWriter) WriteMsg(when time.Time, msg string, level int) error {
h, d := formatTimeHeader(when)
msg = string(h) + msg + "\n"
if w.Rotate {
+ w.RLock()
if w.needRotate(len(msg), d) {
+ w.RUnlock()
w.Lock()
if w.needRotate(len(msg), d) {
if err := w.doRotate(when); err != nil {
@@ -136,6 +137,8 @@ func (w *fileLogWriter) WriteMsg(when time.Time, msg string, level int) error {
}
}
w.Unlock()
+ } else {
+ w.RUnlock()
}
}
@@ -151,7 +154,11 @@ func (w *fileLogWriter) WriteMsg(when time.Time, msg string, level int) error {
func (w *fileLogWriter) createLogFile() (*os.File, error) {
// Open the log file
- fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, w.Perm)
+ perm, err := strconv.ParseInt(w.Perm, 8, 64)
+ if err != nil {
+ return nil, err
+ }
+ fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, os.FileMode(perm))
return fd, err
}
@@ -162,8 +169,12 @@ func (w *fileLogWriter) initFd() error {
return fmt.Errorf("get stat err: %s\n", err)
}
w.maxSizeCurSize = int(fInfo.Size())
- w.dailyOpenDate = time.Now().Day()
+ w.dailyOpenTime = time.Now()
+ w.dailyOpenDate = w.dailyOpenTime.Day()
w.maxLinesCurLines = 0
+ if w.Daily {
+ go w.dailyRotate(w.dailyOpenTime)
+ }
if fInfo.Size() > 0 {
count, err := w.lines()
if err != nil {
@@ -174,6 +185,22 @@ func (w *fileLogWriter) initFd() error {
return nil
}
+func (w *fileLogWriter) dailyRotate(openTime time.Time) {
+ y, m, d := openTime.Add(24 * time.Hour).Date()
+ nextDay := time.Date(y, m, d, 0, 0, 0, 0, openTime.Location())
+ tm := time.NewTimer(time.Duration(nextDay.UnixNano() - openTime.UnixNano() + 100))
+ select {
+ case <-tm.C:
+ w.Lock()
+ if w.needRotate(0, time.Now().Day()) {
+ if err := w.doRotate(time.Now()); err != nil {
+ fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
+ }
+ }
+ w.Unlock()
+ }
+}
+
func (w *fileLogWriter) lines() (int, error) {
fd, err := os.Open(w.Filename)
if err != nil {
@@ -204,22 +231,29 @@ func (w *fileLogWriter) lines() (int, error) {
// DoRotate means it need to write file in new file.
// new file name like xx.2013-01-01.log (daily) or xx.001.log (by line or size)
func (w *fileLogWriter) doRotate(logTime time.Time) error {
- _, err := os.Lstat(w.Filename)
- if err != nil {
- return err
- }
// file exists
// Find the next available number
num := 1
fName := ""
+
+ _, err := os.Lstat(w.Filename)
+ if err != nil {
+ //even if the file is not exist or other ,we should RESTART the logger
+ goto RESTART_LOGGER
+ }
+
if w.MaxLines > 0 || w.MaxSize > 0 {
for ; err == nil && num <= 999; num++ {
fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", logTime.Format("2006-01-02"), num, w.suffix)
_, err = os.Lstat(fName)
}
} else {
- fName = fmt.Sprintf("%s.%s%s", w.fileNameOnly, logTime.Format("2006-01-02"), w.suffix)
+ fName = fmt.Sprintf("%s.%s%s", w.fileNameOnly, w.dailyOpenTime.Format("2006-01-02"), w.suffix)
_, err = os.Lstat(fName)
+ for ; err == nil && num <= 999; num++ {
+ fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", w.dailyOpenTime.Format("2006-01-02"), num, w.suffix)
+ _, err = os.Lstat(fName)
+ }
}
// return error if the last file checked still existed
if err == nil {
@@ -231,16 +265,18 @@ func (w *fileLogWriter) doRotate(logTime time.Time) error {
// Rename the file to its new found name
// even if occurs error,we MUST guarantee to restart new logger
- renameErr := os.Rename(w.Filename, fName)
+ err = os.Rename(w.Filename, fName)
// re-start logger
+RESTART_LOGGER:
+
startLoggerErr := w.startLogger()
go w.deleteOldLog()
if startLoggerErr != nil {
return fmt.Errorf("Rotate StartLogger: %s\n", startLoggerErr)
}
- if renameErr != nil {
- return fmt.Errorf("Rotate: %s\n", renameErr)
+ if err != nil {
+ return fmt.Errorf("Rotate: %s\n", err)
}
return nil
@@ -255,8 +291,12 @@ func (w *fileLogWriter) deleteOldLog() {
}
}()
- if !info.IsDir() && info.ModTime().Unix() < (time.Now().Unix()-60*60*24*w.MaxDays) {
- if strings.HasPrefix(filepath.Base(path), w.fileNameOnly) &&
+ if info == nil {
+ return
+ }
+
+ if !info.IsDir() && info.ModTime().Add(24*time.Hour*time.Duration(w.MaxDays)).Before(time.Now()) {
+ if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) &&
strings.HasSuffix(filepath.Base(path), w.suffix) {
os.Remove(path)
}
@@ -278,5 +318,5 @@ func (w *fileLogWriter) Flush() {
}
func init() {
- Register("file", newFileWriter)
+ Register(AdapterFile, newFileWriter)
}
diff --git a/logs/file_test.go b/logs/file_test.go
index 1fa6cdaa..23370947 100644
--- a/logs/file_test.go
+++ b/logs/file_test.go
@@ -17,12 +17,34 @@ package logs
import (
"bufio"
"fmt"
+ "io/ioutil"
"os"
"strconv"
"testing"
"time"
)
+func TestFilePerm(t *testing.T) {
+ log := NewLogger(10000)
+ log.SetLogger("file", `{"filename":"test.log", "perm": "0600"}`)
+ log.Debug("debug")
+ log.Informational("info")
+ log.Notice("notice")
+ log.Warning("warning")
+ log.Error("error")
+ log.Alert("alert")
+ log.Critical("critical")
+ log.Emergency("emergency")
+ file, err := os.Stat("test.log")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if file.Mode() != 0600 {
+ t.Fatal("unexpected log file permission")
+ }
+ os.Remove("test.log")
+}
+
func TestFile1(t *testing.T) {
log := NewLogger(10000)
log.SetLogger("file", `{"filename":"test.log"}`)
@@ -89,7 +111,7 @@ func TestFile2(t *testing.T) {
os.Remove("test2.log")
}
-func TestFileRotate(t *testing.T) {
+func TestFileRotate_01(t *testing.T) {
log := NewLogger(10000)
log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`)
log.Debug("debug")
@@ -110,6 +132,90 @@ func TestFileRotate(t *testing.T) {
os.Remove("test3.log")
}
+func TestFileRotate_02(t *testing.T) {
+ fn1 := "rotate_day.log"
+ fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log"
+ testFileRotate(t, fn1, fn2)
+}
+
+func TestFileRotate_03(t *testing.T) {
+ fn1 := "rotate_day.log"
+ fn := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log"
+ os.Create(fn)
+ fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log"
+ testFileRotate(t, fn1, fn2)
+ os.Remove(fn)
+}
+
+func TestFileRotate_04(t *testing.T) {
+ fn1 := "rotate_day.log"
+ fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log"
+ testFileDailyRotate(t, fn1, fn2)
+}
+
+func TestFileRotate_05(t *testing.T) {
+ fn1 := "rotate_day.log"
+ fn := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log"
+ os.Create(fn)
+ fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log"
+ testFileDailyRotate(t, fn1, fn2)
+ os.Remove(fn)
+}
+
+func testFileRotate(t *testing.T, fn1, fn2 string) {
+ fw := &fileLogWriter{
+ Daily: true,
+ MaxDays: 7,
+ Rotate: true,
+ Level: LevelTrace,
+ Perm: "0660",
+ }
+ fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1))
+ fw.dailyOpenTime = time.Now().Add(-24 * time.Hour)
+ fw.dailyOpenDate = fw.dailyOpenTime.Day()
+ fw.WriteMsg(time.Now(), "this is a msg for test", LevelDebug)
+
+ for _, file := range []string{fn1, fn2} {
+ _, err := os.Stat(file)
+ if err != nil {
+ t.FailNow()
+ }
+ os.Remove(file)
+ }
+ fw.Destroy()
+}
+
+func testFileDailyRotate(t *testing.T, fn1, fn2 string) {
+ fw := &fileLogWriter{
+ Daily: true,
+ MaxDays: 7,
+ Rotate: true,
+ Level: LevelTrace,
+ Perm: "0660",
+ }
+ fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1))
+ fw.dailyOpenTime = time.Now().Add(-24 * time.Hour)
+ fw.dailyOpenDate = fw.dailyOpenTime.Day()
+ today, _ := time.ParseInLocation("2006-01-02", time.Now().Format("2006-01-02"), fw.dailyOpenTime.Location())
+ today = today.Add(-1 * time.Second)
+ fw.dailyRotate(today)
+ for _, file := range []string{fn1, fn2} {
+ _, err := os.Stat(file)
+ if err != nil {
+ t.FailNow()
+ }
+ content, err := ioutil.ReadFile(file)
+ if err != nil {
+ t.FailNow()
+ }
+ if len(content) > 0 {
+ t.FailNow()
+ }
+ os.Remove(file)
+ }
+ fw.Destroy()
+}
+
func exists(path string) (bool, error) {
_, err := os.Stat(path)
if err == nil {
diff --git a/logs/log.go b/logs/log.go
index a5982e45..c43782f3 100644
--- a/logs/log.go
+++ b/logs/log.go
@@ -35,10 +35,12 @@ package logs
import (
"fmt"
+ "log"
"os"
"path"
"runtime"
"strconv"
+ "strings"
"sync"
"time"
)
@@ -55,16 +57,28 @@ const (
LevelDebug
)
-// Legacy loglevel constants to ensure backwards compatibility.
-//
-// Deprecated: will be removed in 1.5.0.
+// levelLogLogger is defined to implement log.Logger
+// the real log level will be LevelEmergency
+const levelLoggerImpl = -1
+
+// Name for adapter with beego official support
+const (
+ AdapterConsole = "console"
+ AdapterFile = "file"
+ AdapterMultiFile = "multifile"
+ AdapterMail = "stmp"
+ AdapterConn = "conn"
+ AdapterEs = "es"
+)
+
+// Legacy log level constants to ensure backwards compatibility.
const (
LevelInfo = LevelInformational
LevelTrace = LevelDebug
LevelWarn = LevelWarning
)
-type loggerType func() Logger
+type newLoggerFunc func() Logger
// Logger defines the behavior of a log provider.
type Logger interface {
@@ -74,12 +88,13 @@ type Logger interface {
Flush()
}
-var adapters = make(map[string]loggerType)
+var adapters = make(map[string]newLoggerFunc)
+var levelPrefix = [LevelDebug + 1]string{"[M] ", "[A] ", "[C] ", "[E] ", "[W] ", "[N] ", "[I] ", "[D] "}
// Register makes a log provide available by the provided name.
// If Register is called twice with the same name or if driver is nil,
// it panics.
-func Register(name string, log loggerType) {
+func Register(name string, log newLoggerFunc) {
if log == nil {
panic("logs: Register provide is nil")
}
@@ -94,15 +109,19 @@ func Register(name string, log loggerType) {
type BeeLogger struct {
lock sync.Mutex
level int
+ init bool
enableFuncCallDepth bool
loggerFuncCallDepth int
asynchronous bool
+ msgChanLen int64
msgChan chan *logMsg
signalChan chan string
wg sync.WaitGroup
outputs []*nameLogger
}
+const defaultAsyncMsgLen = 1e3
+
type nameLogger struct {
Logger
name string
@@ -119,18 +138,31 @@ var logMsgPool *sync.Pool
// NewLogger returns a new BeeLogger.
// channelLen means the number of messages in chan(used where asynchronous is true).
// if the buffering chan is full, logger adapters write to file or other way.
-func NewLogger(channelLen int64) *BeeLogger {
+func NewLogger(channelLens ...int64) *BeeLogger {
bl := new(BeeLogger)
bl.level = LevelDebug
bl.loggerFuncCallDepth = 2
- bl.msgChan = make(chan *logMsg, channelLen)
+ bl.msgChanLen = append(channelLens, 0)[0]
+ if bl.msgChanLen <= 0 {
+ bl.msgChanLen = defaultAsyncMsgLen
+ }
bl.signalChan = make(chan string, 1)
+ bl.setLogger(AdapterConsole)
return bl
}
// Async set the log to asynchronous and start the goroutine
-func (bl *BeeLogger) Async() *BeeLogger {
+func (bl *BeeLogger) Async(msgLen ...int64) *BeeLogger {
+ bl.lock.Lock()
+ defer bl.lock.Unlock()
+ if bl.asynchronous {
+ return bl
+ }
bl.asynchronous = true
+ if len(msgLen) > 0 && msgLen[0] > 0 {
+ bl.msgChanLen = msgLen[0]
+ }
+ bl.msgChan = make(chan *logMsg, bl.msgChanLen)
logMsgPool = &sync.Pool{
New: func() interface{} {
return &logMsg{}
@@ -143,10 +175,8 @@ func (bl *BeeLogger) Async() *BeeLogger {
// SetLogger provides a given logger adapter into BeeLogger with config string.
// config need to be correct JSON as string: {"interval":360}.
-func (bl *BeeLogger) SetLogger(adapterName string, config string) error {
- bl.lock.Lock()
- defer bl.lock.Unlock()
-
+func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error {
+ config := append(configs, "{}")[0]
for _, l := range bl.outputs {
if l.name == adapterName {
return fmt.Errorf("logs: duplicate adaptername %q (you have set this logger before)", adapterName)
@@ -168,6 +198,18 @@ func (bl *BeeLogger) SetLogger(adapterName string, config string) error {
return nil
}
+// SetLogger provides a given logger adapter into BeeLogger with config string.
+// config need to be correct JSON as string: {"interval":360}.
+func (bl *BeeLogger) SetLogger(adapterName string, configs ...string) error {
+ bl.lock.Lock()
+ defer bl.lock.Unlock()
+ if !bl.init {
+ bl.outputs = []*nameLogger{}
+ bl.init = true
+ }
+ return bl.setLogger(adapterName, configs...)
+}
+
// DelLogger remove a logger adapter in BeeLogger.
func (bl *BeeLogger) DelLogger(adapterName string) error {
bl.lock.Lock()
@@ -196,7 +238,37 @@ func (bl *BeeLogger) writeToLoggers(when time.Time, msg string, level int) {
}
}
-func (bl *BeeLogger) writeMsg(logLevel int, msg string) error {
+func (bl *BeeLogger) Write(p []byte) (n int, err error) {
+ if len(p) == 0 {
+ return 0, nil
+ }
+ // writeMsg will always add a '\n' character
+ if p[len(p)-1] == '\n' {
+ p = p[0 : len(p)-1]
+ }
+ // set levelLoggerImpl to ensure all log message will be write out
+ err = bl.writeMsg(levelLoggerImpl, string(p))
+ if err == nil {
+ return len(p), err
+ }
+ return 0, err
+}
+
+func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error {
+ if !bl.init {
+ bl.lock.Lock()
+ bl.setLogger(AdapterConsole)
+ bl.lock.Unlock()
+ }
+ if logLevel == levelLoggerImpl {
+ // set to emergency to ensure all log will be print out correctly
+ logLevel = LevelEmergency
+ } else {
+ msg = levelPrefix[logLevel] + msg
+ }
+ if len(v) > 0 {
+ msg = fmt.Sprintf(msg, v...)
+ }
when := time.Now()
if bl.enableFuncCallDepth {
_, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth)
@@ -205,7 +277,7 @@ func (bl *BeeLogger) writeMsg(logLevel int, msg string) error {
line = 0
}
_, filename := path.Split(file)
- msg = "[" + filename + ":" + strconv.FormatInt(int64(line), 10) + "]" + msg
+ msg = "[" + filename + ":" + strconv.FormatInt(int64(line), 10) + "] " + msg
}
if bl.asynchronous {
lm := logMsgPool.Get().(*logMsg)
@@ -273,8 +345,7 @@ func (bl *BeeLogger) Emergency(format string, v ...interface{}) {
if LevelEmergency > bl.level {
return
}
- msg := fmt.Sprintf("[M] "+format, v...)
- bl.writeMsg(LevelEmergency, msg)
+ bl.writeMsg(LevelEmergency, format, v...)
}
// Alert Log ALERT level message.
@@ -282,8 +353,7 @@ func (bl *BeeLogger) Alert(format string, v ...interface{}) {
if LevelAlert > bl.level {
return
}
- msg := fmt.Sprintf("[A] "+format, v...)
- bl.writeMsg(LevelAlert, msg)
+ bl.writeMsg(LevelAlert, format, v...)
}
// Critical Log CRITICAL level message.
@@ -291,8 +361,7 @@ func (bl *BeeLogger) Critical(format string, v ...interface{}) {
if LevelCritical > bl.level {
return
}
- msg := fmt.Sprintf("[C] "+format, v...)
- bl.writeMsg(LevelCritical, msg)
+ bl.writeMsg(LevelCritical, format, v...)
}
// Error Log ERROR level message.
@@ -300,17 +369,12 @@ func (bl *BeeLogger) Error(format string, v ...interface{}) {
if LevelError > bl.level {
return
}
- msg := fmt.Sprintf("[E] "+format, v...)
- bl.writeMsg(LevelError, msg)
+ bl.writeMsg(LevelError, format, v...)
}
// Warning Log WARNING level message.
func (bl *BeeLogger) Warning(format string, v ...interface{}) {
- if LevelWarning > bl.level {
- return
- }
- msg := fmt.Sprintf("[W] "+format, v...)
- bl.writeMsg(LevelWarning, msg)
+ bl.Warn(format, v...)
}
// Notice Log NOTICE level message.
@@ -318,17 +382,12 @@ func (bl *BeeLogger) Notice(format string, v ...interface{}) {
if LevelNotice > bl.level {
return
}
- msg := fmt.Sprintf("[N] "+format, v...)
- bl.writeMsg(LevelNotice, msg)
+ bl.writeMsg(LevelNotice, format, v...)
}
// Informational Log INFORMATIONAL level message.
func (bl *BeeLogger) Informational(format string, v ...interface{}) {
- if LevelInformational > bl.level {
- return
- }
- msg := fmt.Sprintf("[I] "+format, v...)
- bl.writeMsg(LevelInformational, msg)
+ bl.Info(format, v...)
}
// Debug Log DEBUG level message.
@@ -336,38 +395,31 @@ func (bl *BeeLogger) Debug(format string, v ...interface{}) {
if LevelDebug > bl.level {
return
}
- msg := fmt.Sprintf("[D] "+format, v...)
- bl.writeMsg(LevelDebug, msg)
+ bl.writeMsg(LevelDebug, format, v...)
}
// Warn Log WARN level message.
// compatibility alias for Warning()
func (bl *BeeLogger) Warn(format string, v ...interface{}) {
- if LevelWarning > bl.level {
+ if LevelWarn > bl.level {
return
}
- msg := fmt.Sprintf("[W] "+format, v...)
- bl.writeMsg(LevelWarning, msg)
+ bl.writeMsg(LevelWarn, format, v...)
}
// Info Log INFO level message.
// compatibility alias for Informational()
func (bl *BeeLogger) Info(format string, v ...interface{}) {
- if LevelInformational > bl.level {
+ if LevelInfo > bl.level {
return
}
- msg := fmt.Sprintf("[I] "+format, v...)
- bl.writeMsg(LevelInformational, msg)
+ bl.writeMsg(LevelInfo, format, v...)
}
// Trace Log TRACE level message.
// compatibility alias for Debug()
func (bl *BeeLogger) Trace(format string, v ...interface{}) {
- if LevelDebug > bl.level {
- return
- }
- msg := fmt.Sprintf("[D] "+format, v...)
- bl.writeMsg(LevelDebug, msg)
+ bl.Debug(format, v...)
}
// Flush flush all chan data.
@@ -386,6 +438,7 @@ func (bl *BeeLogger) Close() {
if bl.asynchronous {
bl.signalChan <- "close"
bl.wg.Wait()
+ close(bl.msgChan)
} else {
bl.flush()
for _, l := range bl.outputs {
@@ -393,7 +446,6 @@ func (bl *BeeLogger) Close() {
}
bl.outputs = nil
}
- close(bl.msgChan)
close(bl.signalChan)
}
@@ -407,16 +459,175 @@ func (bl *BeeLogger) Reset() {
}
func (bl *BeeLogger) flush() {
- for {
- if len(bl.msgChan) > 0 {
- bm := <-bl.msgChan
- bl.writeToLoggers(bm.when, bm.msg, bm.level)
- logMsgPool.Put(bm)
- continue
+ if bl.asynchronous {
+ for {
+ if len(bl.msgChan) > 0 {
+ bm := <-bl.msgChan
+ bl.writeToLoggers(bm.when, bm.msg, bm.level)
+ logMsgPool.Put(bm)
+ continue
+ }
+ break
}
- break
}
for _, l := range bl.outputs {
l.Flush()
}
}
+
+// beeLogger references the used application logger.
+var beeLogger *BeeLogger = NewLogger()
+
+// GetLogger returns the default BeeLogger
+func GetBeeLogger() *BeeLogger {
+ return beeLogger
+}
+
+var beeLoggerMap = struct {
+ sync.RWMutex
+ logs map[string]*log.Logger
+}{
+ logs: map[string]*log.Logger{},
+}
+
+// GetLogger returns the default BeeLogger
+func GetLogger(prefixes ...string) *log.Logger {
+ prefix := append(prefixes, "")[0]
+ if prefix != "" {
+ prefix = fmt.Sprintf(`[%s] `, strings.ToUpper(prefix))
+ }
+ beeLoggerMap.RLock()
+ l, ok := beeLoggerMap.logs[prefix]
+ if ok {
+ beeLoggerMap.RUnlock()
+ return l
+ }
+ beeLoggerMap.RUnlock()
+ beeLoggerMap.Lock()
+ defer beeLoggerMap.Unlock()
+ l, ok = beeLoggerMap.logs[prefix]
+ if !ok {
+ l = log.New(beeLogger, prefix, 0)
+ beeLoggerMap.logs[prefix] = l
+ }
+ return l
+}
+
+// Reset will remove all the adapter
+func Reset() {
+ beeLogger.Reset()
+}
+
+func Async(msgLen ...int64) *BeeLogger {
+ return beeLogger.Async(msgLen...)
+}
+
+// SetLevel sets the global log level used by the simple logger.
+func SetLevel(l int) {
+ beeLogger.SetLevel(l)
+}
+
+// EnableFuncCallDepth enable log funcCallDepth
+func EnableFuncCallDepth(b bool) {
+ beeLogger.enableFuncCallDepth = b
+}
+
+// SetLogFuncCall set the CallDepth, default is 3
+func SetLogFuncCall(b bool) {
+ beeLogger.EnableFuncCallDepth(b)
+ beeLogger.SetLogFuncCallDepth(3)
+}
+
+// SetLogFuncCallDepth set log funcCallDepth
+func SetLogFuncCallDepth(d int) {
+ beeLogger.loggerFuncCallDepth = d
+}
+
+// SetLogger sets a new logger.
+func SetLogger(adapter string, config ...string) error {
+ err := beeLogger.SetLogger(adapter, config...)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+// Emergency logs a message at emergency level.
+func Emergency(f interface{}, v ...interface{}) {
+ beeLogger.Emergency(formatLog(f, v...))
+}
+
+// Alert logs a message at alert level.
+func Alert(f interface{}, v ...interface{}) {
+ beeLogger.Alert(formatLog(f, v...))
+}
+
+// Critical logs a message at critical level.
+func Critical(f interface{}, v ...interface{}) {
+ beeLogger.Critical(formatLog(f, v...))
+}
+
+// Error logs a message at error level.
+func Error(f interface{}, v ...interface{}) {
+ beeLogger.Error(formatLog(f, v...))
+}
+
+// Warning logs a message at warning level.
+func Warning(f interface{}, v ...interface{}) {
+ beeLogger.Warn(formatLog(f, v...))
+}
+
+// Warn compatibility alias for Warning()
+func Warn(f interface{}, v ...interface{}) {
+ beeLogger.Warn(formatLog(f, v...))
+}
+
+// Notice logs a message at notice level.
+func Notice(f interface{}, v ...interface{}) {
+ beeLogger.Notice(formatLog(f, v...))
+}
+
+// Informational logs a message at info level.
+func Informational(f interface{}, v ...interface{}) {
+ beeLogger.Info(formatLog(f, v...))
+}
+
+// Info compatibility alias for Warning()
+func Info(f interface{}, v ...interface{}) {
+ beeLogger.Info(formatLog(f, v...))
+}
+
+// Debug logs a message at debug level.
+func Debug(f interface{}, v ...interface{}) {
+ beeLogger.Debug(formatLog(f, v...))
+}
+
+// Trace logs a message at trace level.
+// compatibility alias for Warning()
+func Trace(f interface{}, v ...interface{}) {
+ beeLogger.Trace(formatLog(f, v...))
+}
+
+func formatLog(f interface{}, v ...interface{}) string {
+ var msg string
+ switch f.(type) {
+ case string:
+ msg = f.(string)
+ if len(v) == 0 {
+ return msg
+ }
+ if strings.Contains(msg, "%") && !strings.Contains(msg, "%%") {
+ //format string
+ } else {
+ //do not contain format char
+ msg += strings.Repeat(" %v", len(v))
+ }
+ default:
+ msg = fmt.Sprint(f)
+ if len(v) == 0 {
+ return msg
+ }
+ msg += strings.Repeat(" %v", len(v))
+ }
+ return fmt.Sprintf(msg, v...)
+}
diff --git a/logs/logger.go b/logs/logger.go
index b25bfaef..2f47e569 100644
--- a/logs/logger.go
+++ b/logs/logger.go
@@ -36,43 +36,46 @@ func (lg *logWriter) println(when time.Time, msg string) {
lg.Unlock()
}
+const y1 = `0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999`
+const y2 = `0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789`
+const mo1 = `000000000111`
+const mo2 = `123456789012`
+const d1 = `0000000001111111111222222222233`
+const d2 = `1234567890123456789012345678901`
+const h1 = `000000000011111111112222`
+const h2 = `012345678901234567890123`
+const mi1 = `000000000011111111112222222222333333333344444444445555555555`
+const mi2 = `012345678901234567890123456789012345678901234567890123456789`
+const s1 = `000000000011111111112222222222333333333344444444445555555555`
+const s2 = `012345678901234567890123456789012345678901234567890123456789`
+
func formatTimeHeader(when time.Time) ([]byte, int) {
y, mo, d := when.Date()
h, mi, s := when.Clock()
- //len(2006/01/02 15:03:04)==19
+ //len("2006/01/02 15:04:05 ")==20
var buf [20]byte
- t := 3
- for y >= 10 {
- p := y / 10
- buf[t] = byte('0' + y - p*10)
- y = p
- t--
- }
- buf[0] = byte('0' + y)
+
+ //change to '3' after 984 years, LOL
+ buf[0] = '2'
+ //change to '1' after 84 years, LOL
+ buf[1] = '0'
+ buf[2] = y1[y-2000]
+ buf[3] = y2[y-2000]
buf[4] = '/'
- if mo > 9 {
- buf[5] = '1'
- buf[6] = byte('0' + mo - 9)
- } else {
- buf[5] = '0'
- buf[6] = byte('0' + mo)
- }
+ buf[5] = mo1[mo-1]
+ buf[6] = mo2[mo-1]
buf[7] = '/'
- t = d / 10
- buf[8] = byte('0' + t)
- buf[9] = byte('0' + d - t*10)
+ buf[8] = d1[d-1]
+ buf[9] = d2[d-1]
buf[10] = ' '
- t = h / 10
- buf[11] = byte('0' + t)
- buf[12] = byte('0' + h - t*10)
+ buf[11] = h1[h]
+ buf[12] = h2[h]
buf[13] = ':'
- t = mi / 10
- buf[14] = byte('0' + t)
- buf[15] = byte('0' + mi - t*10)
+ buf[14] = mi1[mi]
+ buf[15] = mi2[mi]
buf[16] = ':'
- t = s / 10
- buf[17] = byte('0' + t)
- buf[18] = byte('0' + s - t*10)
+ buf[17] = s1[s]
+ buf[18] = s2[s]
buf[19] = ' '
return buf[0:], d
diff --git a/logs/logger_test.go b/logs/logger_test.go
new file mode 100644
index 00000000..4627853a
--- /dev/null
+++ b/logs/logger_test.go
@@ -0,0 +1,57 @@
+// Copyright 2016 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logs
+
+import (
+ "testing"
+ "time"
+)
+
+func TestFormatHeader_0(t *testing.T) {
+ tm := time.Now()
+ if tm.Year() >= 2100 {
+ t.FailNow()
+ }
+ dur := time.Second
+ for {
+ if tm.Year() >= 2100 {
+ break
+ }
+ h, _ := formatTimeHeader(tm)
+ if tm.Format("2006/01/02 15:04:05 ") != string(h) {
+ t.Log(tm)
+ t.FailNow()
+ }
+ tm = tm.Add(dur)
+ dur *= 2
+ }
+}
+
+func TestFormatHeader_1(t *testing.T) {
+ tm := time.Now()
+ year := tm.Year()
+ dur := time.Second
+ for {
+ if tm.Year() >= year+1 {
+ break
+ }
+ h, _ := formatTimeHeader(tm)
+ if tm.Format("2006/01/02 15:04:05 ") != string(h) {
+ t.Log(tm)
+ t.FailNow()
+ }
+ tm = tm.Add(dur)
+ }
+}
diff --git a/logs/multifile.go b/logs/multifile.go
index b82ba274..63204e17 100644
--- a/logs/multifile.go
+++ b/logs/multifile.go
@@ -112,5 +112,5 @@ func newFilesWriter() Logger {
}
func init() {
- Register("multifile", newFilesWriter)
+ Register(AdapterMultiFile, newFilesWriter)
}
diff --git a/logs/smtp.go b/logs/smtp.go
index 47f5a0c6..834130ef 100644
--- a/logs/smtp.go
+++ b/logs/smtp.go
@@ -156,5 +156,5 @@ func (s *SMTPWriter) Destroy() {
}
func init() {
- Register("smtp", newSMTPWriter)
+ Register(AdapterMail, newSMTPWriter)
}
diff --git a/migration/migration.go b/migration/migration.go
index 1591bc50..c9ca1bc6 100644
--- a/migration/migration.go
+++ b/migration/migration.go
@@ -33,7 +33,7 @@ import (
"strings"
"time"
- "github.com/astaxie/beego"
+ "github.com/astaxie/beego/logs"
"github.com/astaxie/beego/orm"
)
@@ -90,7 +90,7 @@ func (m *Migration) Reset() {
func (m *Migration) Exec(name, status string) error {
o := orm.NewOrm()
for _, s := range m.sqls {
- beego.Info("exec sql:", s)
+ logs.Info("exec sql:", s)
r := o.Raw(s)
_, err := r.Exec()
if err != nil {
@@ -144,20 +144,20 @@ func Upgrade(lasttime int64) error {
i := 0
for _, v := range sm {
if v.created > lasttime {
- beego.Info("start upgrade", v.name)
+ logs.Info("start upgrade", v.name)
v.m.Reset()
v.m.Up()
err := v.m.Exec(v.name, "up")
if err != nil {
- beego.Error("execute error:", err)
+ logs.Error("execute error:", err)
time.Sleep(2 * time.Second)
return err
}
- beego.Info("end upgrade:", v.name)
+ logs.Info("end upgrade:", v.name)
i++
}
}
- beego.Info("total success upgrade:", i, " migration")
+ logs.Info("total success upgrade:", i, " migration")
time.Sleep(2 * time.Second)
return nil
}
@@ -165,20 +165,20 @@ func Upgrade(lasttime int64) error {
// Rollback rollback the migration by the name
func Rollback(name string) error {
if v, ok := migrationMap[name]; ok {
- beego.Info("start rollback")
+ logs.Info("start rollback")
v.Reset()
v.Down()
err := v.Exec(name, "down")
if err != nil {
- beego.Error("execute error:", err)
+ logs.Error("execute error:", err)
time.Sleep(2 * time.Second)
return err
}
- beego.Info("end rollback")
+ logs.Info("end rollback")
time.Sleep(2 * time.Second)
return nil
}
- beego.Error("not exist the migrationMap name:" + name)
+ logs.Error("not exist the migrationMap name:" + name)
time.Sleep(2 * time.Second)
return errors.New("not exist the migrationMap name:" + name)
}
@@ -191,23 +191,23 @@ func Reset() error {
for j := len(sm) - 1; j >= 0; j-- {
v := sm[j]
if isRollBack(v.name) {
- beego.Info("skip the", v.name)
+ logs.Info("skip the", v.name)
time.Sleep(1 * time.Second)
continue
}
- beego.Info("start reset:", v.name)
+ logs.Info("start reset:", v.name)
v.m.Reset()
v.m.Down()
err := v.m.Exec(v.name, "down")
if err != nil {
- beego.Error("execute error:", err)
+ logs.Error("execute error:", err)
time.Sleep(2 * time.Second)
return err
}
i++
- beego.Info("end reset:", v.name)
+ logs.Info("end reset:", v.name)
}
- beego.Info("total success reset:", i, " migration")
+ logs.Info("total success reset:", i, " migration")
time.Sleep(2 * time.Second)
return nil
}
@@ -216,7 +216,7 @@ func Reset() error {
func Refresh() error {
err := Reset()
if err != nil {
- beego.Error("execute error:", err)
+ logs.Error("execute error:", err)
time.Sleep(2 * time.Second)
return err
}
@@ -265,7 +265,7 @@ func isRollBack(name string) bool {
var maps []orm.Params
num, err := o.Raw("select * from migrations where `name` = ? order by id_migration desc", name).Values(&maps)
if err != nil {
- beego.Info("get name has error", err)
+ logs.Info("get name has error", err)
return false
}
if num <= 0 {
diff --git a/namespace.go b/namespace.go
index 4007d44c..cfde0111 100644
--- a/namespace.go
+++ b/namespace.go
@@ -44,7 +44,7 @@ func NewNamespace(prefix string, params ...LinkNamespace) *Namespace {
return ns
}
-// Cond set condtion function
+// Cond set condition function
// if cond return true can run this namespace, else can't
// usage:
// ns.Cond(func (ctx *context.Context) bool{
@@ -60,7 +60,7 @@ func (n *Namespace) Cond(cond namespaceCond) *Namespace {
exception("405", ctx)
}
}
- if v, ok := n.handlers.filters[BeforeRouter]; ok {
+ if v := n.handlers.filters[BeforeRouter]; len(v) > 0 {
mr := new(FilterRouter)
mr.tree = NewTree()
mr.pattern = "*"
diff --git a/namespace_test.go b/namespace_test.go
index a92ae3ef..fc02b5fb 100644
--- a/namespace_test.go
+++ b/namespace_test.go
@@ -61,8 +61,8 @@ func TestNamespaceNest(t *testing.T) {
ns.Namespace(
NewNamespace("/admin").
Get("/order", func(ctx *context.Context) {
- ctx.Output.Body([]byte("order"))
- }),
+ ctx.Output.Body([]byte("order"))
+ }),
)
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
@@ -79,8 +79,8 @@ func TestNamespaceNestParam(t *testing.T) {
ns.Namespace(
NewNamespace("/admin").
Get("/order/:id", func(ctx *context.Context) {
- ctx.Output.Body([]byte(ctx.Input.Param(":id")))
- }),
+ ctx.Output.Body([]byte(ctx.Input.Param(":id")))
+ }),
)
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
@@ -124,8 +124,8 @@ func TestNamespaceFilter(t *testing.T) {
ctx.Output.Body([]byte("this is Filter"))
}).
Get("/user/:id", func(ctx *context.Context) {
- ctx.Output.Body([]byte(ctx.Input.Param(":id")))
- })
+ ctx.Output.Body([]byte(ctx.Input.Param(":id")))
+ })
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != "this is Filter" {
diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go
index da0ee8ab..8119b70b 100644
--- a/orm/cmd_utils.go
+++ b/orm/cmd_utils.go
@@ -52,9 +52,15 @@ checkColumn:
case TypeBooleanField:
col = T["bool"]
case TypeCharField:
- col = fmt.Sprintf(T["string"], fieldSize)
+ if al.Driver == DRPostgres && fi.toText {
+ col = T["string-text"]
+ } else {
+ col = fmt.Sprintf(T["string"], fieldSize)
+ }
case TypeTextField:
col = T["string-text"]
+ case TypeTimeField:
+ col = T["time.Time-clock"]
case TypeDateField:
col = T["time.Time-date"]
case TypeDateTimeField:
@@ -88,6 +94,18 @@ checkColumn:
} else {
col = fmt.Sprintf(s, fi.digits, fi.decimals)
}
+ case TypeJSONField:
+ if al.Driver != DRPostgres {
+ fieldType = TypeCharField
+ goto checkColumn
+ }
+ col = T["json"]
+ case TypeJsonbField:
+ if al.Driver != DRPostgres {
+ fieldType = TypeCharField
+ goto checkColumn
+ }
+ col = T["jsonb"]
case RelForeignKey, RelOneToOne:
fieldType = fi.relModelInfo.fields.pk.fieldType
fieldSize = fi.relModelInfo.fields.pk.size
@@ -264,7 +282,7 @@ func getColumnDefault(fi *fieldInfo) string {
// These defaults will be useful if there no config value orm:"default" and NOT NULL is on
switch fi.fieldType {
- case TypeDateField, TypeDateTimeField, TypeTextField:
+ case TypeTimeField, TypeDateField, TypeDateTimeField, TypeTextField:
return v
case TypeBitField, TypeSmallIntegerField, TypeIntegerField,
@@ -276,6 +294,8 @@ func getColumnDefault(fi *fieldInfo) string {
case TypeBooleanField:
t = " DEFAULT %s "
d = "FALSE"
+ case TypeJSONField, TypeJsonbField:
+ d = "{}"
}
if fi.colDefault {
diff --git a/orm/db.go b/orm/db.go
index 314c3535..9964e263 100644
--- a/orm/db.go
+++ b/orm/db.go
@@ -24,6 +24,7 @@ import (
)
const (
+ formatTime = "15:04:05"
formatDate = "2006-01-02"
formatDateTime = "2006-01-02 15:04:05"
)
@@ -71,12 +72,12 @@ type dbBase struct {
var _ dbBaser = new(dbBase)
// get struct columns values as interface slice.
-func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, err error) {
- var columns []string
-
- if names != nil {
- columns = *names
+func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, autoFields []string, err error) {
+ if names == nil {
+ ns := make([]string, 0, len(cols))
+ names = &ns
}
+ values = make([]interface{}, 0, len(cols))
for _, column := range cols {
var fi *fieldInfo
@@ -90,18 +91,24 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
}
value, err := d.collectFieldValue(mi, fi, ind, insert, tz)
if err != nil {
- return nil, err
+ return nil, nil, err
}
- if names != nil {
- columns = append(columns, column)
+ // ignore empty value auto field
+ if insert && fi.auto {
+ if fi.fieldType&IsPositiveIntegerField > 0 {
+ if vu, ok := value.(uint64); !ok || vu == 0 {
+ continue
+ }
+ } else {
+ if vu, ok := value.(int64); !ok || vu == 0 {
+ continue
+ }
+ }
+ autoFields = append(autoFields, fi.column)
}
- values = append(values, value)
- }
-
- if names != nil {
- *names = columns
+ *names, values = append(*names, column), append(values, value)
}
return
@@ -134,7 +141,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
} else {
value = field.Bool()
}
- case TypeCharField, TypeTextField:
+ case TypeCharField, TypeTextField, TypeJSONField, TypeJsonbField:
if ns, ok := field.Interface().(sql.NullString); ok {
value = nil
if ns.Valid {
@@ -169,7 +176,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
value = field.Float()
}
}
- case TypeDateField, TypeDateTimeField:
+ case TypeTimeField, TypeDateField, TypeDateTimeField:
value = field.Interface()
if t, ok := value.(time.Time); ok {
d.ins.TimeToDB(&t, tz)
@@ -181,7 +188,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
}
default:
switch {
- case fi.fieldType&IsPostiveIntegerField > 0:
+ case fi.fieldType&IsPositiveIntegerField > 0:
if field.Kind() == reflect.Ptr {
if field.IsNil() {
value = nil
@@ -223,7 +230,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
}
}
switch fi.fieldType {
- case TypeDateField, TypeDateTimeField:
+ case TypeTimeField, TypeDateField, TypeDateTimeField:
if fi.autoNow || fi.autoNowAdd && insert {
if insert {
if t, ok := value.(time.Time); ok && !t.IsZero() {
@@ -236,10 +243,21 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
if fi.isFielder {
f := field.Addr().Interface().(Fielder)
f.SetRaw(tnow.In(DefaultTimeLoc))
+ } else if field.Kind() == reflect.Ptr {
+ v := tnow.In(DefaultTimeLoc)
+ field.Set(reflect.ValueOf(&v))
} else {
field.Set(reflect.ValueOf(tnow.In(DefaultTimeLoc)))
}
}
+ case TypeJSONField, TypeJsonbField:
+ if s, ok := value.(string); (ok && len(s) == 0) || value == nil {
+ if fi.colDefault && fi.initial.Exist() {
+ value = fi.initial.String()
+ } else {
+ value = nil
+ }
+ }
}
}
return value, nil
@@ -273,7 +291,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
// insert struct with prepared statement and given struct reflect value.
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
- values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
+ values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
if err != nil {
return 0, err
}
@@ -300,7 +318,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
if len(cols) > 0 {
var err error
whereCols = make([]string, 0, len(cols))
- args, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz)
+ args, _, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz)
if err != nil {
return err
}
@@ -349,13 +367,21 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
// execute insert sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
- names := make([]string, 0, len(mi.fields.dbcols)-1)
- values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
+ names := make([]string, 0, len(mi.fields.dbcols))
+ values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz)
if err != nil {
return 0, err
}
- return d.InsertValue(q, mi, false, names, values)
+ id, err := d.InsertValue(q, mi, false, names, values)
+ if err != nil {
+ return 0, err
+ }
+
+ if len(autoFields) > 0 {
+ err = d.ins.setval(q, mi, autoFields)
+ }
+ return id, err
}
// multi-insert sql with given slice struct reflect.Value.
@@ -369,7 +395,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
// typ := reflect.Indirect(mi.addrField).Type()
- length := sind.Len()
+ length, autoFields := sind.Len(), make([]string, 0, 1)
for i := 1; i <= length; i++ {
@@ -381,16 +407,18 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
// }
if i == 1 {
- vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
+ var (
+ vus []interface{}
+ err error
+ )
+ vus, autoFields, err = d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz)
if err != nil {
return cnt, err
}
values = make([]interface{}, bulk*len(vus))
nums += copy(values, vus)
-
} else {
-
- vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
+ vus, _, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, nil, tz)
if err != nil {
return cnt, err
}
@@ -412,7 +440,12 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
}
}
- return cnt, nil
+ var err error
+ if len(autoFields) > 0 {
+ err = d.ins.setval(q, mi, autoFields)
+ }
+
+ return cnt, err
}
// execute insert sql with given struct and given values.
@@ -472,7 +505,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
setNames = make([]string, 0, len(cols))
}
- setValues, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz)
+ setValues, _, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz)
if err != nil {
return 0, err
}
@@ -516,7 +549,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
}
if num > 0 {
if mi.fields.pk.auto {
- if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
+ if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(0)
} else {
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0)
@@ -1071,13 +1104,13 @@ setValue:
}
value = b
}
- case fieldType == TypeCharField || fieldType == TypeTextField:
+ case fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField:
if str == nil {
value = ToStr(val)
} else {
value = str.String()
}
- case fieldType == TypeDateField || fieldType == TypeDateTimeField:
+ case fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField:
if str == nil {
switch t := val.(type) {
case time.Time:
@@ -1097,15 +1130,20 @@ setValue:
if len(s) >= 19 {
s = s[:19]
t, err = time.ParseInLocation(formatDateTime, s, tz)
- } else {
+ } else if len(s) >= 10 {
if len(s) > 10 {
s = s[:10]
}
t, err = time.ParseInLocation(formatDate, s, tz)
+ } else if len(s) >= 8 {
+ if len(s) > 8 {
+ s = s[:8]
+ }
+ t, err = time.ParseInLocation(formatTime, s, tz)
}
t = t.In(DefaultTimeLoc)
- if err != nil && s != "0000-00-00" && s != "0000-00-00 00:00:00" {
+ if err != nil && s != "00:00:00" && s != "0000-00-00" && s != "0000-00-00 00:00:00" {
tErr = err
goto end
}
@@ -1140,7 +1178,7 @@ setValue:
tErr = err
goto end
}
- if fieldType&IsPostiveIntegerField > 0 {
+ if fieldType&IsPositiveIntegerField > 0 {
v, _ := str.Uint64()
value = v
} else {
@@ -1212,7 +1250,7 @@ setValue:
field.SetBool(value.(bool))
}
}
- case fieldType == TypeCharField || fieldType == TypeTextField:
+ case fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField:
if isNative {
if ns, ok := field.Interface().(sql.NullString); ok {
if value == nil {
@@ -1234,12 +1272,18 @@ setValue:
field.SetString(value.(string))
}
}
- case fieldType == TypeDateField || fieldType == TypeDateTimeField:
+ case fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField:
if isNative {
if value == nil {
value = time.Time{}
+ } else if field.Kind() == reflect.Ptr {
+ if value != nil {
+ v := value.(time.Time)
+ field.Set(reflect.ValueOf(&v))
+ }
+ } else {
+ field.Set(reflect.ValueOf(value))
}
- field.Set(reflect.ValueOf(value))
}
case fieldType == TypePositiveBitField && field.Kind() == reflect.Ptr:
if value != nil {
@@ -1292,7 +1336,7 @@ setValue:
field.Set(reflect.ValueOf(&v))
}
case fieldType&IsIntegerField > 0:
- if fieldType&IsPostiveIntegerField > 0 {
+ if fieldType&IsPositiveIntegerField > 0 {
if isNative {
if value == nil {
value = uint64(0)
@@ -1440,7 +1484,11 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
sels := strings.Join(cols, ", ")
- query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s%s", sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit)
+ sqlSelect := "SELECT"
+ if qs.distinct {
+ sqlSelect += " DISTINCT"
+ }
+ query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit)
d.ins.ReplaceMarks(&query)
@@ -1562,6 +1610,11 @@ func (d *dbBase) HasReturningID(*modelInfo, *string) bool {
return false
}
+// sync auto key
+func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error {
+ return nil
+}
+
// convert time from db.
func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) {
*t = t.In(tz)
diff --git a/orm/db_postgres.go b/orm/db_postgres.go
index 7dbef95a..e972c4a2 100644
--- a/orm/db_postgres.go
+++ b/orm/db_postgres.go
@@ -56,6 +56,8 @@ var postgresTypes = map[string]string{
"uint64": `bigint CHECK("%COL%" >= 0)`,
"float64": "double precision",
"float64-decimal": "numeric(%d, %d)",
+ "json": "json",
+ "jsonb": "jsonb",
}
// postgresql dbBaser.
@@ -123,14 +125,35 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
}
// make returning sql support for postgresql.
-func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) {
- if mi.fields.pk.auto {
- if query != nil {
- *query = fmt.Sprintf(`%s RETURNING "%s"`, *query, mi.fields.pk.column)
- }
- has = true
+func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool {
+ fi := mi.fields.pk
+ if fi.fieldType&IsPositiveIntegerField == 0 && fi.fieldType&IsIntegerField == 0 {
+ return false
}
- return
+
+ if query != nil {
+ *query = fmt.Sprintf(`%s RETURNING "%s"`, *query, fi.column)
+ }
+ return true
+}
+
+// sync auto key
+func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error {
+ if len(autoFields) == 0 {
+ return nil
+ }
+
+ Q := d.ins.TableQuote()
+ for _, name := range autoFields {
+ query := fmt.Sprintf("SELECT setval(pg_get_serial_sequence('%s', '%s'), (SELECT MAX(%s%s%s) FROM %s%s%s));",
+ mi.table, name,
+ Q, name, Q,
+ Q, mi.table, Q)
+ if _, err := db.Exec(query); err != nil {
+ return err
+ }
+ }
+ return nil
}
// show table sql for postgresql.
diff --git a/orm/db_utils.go b/orm/db_utils.go
index c97caf36..cf465d02 100644
--- a/orm/db_utils.go
+++ b/orm/db_utils.go
@@ -33,13 +33,13 @@ func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interfac
fi := mi.fields.pk
v := ind.FieldByIndex(fi.fieldIndex)
- if fi.fieldType&IsPostiveIntegerField > 0 {
+ if fi.fieldType&IsPositiveIntegerField > 0 {
vu := v.Uint()
exist = vu > 0
value = vu
} else if fi.fieldType&IsIntegerField > 0 {
vu := v.Int()
- exist = vu > 0
+ exist = true
value = vu
} else {
vu := v.String()
@@ -74,24 +74,32 @@ outFor:
case reflect.String:
v := val.String()
if fi != nil {
- if fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField {
+ if fi.fieldType == TypeTimeField || fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField {
var t time.Time
var err error
if len(v) >= 19 {
s := v[:19]
t, err = time.ParseInLocation(formatDateTime, s, DefaultTimeLoc)
- } else {
+ } else if len(v) >= 10 {
s := v
if len(v) > 10 {
s = v[:10]
}
t, err = time.ParseInLocation(formatDate, s, tz)
+ } else {
+ s := v
+ if len(s) > 8 {
+ s = v[:8]
+ }
+ t, err = time.ParseInLocation(formatTime, s, tz)
}
if err == nil {
if fi.fieldType == TypeDateField {
v = t.In(tz).Format(formatDate)
- } else {
+ } else if fi.fieldType == TypeDateTimeField {
v = t.In(tz).Format(formatDateTime)
+ } else {
+ v = t.In(tz).Format(formatTime)
}
}
}
@@ -137,8 +145,10 @@ outFor:
if v, ok := arg.(time.Time); ok {
if fi != nil && fi.fieldType == TypeDateField {
arg = v.In(tz).Format(formatDate)
- } else {
+ } else if fi.fieldType == TypeDateTimeField {
arg = v.In(tz).Format(formatDateTime)
+ } else {
+ arg = v.In(tz).Format(formatTime)
}
} else {
typ := val.Type()
diff --git a/orm/models_boot.go b/orm/models_boot.go
index 3690557b..c9905330 100644
--- a/orm/models_boot.go
+++ b/orm/models_boot.go
@@ -66,7 +66,7 @@ func registerModel(prefix string, model interface{}) {
}
if info.fields.pk == nil {
- fmt.Printf(" `%s` need a primary key field\n", name)
+ fmt.Printf(" `%s` need a primary key field, default use 'id' if not set\n", name)
os.Exit(2)
}
diff --git a/orm/models_fields.go b/orm/models_fields.go
index a8cf8e4f..57820600 100644
--- a/orm/models_fields.go
+++ b/orm/models_fields.go
@@ -25,6 +25,7 @@ const (
TypeBooleanField = 1 << iota
TypeCharField
TypeTextField
+ TypeTimeField
TypeDateField
TypeDateTimeField
TypeBitField
@@ -37,6 +38,8 @@ const (
TypePositiveBigIntegerField
TypeFloatField
TypeDecimalField
+ TypeJSONField
+ TypeJsonbField
RelForeignKey
RelOneToOne
RelManyToMany
@@ -46,10 +49,10 @@ const (
// Define some logic enum
const (
- IsIntegerField = ^-TypePositiveBigIntegerField >> 4 << 5
- IsPostiveIntegerField = ^-TypePositiveBigIntegerField >> 8 << 9
- IsRelField = ^-RelReverseMany >> 14 << 15
- IsFieldType = ^-RelReverseMany<<1 + 1
+ IsIntegerField = ^-TypePositiveBigIntegerField >> 5 << 6
+ IsPositiveIntegerField = ^-TypePositiveBigIntegerField >> 9 << 10
+ IsRelField = ^-RelReverseMany >> 17 << 18
+ IsFieldType = ^-RelReverseMany<<1 + 1
)
// BooleanField A true/false field.
@@ -145,6 +148,65 @@ func (e *CharField) RawValue() interface{} {
// verify CharField implement Fielder
var _ Fielder = new(CharField)
+// TimeField A time, represented in go by a time.Time instance.
+// only time values like 10:00:00
+// Has a few extra, optional attr tag:
+//
+// auto_now:
+// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps.
+// Note that the current date is always used; it’s not just a default value that you can override.
+//
+// auto_now_add:
+// Automatically set the field to now when the object is first created. Useful for creation of timestamps.
+// Note that the current date is always used; it’s not just a default value that you can override.
+//
+// eg: `orm:"auto_now"` or `orm:"auto_now_add"`
+type TimeField time.Time
+
+// Value return the time.Time
+func (e TimeField) Value() time.Time {
+ return time.Time(e)
+}
+
+// Set set the TimeField's value
+func (e *TimeField) Set(d time.Time) {
+ *e = TimeField(d)
+}
+
+// String convert time to string
+func (e *TimeField) String() string {
+ return e.Value().String()
+}
+
+// FieldType return enum type Date
+func (e *TimeField) FieldType() int {
+ return TypeDateField
+}
+
+// SetRaw convert the interface to time.Time. Allow string and time.Time
+func (e *TimeField) SetRaw(value interface{}) error {
+ switch d := value.(type) {
+ case time.Time:
+ e.Set(d)
+ case string:
+ v, err := timeParse(d, formatTime)
+ if err != nil {
+ e.Set(v)
+ }
+ return err
+ default:
+ return fmt.Errorf(" unknown value `%s`", value)
+ }
+ return nil
+}
+
+// RawValue return time value
+func (e *TimeField) RawValue() interface{} {
+ return e.Value()
+}
+
+var _ Fielder = new(TimeField)
+
// DateField A date, represented in go by a time.Time instance.
// only date values like 2006-01-02
// Has a few extra, optional attr tag:
@@ -627,3 +689,87 @@ func (e *TextField) RawValue() interface{} {
// verify TextField implement Fielder
var _ Fielder = new(TextField)
+
+// JSONField postgres json field.
+type JSONField string
+
+// Value return JSONField value
+func (j JSONField) Value() string {
+ return string(j)
+}
+
+// Set the JSONField value
+func (j *JSONField) Set(d string) {
+ *j = JSONField(d)
+}
+
+// String convert JSONField to string
+func (j *JSONField) String() string {
+ return j.Value()
+}
+
+// FieldType return enum type
+func (j *JSONField) FieldType() int {
+ return TypeJSONField
+}
+
+// SetRaw convert interface string to string
+func (j *JSONField) SetRaw(value interface{}) error {
+ switch d := value.(type) {
+ case string:
+ j.Set(d)
+ default:
+ return fmt.Errorf(" unknown value `%s`", value)
+ }
+ return nil
+}
+
+// RawValue return JSONField value
+func (j *JSONField) RawValue() interface{} {
+ return j.Value()
+}
+
+// verify JSONField implement Fielder
+var _ Fielder = new(JSONField)
+
+// JsonbField postgres json field.
+type JsonbField string
+
+// Value return JsonbField value
+func (j JsonbField) Value() string {
+ return string(j)
+}
+
+// Set the JsonbField value
+func (j *JsonbField) Set(d string) {
+ *j = JsonbField(d)
+}
+
+// String convert JsonbField to string
+func (j *JsonbField) String() string {
+ return j.Value()
+}
+
+// FieldType return enum type
+func (j *JsonbField) FieldType() int {
+ return TypeJsonbField
+}
+
+// SetRaw convert interface string to string
+func (j *JsonbField) SetRaw(value interface{}) error {
+ switch d := value.(type) {
+ case string:
+ j.Set(d)
+ default:
+ return fmt.Errorf(" unknown value `%s`", value)
+ }
+ return nil
+}
+
+// RawValue return JsonbField value
+func (j *JsonbField) RawValue() interface{} {
+ return j.Value()
+}
+
+// verify JsonbField implement Fielder
+var _ Fielder = new(JsonbField)
diff --git a/orm/models_info_f.go b/orm/models_info_f.go
index 996a2f40..be6c9aa4 100644
--- a/orm/models_info_f.go
+++ b/orm/models_info_f.go
@@ -119,6 +119,7 @@ type fieldInfo struct {
colDefault bool
initial StrTo
size int
+ toText bool
autoNow bool
autoNowAdd bool
rel bool
@@ -239,8 +240,15 @@ checkType:
if err != nil {
goto end
}
- if fieldType == TypeCharField && tags["type"] == "text" {
- fieldType = TypeTextField
+ if fieldType == TypeCharField {
+ switch tags["type"] {
+ case "text":
+ fieldType = TypeTextField
+ case "json":
+ fieldType = TypeJSONField
+ case "jsonb":
+ fieldType = TypeJsonbField
+ }
}
if fieldType == TypeFloatField && (digits != "" || decimals != "") {
fieldType = TypeDecimalField
@@ -248,6 +256,9 @@ checkType:
if fieldType == TypeDateTimeField && tags["type"] == "date" {
fieldType = TypeDateField
}
+ if fieldType == TypeTimeField && tags["type"] == "time" {
+ fieldType = TypeTimeField
+ }
}
switch fieldType {
@@ -339,7 +350,7 @@ checkType:
switch fieldType {
case TypeBooleanField:
- case TypeCharField:
+ case TypeCharField, TypeJSONField, TypeJsonbField:
if size != "" {
v, e := StrTo(size).Int32()
if e != nil {
@@ -349,11 +360,12 @@ checkType:
}
} else {
fi.size = 255
+ fi.toText = true
}
case TypeTextField:
fi.index = false
fi.unique = false
- case TypeDateField, TypeDateTimeField:
+ case TypeTimeField, TypeDateField, TypeDateTimeField:
if attrs["auto_now"] {
fi.autoNow = true
} else if attrs["auto_now_add"] {
@@ -406,7 +418,7 @@ checkType:
fi.index = false
}
- if fi.auto || fi.pk || fi.unique || fieldType == TypeDateField || fieldType == TypeDateTimeField {
+ if fi.auto || fi.pk || fi.unique || fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField {
// can not set default
initial.Clear()
}
diff --git a/orm/models_test.go b/orm/models_test.go
index ffb16ea0..462370b2 100644
--- a/orm/models_test.go
+++ b/orm/models_test.go
@@ -78,40 +78,43 @@ func (e *SliceStringField) RawValue() interface{} {
var _ Fielder = new(SliceStringField)
// A json field.
-type JSONField struct {
+type JSONFieldTest struct {
Name string
Data string
}
-func (e *JSONField) String() string {
+func (e *JSONFieldTest) String() string {
data, _ := json.Marshal(e)
return string(data)
}
-func (e *JSONField) FieldType() int {
+func (e *JSONFieldTest) FieldType() int {
return TypeTextField
}
-func (e *JSONField) SetRaw(value interface{}) error {
+func (e *JSONFieldTest) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
return json.Unmarshal([]byte(d), e)
default:
- return fmt.Errorf(" unknown value `%v`", value)
+ return fmt.Errorf(" unknown value `%v`", value)
}
}
-func (e *JSONField) RawValue() interface{} {
+func (e *JSONFieldTest) RawValue() interface{} {
return e.String()
}
-var _ Fielder = new(JSONField)
+var _ Fielder = new(JSONFieldTest)
type Data struct {
ID int `orm:"column(id)"`
Boolean bool
Char string `orm:"size(50)"`
Text string `orm:"type(text)"`
+ JSON string `orm:"type(json);default({\"name\":\"json\"})"`
+ Jsonb string `orm:"type(jsonb)"`
+ Time time.Time `orm:"type(time)"`
Date time.Time `orm:"type(date)"`
DateTime time.Time `orm:"column(datetime)"`
Byte byte
@@ -136,6 +139,9 @@ type DataNull struct {
Boolean bool `orm:"null"`
Char string `orm:"null;size(50)"`
Text string `orm:"null;type(text)"`
+ JSON string `orm:"type(json);null"`
+ Jsonb string `orm:"type(jsonb);null"`
+ Time time.Time `orm:"null;type(time)"`
Date time.Time `orm:"null;type(date)"`
DateTime time.Time `orm:"null;column(datetime)"`
Byte byte `orm:"null"`
@@ -175,6 +181,9 @@ type DataNull struct {
Float32Ptr *float32 `orm:"null"`
Float64Ptr *float64 `orm:"null"`
DecimalPtr *float64 `orm:"digits(8);decimals(4);null"`
+ TimePtr *time.Time `orm:"null;type(time)"`
+ DatePtr *time.Time `orm:"null;type(date)"`
+ DateTimePtr *time.Time `orm:"null"`
}
type String string
@@ -237,7 +246,7 @@ type User struct {
ShouldSkip string `orm:"-"`
Nums int
Langs SliceStringField `orm:"size(100)"`
- Extra JSONField `orm:"type(text)"`
+ Extra JSONFieldTest `orm:"type(text)"`
unexport bool `orm:"-"`
unexportBool bool
}
@@ -375,6 +384,28 @@ func NewInLine() *InLine {
return new(InLine)
}
+type InLineOneToOne struct {
+ // Common Fields
+ ModelBase
+
+ Note string
+ InLine *InLine `orm:"rel(fk);column(inline)"`
+}
+
+func NewInLineOneToOne() *InLineOneToOne {
+ return new(InLineOneToOne)
+}
+
+type IntegerPk struct {
+ ID int64 `orm:"pk"`
+ Value string
+}
+
+type UintPk struct {
+ ID uint32 `orm:"pk"`
+ Name string
+}
+
var DBARGS = struct {
Driver string
Source string
diff --git a/orm/models_utils.go b/orm/models_utils.go
index ec11d516..4c4b0f24 100644
--- a/orm/models_utils.go
+++ b/orm/models_utils.go
@@ -137,6 +137,8 @@ func getFieldType(val reflect.Value) (ft int, err error) {
ft = TypeBooleanField
case reflect.TypeOf(new(string)):
ft = TypeCharField
+ case reflect.TypeOf(new(time.Time)):
+ ft = TypeDateTimeField
default:
elm := reflect.Indirect(val)
switch elm.Kind() {
@@ -192,10 +194,10 @@ func parseStructTag(data string, attrs *map[string]bool, tags *map[string]string
tag := make(map[string]string)
for _, v := range strings.Split(data, defaultStructTagDelim) {
v = strings.TrimSpace(v)
- if supportTag[v] == 1 {
- attr[v] = true
+ if t := strings.ToLower(v); supportTag[t] == 1 {
+ attr[t] = true
} else if i := strings.Index(v, "("); i > 0 && strings.Index(v, ")") == len(v)-1 {
- name := v[:i]
+ name := t[:i]
if supportTag[name] == 2 {
v = v[i+1 : len(v)-1]
tag[name] = v
diff --git a/orm/orm.go b/orm/orm.go
index 0ffb6b86..5e43ae59 100644
--- a/orm/orm.go
+++ b/orm/orm.go
@@ -140,7 +140,14 @@ func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, i
return (err == nil), id, err
}
- return false, ind.FieldByIndex(mi.fields.pk.fieldIndex).Int(), err
+ id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex)
+ if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
+ id = int64(vid.Uint())
+ } else {
+ id = vid.Int()
+ }
+
+ return false, id, err
}
// insert model data to database
@@ -159,7 +166,7 @@ func (o *orm) Insert(md interface{}) (int64, error) {
// set auto pk field
func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) {
if mi.fields.pk.auto {
- if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
+ if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id))
} else {
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(id)
@@ -184,7 +191,7 @@ func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) {
if bulk <= 1 {
for i := 0; i < sind.Len(); i++ {
- ind := sind.Index(i)
+ ind := reflect.Indirect(sind.Index(i))
mi, _ := o.getMiInd(ind.Interface(), false)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
if err != nil {
diff --git a/orm/orm_log.go b/orm/orm_log.go
index 712eb219..54723273 100644
--- a/orm/orm_log.go
+++ b/orm/orm_log.go
@@ -31,7 +31,7 @@ type Log struct {
// NewLog set io.Writer to create a Logger.
func NewLog(out io.Writer) *Log {
d := new(Log)
- d.Logger = log.New(out, "[ORM]", 1e9)
+ d.Logger = log.New(out, "[ORM]", log.LstdFlags)
return d
}
diff --git a/orm/orm_object.go b/orm/orm_object.go
index 8a5d85e2..de3181ce 100644
--- a/orm/orm_object.go
+++ b/orm/orm_object.go
@@ -50,7 +50,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) {
}
if id > 0 {
if o.mi.fields.pk.auto {
- if o.mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
+ if o.mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetUint(uint64(id))
} else {
ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetInt(id)
diff --git a/orm/orm_test.go b/orm/orm_test.go
index d9fd7d51..b5973448 100644
--- a/orm/orm_test.go
+++ b/orm/orm_test.go
@@ -19,6 +19,7 @@ import (
"database/sql"
"fmt"
"io/ioutil"
+ "math"
"os"
"path/filepath"
"reflect"
@@ -33,6 +34,7 @@ var _ = os.PathSeparator
var (
testDate = formatDate + " -0700"
testDateTime = formatDateTime + " -0700"
+ testTime = formatTime + " -0700"
)
type argAny []interface{}
@@ -188,6 +190,9 @@ func TestSyncDb(t *testing.T) {
RegisterModel(new(Permission))
RegisterModel(new(GroupPermissions))
RegisterModel(new(InLine))
+ RegisterModel(new(InLineOneToOne))
+ RegisterModel(new(IntegerPk))
+ RegisterModel(new(UintPk))
err := RunSyncdb("default", true, Debug)
throwFail(t, err)
@@ -208,6 +213,9 @@ func TestRegisterModels(t *testing.T) {
RegisterModel(new(Permission))
RegisterModel(new(GroupPermissions))
RegisterModel(new(InLine))
+ RegisterModel(new(InLineOneToOne))
+ RegisterModel(new(IntegerPk))
+ RegisterModel(new(UintPk))
BootStrap()
@@ -233,6 +241,9 @@ var DataValues = map[string]interface{}{
"Boolean": true,
"Char": "char",
"Text": "text",
+ "JSON": `{"name":"json"}`,
+ "Jsonb": `{"name": "jsonb"}`,
+ "Time": time.Now(),
"Date": time.Now(),
"DateTime": time.Now(),
"Byte": byte(1<<8 - 1),
@@ -257,10 +268,12 @@ func TestDataTypes(t *testing.T) {
ind := reflect.Indirect(reflect.ValueOf(&d))
for name, value := range DataValues {
+ if name == "JSON" {
+ continue
+ }
e := ind.FieldByName(name)
e.Set(reflect.ValueOf(value))
}
-
id, err := dORM.Insert(&d)
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
@@ -281,6 +294,9 @@ func TestDataTypes(t *testing.T) {
case "DateTime":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
+ case "Time":
+ vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime)
+ value = value.(time.Time).In(DefaultTimeLoc).Format(testTime)
}
throwFail(t, AssertIs(vu == value, true), value, vu)
}
@@ -299,10 +315,18 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
+ data := `{"ok":1,"data":{"arr":[1,2],"msg":"gopher"}}`
+ d = DataNull{ID: 1, JSON: data}
+ num, err := dORM.Update(&d)
+ throwFail(t, err)
+ throwFail(t, AssertIs(num, 1))
+
d = DataNull{ID: 1}
err = dORM.Read(&d)
throwFail(t, err)
+ throwFail(t, AssertIs(d.JSON, data))
+
throwFail(t, AssertIs(d.NullBool.Valid, false))
throwFail(t, AssertIs(d.NullString.Valid, false))
throwFail(t, AssertIs(d.NullInt64.Valid, false))
@@ -326,6 +350,9 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, AssertIs(d.Float32Ptr, nil))
throwFail(t, AssertIs(d.Float64Ptr, nil))
throwFail(t, AssertIs(d.DecimalPtr, nil))
+ throwFail(t, AssertIs(d.TimePtr, nil))
+ throwFail(t, AssertIs(d.DatePtr, nil))
+ throwFail(t, AssertIs(d.DateTimePtr, nil))
_, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec()
throwFail(t, err)
@@ -352,6 +379,9 @@ func TestNullDataTypes(t *testing.T) {
float32Ptr := float32(42.0)
float64Ptr := float64(42.0)
decimalPtr := float64(42.0)
+ timePtr := time.Now()
+ datePtr := time.Now()
+ dateTimePtr := time.Now()
d = DataNull{
DateTime: time.Now(),
@@ -377,6 +407,9 @@ func TestNullDataTypes(t *testing.T) {
Float32Ptr: &float32Ptr,
Float64Ptr: &float64Ptr,
DecimalPtr: &decimalPtr,
+ TimePtr: &timePtr,
+ DatePtr: &datePtr,
+ DateTimePtr: &dateTimePtr,
}
id, err = dORM.Insert(&d)
@@ -417,6 +450,9 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, AssertIs(*d.Float32Ptr, float32Ptr))
throwFail(t, AssertIs(*d.Float64Ptr, float64Ptr))
throwFail(t, AssertIs(*d.DecimalPtr, decimalPtr))
+ throwFail(t, AssertIs((*d.TimePtr).Format(testTime), timePtr.Format(testTime)))
+ throwFail(t, AssertIs((*d.DatePtr).Format(testDate), datePtr.Format(testDate)))
+ throwFail(t, AssertIs((*d.DateTimePtr).Format(testDateTime), dateTimePtr.Format(testDateTime)))
}
func TestDataCustomTypes(t *testing.T) {
@@ -1521,6 +1557,7 @@ func TestRawQueryRow(t *testing.T) {
Boolean bool
Char string
Text string
+ Time time.Time
Date time.Time
DateTime time.Time
Byte byte
@@ -1549,14 +1586,14 @@ func TestRawQueryRow(t *testing.T) {
Q := dDbBaser.TableQuote()
cols := []string{
- "id", "boolean", "char", "text", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32",
+ "id", "boolean", "char", "text", "time", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32",
"int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal",
}
sep := fmt.Sprintf("%s, %s", Q, Q)
query := fmt.Sprintf("SELECT %s%s%s FROM data WHERE id = ?", Q, strings.Join(cols, sep), Q)
var id int
values := []interface{}{
- &id, &Boolean, &Char, &Text, &Date, &DateTime, &Byte, &Rune, &Int, &Int8, &Int16, &Int32,
+ &id, &Boolean, &Char, &Text, &Time, &Date, &DateTime, &Byte, &Rune, &Int, &Int8, &Int16, &Int32,
&Int64, &Uint, &Uint8, &Uint16, &Uint32, &Uint64, &Float32, &Float64, &Decimal,
}
err := dORM.Raw(query, 1).QueryRow(values...)
@@ -1567,6 +1604,10 @@ func TestRawQueryRow(t *testing.T) {
switch col {
case "id":
throwFail(t, AssertIs(id, 1))
+ case "time":
+ v = v.(time.Time).In(DefaultTimeLoc)
+ value := dataValues[col].(time.Time).In(DefaultTimeLoc)
+ throwFail(t, AssertIs(v, value, testTime))
case "date":
v = v.(time.Time).In(DefaultTimeLoc)
value := dataValues[col].(time.Time).In(DefaultTimeLoc)
@@ -1614,6 +1655,9 @@ func TestQueryRows(t *testing.T) {
e := ind.FieldByName(name)
vu := e.Interface()
switch name {
+ case "Time":
+ vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime)
+ value = value.(time.Time).In(DefaultTimeLoc).Format(testTime)
case "Date":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate)
value = value.(time.Time).In(DefaultTimeLoc).Format(testDate)
@@ -1638,6 +1682,9 @@ func TestQueryRows(t *testing.T) {
e := ind.FieldByName(name)
vu := e.Interface()
switch name {
+ case "Time":
+ vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime)
+ value = value.(time.Time).In(DefaultTimeLoc).Format(testTime)
case "Date":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate)
value = value.(time.Time).In(DefaultTimeLoc).Format(testDate)
@@ -1959,3 +2006,171 @@ func TestInLine(t *testing.T) {
throwFail(t, AssertIs(il.Created.In(DefaultTimeLoc), inline.Created.In(DefaultTimeLoc), testDate))
throwFail(t, AssertIs(il.Updated.In(DefaultTimeLoc), inline.Updated.In(DefaultTimeLoc), testDateTime))
}
+
+func TestInLineOneToOne(t *testing.T) {
+ name := "121"
+ email := "121@go.com"
+ inline := NewInLine()
+ inline.Name = name
+ inline.Email = email
+
+ id, err := dORM.Insert(inline)
+ throwFail(t, err)
+ throwFail(t, AssertIs(id, 2))
+
+ note := "one2one"
+ il121 := NewInLineOneToOne()
+ il121.Note = note
+ il121.InLine = inline
+ _, err = dORM.Insert(il121)
+ throwFail(t, err)
+ throwFail(t, AssertIs(il121.ID, 1))
+
+ il := NewInLineOneToOne()
+ err = dORM.QueryTable(il).Filter("Id", 1).RelatedSel().One(il)
+
+ throwFail(t, err)
+ throwFail(t, AssertIs(il.Note, note))
+ throwFail(t, AssertIs(il.InLine.ID, id))
+ throwFail(t, AssertIs(il.InLine.Name, name))
+ throwFail(t, AssertIs(il.InLine.Email, email))
+
+ rinline := NewInLine()
+ err = dORM.QueryTable(rinline).Filter("InLineOneToOne__Id", 1).One(rinline)
+
+ throwFail(t, err)
+ throwFail(t, AssertIs(rinline.ID, id))
+ throwFail(t, AssertIs(rinline.Name, name))
+ throwFail(t, AssertIs(rinline.Email, email))
+}
+
+func TestIntegerPk(t *testing.T) {
+ its := []IntegerPk{
+ {ID: math.MinInt64, Value: "-"},
+ {ID: 0, Value: "0"},
+ {ID: math.MaxInt64, Value: "+"},
+ }
+
+ num, err := dORM.InsertMulti(len(its), its)
+ throwFail(t, err)
+ throwFail(t, AssertIs(num, len(its)))
+
+ for _, intPk := range its {
+ out := IntegerPk{ID: intPk.ID}
+ err = dORM.Read(&out)
+ throwFail(t, err)
+ throwFail(t, AssertIs(out.Value, intPk.Value))
+ }
+
+ num, err = dORM.InsertMulti(1, []*IntegerPk{&IntegerPk{
+ ID: 1, Value: "ok",
+ }})
+ throwFail(t, err)
+ throwFail(t, AssertIs(num, 1))
+}
+
+func TestInsertAuto(t *testing.T) {
+ u := &User{
+ UserName: "autoPre",
+ Email: "autoPre@gmail.com",
+ }
+
+ id, err := dORM.Insert(u)
+ throwFail(t, err)
+
+ id += 100
+ su := &User{
+ ID: int(id),
+ UserName: "auto",
+ Email: "auto@gmail.com",
+ }
+
+ nid, err := dORM.Insert(su)
+ throwFail(t, err)
+ throwFail(t, AssertIs(nid, id))
+
+ users := []User{
+ {ID: int(id + 100), UserName: "auto_100"},
+ {ID: int(id + 110), UserName: "auto_110"},
+ {ID: int(id + 120), UserName: "auto_120"},
+ }
+ num, err := dORM.InsertMulti(100, users)
+ throwFail(t, err)
+ throwFail(t, AssertIs(num, 3))
+
+ u = &User{
+ UserName: "auto_121",
+ }
+
+ nid, err = dORM.Insert(u)
+ throwFail(t, err)
+ throwFail(t, AssertIs(nid, id+120+1))
+}
+
+func TestUintPk(t *testing.T) {
+ name := "go"
+ u := &UintPk{
+ ID: 8,
+ Name: name,
+ }
+
+ created, pk, err := dORM.ReadOrCreate(u, "ID")
+ throwFail(t, err)
+ throwFail(t, AssertIs(created, true))
+ throwFail(t, AssertIs(u.Name, name))
+
+ nu := &UintPk{ID: 8}
+ created, pk, err = dORM.ReadOrCreate(nu, "ID")
+ throwFail(t, err)
+ throwFail(t, AssertIs(created, false))
+ throwFail(t, AssertIs(nu.ID, u.ID))
+ throwFail(t, AssertIs(pk, u.ID))
+ throwFail(t, AssertIs(nu.Name, name))
+
+ dORM.Delete(u)
+}
+
+func TestSnake(t *testing.T) {
+ cases := map[string]string{
+ "i": "i",
+ "I": "i",
+ "iD": "i_d",
+ "ID": "id",
+ "NO": "no",
+ "NOO": "noo",
+ "NOOooOOoo": "noo_oo_oo_oo",
+ "OrderNO": "order_no",
+ "tagName": "tag_name",
+ "tag_Name": "tag_name",
+ "tag_name": "tag_name",
+ "_tag_name": "_tag_name",
+ "tag_666name": "tag_666name",
+ "tag_666Name": "tag_666_name",
+ }
+ for name, want := range cases {
+ got := snakeString(name)
+ throwFail(t, AssertIs(got, want))
+ }
+}
+
+func TestIgnoreCaseTag(t *testing.T) {
+ type testTagModel struct {
+ ID int `orm:"pk"`
+ NOO string `orm:"column(n)"`
+ Name01 string `orm:"NULL"`
+ Name02 string `orm:"COLUMN(Name)"`
+ Name03 string `orm:"Column(name)"`
+ }
+ modelCache.clean()
+ RegisterModel(&testTagModel{})
+ info, ok := modelCache.get("test_tag_model")
+ throwFail(t, AssertIs(ok, true))
+ throwFail(t, AssertNot(info, nil))
+ if t == nil {
+ return
+ }
+ throwFail(t, AssertIs(info.fields.GetByName("NOO").column, "n"))
+ throwFail(t, AssertIs(info.fields.GetByName("Name01").null, true))
+ throwFail(t, AssertIs(info.fields.GetByName("Name02").column, "Name"))
+ throwFail(t, AssertIs(info.fields.GetByName("Name03").column, "name"))
+}
diff --git a/orm/types.go b/orm/types.go
index 41933dd1..cb55e71a 100644
--- a/orm/types.go
+++ b/orm/types.go
@@ -420,4 +420,5 @@ type dbBaser interface {
ShowColumnsQuery(string) string
IndexExists(dbQuerier, string, string) bool
collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error)
+ setval(dbQuerier, *modelInfo, []string) error
}
diff --git a/orm/utils.go b/orm/utils.go
index 99437c7b..e3cd8ad6 100644
--- a/orm/utils.go
+++ b/orm/utils.go
@@ -181,18 +181,36 @@ func ToInt64(value interface{}) (d int64) {
return
}
-// snake string, XxYy to xx_yy
+// snake string, XxYy to xx_yy , XxYY to xx_yy
func snakeString(s string) string {
data := make([]byte, 0, len(s)*2)
- j := false
num := len(s)
for i := 0; i < num; i++ {
d := s[i]
- if i > 0 && d >= 'A' && d <= 'Z' && j {
- data = append(data, '_')
- }
- if d != '_' {
- j = true
+ if i > 0 && d != '_' && s[i-1] != '_' {
+ need := false
+ // upper as 1, lower as 0
+ // XX -> 11 -> 11
+ // Xx -> 10 -> 10
+ // XxYyZZ -> 101011 -> 10_10_11
+ isUpper := d >= 'A' && d <= 'Z'
+ preIsUpper := s[i-1] >= 'A' && s[i-1] <= 'Z'
+ if isUpper {
+ // like : xxYy
+ if !preIsUpper {
+ need = true
+ }
+ } else {
+ if preIsUpper {
+ // ignore "Xy" in "xxXyy"
+ if i-2 >= 0 && s[i-2] >= 'A' && s[i-2] <= 'Z' {
+ need = true
+ }
+ }
+ }
+ if need {
+ data = append(data, '_')
+ }
}
data = append(data, d)
}
diff --git a/parser.go b/parser.go
index 46d02320..caa2b38b 100644
--- a/parser.go
+++ b/parser.go
@@ -23,10 +23,11 @@ import (
"go/token"
"io/ioutil"
"os"
- "path"
+ "path/filepath"
"sort"
"strings"
+ "github.com/astaxie/beego/logs"
"github.com/astaxie/beego/utils"
)
@@ -55,10 +56,11 @@ func init() {
}
func parserPkg(pkgRealpath, pkgpath string) error {
- rep := strings.NewReplacer("/", "_", ".", "_")
- commentFilename = coomentPrefix + rep.Replace(pkgpath) + ".go"
+ rep := strings.NewReplacer("\\", "_", "/", "_", ".", "_")
+ commentFilename, _ = filepath.Rel(AppPath, pkgRealpath)
+ commentFilename = coomentPrefix + rep.Replace(commentFilename) + ".go"
if !compareFile(pkgRealpath) {
- Info(pkgRealpath + " no changed")
+ logs.Info(pkgRealpath + " no changed")
return nil
}
genInfoList = make(map[string][]ControllerComments)
@@ -86,7 +88,7 @@ func parserPkg(pkgRealpath, pkgpath string) error {
}
}
}
- genRouterCode()
+ genRouterCode(pkgRealpath)
savetoFile(pkgRealpath)
return nil
}
@@ -129,9 +131,9 @@ func parserComments(comments *ast.CommentGroup, funcName, controllerName, pkgpat
return nil
}
-func genRouterCode() {
- os.Mkdir(path.Join(AppPath, "routers"), 0755)
- Info("generate router from comments")
+func genRouterCode(pkgRealpath string) {
+ os.Mkdir(getRouterDir(pkgRealpath), 0755)
+ logs.Info("generate router from comments")
var (
globalinfo string
sortKey []string
@@ -164,15 +166,15 @@ func genRouterCode() {
globalinfo = globalinfo + `
beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"],
beego.ControllerComments{
- "` + strings.TrimSpace(c.Method) + `",
- ` + "`" + c.Router + "`" + `,
- ` + allmethod + `,
- ` + params + `})
+ Method: "` + strings.TrimSpace(c.Method) + `",
+ ` + "Router: `" + c.Router + "`" + `,
+ AllowHTTPMethods: ` + allmethod + `,
+ Params: ` + params + `})
`
}
}
if globalinfo != "" {
- f, err := os.Create(path.Join(AppPath, "routers", commentFilename))
+ f, err := os.Create(filepath.Join(getRouterDir(pkgRealpath), commentFilename))
if err != nil {
panic(err)
}
@@ -182,7 +184,7 @@ func genRouterCode() {
}
func compareFile(pkgRealpath string) bool {
- if !utils.FileExists(path.Join(AppPath, "routers", commentFilename)) {
+ if !utils.FileExists(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) {
return true
}
if utils.FileExists(lastupdateFilename) {
@@ -229,3 +231,19 @@ func getpathTime(pkgRealpath string) (lastupdate int64, err error) {
}
return lastupdate, nil
}
+
+func getRouterDir(pkgRealpath string) string {
+ dir := filepath.Dir(pkgRealpath)
+ for {
+ d := filepath.Join(dir, "routers")
+ if utils.FileExists(d) {
+ return d
+ }
+
+ if r, _ := filepath.Rel(dir, AppPath); r == "." {
+ return d
+ }
+ // Parent dir.
+ dir = filepath.Dir(dir)
+ }
+}
diff --git a/router.go b/router.go
index d0bf534f..960cd104 100644
--- a/router.go
+++ b/router.go
@@ -28,6 +28,7 @@ import (
"time"
beecontext "github.com/astaxie/beego/context"
+ "github.com/astaxie/beego/logs"
"github.com/astaxie/beego/toolbox"
"github.com/astaxie/beego/utils"
)
@@ -114,7 +115,7 @@ type controllerInfo struct {
type ControllerRegister struct {
routers map[string]*Tree
enableFilter bool
- filters map[int][]*FilterRouter
+ filters [FinishRouter + 1][]*FilterRouter
pool sync.Pool
}
@@ -122,7 +123,6 @@ type ControllerRegister struct {
func NewControllerRegister() *ControllerRegister {
cr := &ControllerRegister{
routers: make(map[string]*Tree),
- filters: make(map[int][]*FilterRouter),
}
cr.pool.New = func() interface{} {
return beecontext.NewContext()
@@ -408,7 +408,6 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface)
// InsertFilter Add a FilterFunc with pattern rule and action constant.
// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute)
func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error {
-
mr := new(FilterRouter)
mr.tree = NewTree()
mr.pattern = pattern
@@ -426,9 +425,13 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter
}
// add Filter into
-func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) error {
- p.filters[pos] = append(p.filters[pos], mr)
+func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) {
+ if pos < BeforeStatic || pos > FinishRouter {
+ err = fmt.Errorf("can not find your filter postion")
+ return
+ }
p.enableFilter = true
+ p.filters[pos] = append(p.filters[pos], mr)
return nil
}
@@ -437,11 +440,11 @@ func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) error
func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) string {
paths := strings.Split(endpoint, ".")
if len(paths) <= 1 {
- Warn("urlfor endpoint must like path.controller.method")
+ logs.Warn("urlfor endpoint must like path.controller.method")
return ""
}
if len(values)%2 != 0 {
- Warn("urlfor params must key-value pair")
+ logs.Warn("urlfor params must key-value pair")
return ""
}
params := make(map[string]string)
@@ -577,20 +580,16 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin
return false, ""
}
-func (p *ControllerRegister) execFilter(context *beecontext.Context, pos int, urlPath string) (started bool) {
- if p.enableFilter {
- if l, ok := p.filters[pos]; ok {
- for _, filterR := range l {
- if filterR.returnOnOutput && context.ResponseWriter.Started {
- return true
- }
- if ok := filterR.ValidRouter(urlPath, context); ok {
- filterR.filterFunc(context)
- }
- if filterR.returnOnOutput && context.ResponseWriter.Started {
- return true
- }
- }
+func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath string, pos int) (started bool) {
+ for _, filterR := range p.filters[pos] {
+ if filterR.returnOnOutput && context.ResponseWriter.Started {
+ return true
+ }
+ if ok := filterR.ValidRouter(urlPath, context); ok {
+ filterR.filterFunc(context)
+ }
+ if filterR.returnOnOutput && context.ResponseWriter.Started {
+ return true
}
}
return false
@@ -617,11 +616,10 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
context.Output.Header("Server", BConfig.ServerName)
}
- var urlPath string
+ var urlPath = r.URL.Path
+
if !BConfig.RouterCaseSensitive {
- urlPath = strings.ToLower(r.URL.Path)
- } else {
- urlPath = r.URL.Path
+ urlPath = strings.ToLower(urlPath)
}
// filter wrong http method
@@ -631,11 +629,12 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
}
// filter for static file
- if p.execFilter(context, BeforeStatic, urlPath) {
+ if len(p.filters[BeforeStatic]) > 0 && p.execFilter(context, urlPath, BeforeStatic) {
goto Admin
}
serverStaticRouter(context)
+
if context.ResponseWriter.Started {
findRouter = true
goto Admin
@@ -653,9 +652,9 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
var err error
context.Input.CruSession, err = GlobalSessions.SessionStart(rw, r)
if err != nil {
- Error(err)
+ logs.Error(err)
exception("503", context)
- return
+ goto Admin
}
defer func() {
if context.Input.CruSession != nil {
@@ -663,8 +662,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
}
}()
}
-
- if p.execFilter(context, BeforeRouter, urlPath) {
+ if len(p.filters[BeforeRouter]) > 0 && p.execFilter(context, urlPath, BeforeRouter) {
goto Admin
}
@@ -693,7 +691,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
if findRouter {
//execute middleware filters
- if p.execFilter(context, BeforeExec, urlPath) {
+ if len(p.filters[BeforeExec]) > 0 && p.execFilter(context, urlPath, BeforeExec) {
goto Admin
}
isRunnable := false
@@ -783,7 +781,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
if !context.ResponseWriter.Started && context.Output.Status == 0 {
if BConfig.WebConfig.AutoRender {
if err := execController.Render(); err != nil {
- panic(err)
+ logs.Error(err)
}
}
}
@@ -794,17 +792,18 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
}
//execute middleware filters
- if p.execFilter(context, AfterExec, urlPath) {
+ if len(p.filters[AfterExec]) > 0 && p.execFilter(context, urlPath, AfterExec) {
goto Admin
}
}
-
- p.execFilter(context, FinishRouter, urlPath)
+ if len(p.filters[FinishRouter]) > 0 && p.execFilter(context, urlPath, FinishRouter) {
+ goto Admin
+ }
Admin:
- timeDur := time.Since(startTime)
//admin module record QPS
if BConfig.Listen.EnableAdmin {
+ timeDur := time.Since(startTime)
if FilterMonitorFunc(r.Method, r.URL.Path, timeDur) {
if runRouter != nil {
go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, runRouter.Name(), timeDur)
@@ -815,6 +814,7 @@ Admin:
}
if BConfig.RunMode == DEV || BConfig.Log.AccessLogs {
+ timeDur := time.Since(startTime)
var devInfo string
if findRouter {
if routerInfo != nil {
@@ -826,7 +826,7 @@ Admin:
devInfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s |", r.Method, r.URL.Path, timeDur.String(), "notmatch")
}
if DefaultAccessLogFilter == nil || !DefaultAccessLogFilter.Filter(context) {
- Debug(devInfo)
+ logs.Debug(devInfo)
}
}
@@ -843,27 +843,26 @@ func (p *ControllerRegister) recoverPanic(context *beecontext.Context) {
}
if !BConfig.RecoverPanic {
panic(err)
- } else {
- if BConfig.EnableErrorsShow {
- if _, ok := ErrorMaps[fmt.Sprint(err)]; ok {
- exception(fmt.Sprint(err), context)
- return
- }
+ }
+ if BConfig.EnableErrorsShow {
+ if _, ok := ErrorMaps[fmt.Sprint(err)]; ok {
+ exception(fmt.Sprint(err), context)
+ return
}
- var stack string
- Critical("the request url is ", context.Input.URL())
- Critical("Handler crashed with error", err)
- for i := 1; ; i++ {
- _, file, line, ok := runtime.Caller(i)
- if !ok {
- break
- }
- Critical(fmt.Sprintf("%s:%d", file, line))
- stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line))
- }
- if BConfig.RunMode == DEV {
- showErr(err, context, stack)
+ }
+ var stack string
+ logs.Critical("the request url is ", context.Input.URL())
+ logs.Critical("Handler crashed with error", err)
+ for i := 1; ; i++ {
+ _, file, line, ok := runtime.Caller(i)
+ if !ok {
+ break
}
+ logs.Critical(fmt.Sprintf("%s:%d", file, line))
+ stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line))
+ }
+ if BConfig.RunMode == DEV {
+ showErr(err, context, stack)
}
}
}
diff --git a/router_test.go b/router_test.go
index f26f0c86..9f11286c 100644
--- a/router_test.go
+++ b/router_test.go
@@ -21,6 +21,7 @@ import (
"testing"
"github.com/astaxie/beego/context"
+ "github.com/astaxie/beego/logs"
)
type TestController struct {
@@ -94,7 +95,7 @@ func TestUrlFor(t *testing.T) {
handler.Add("/api/list", &TestController{}, "*:List")
handler.Add("/person/:last/:first", &TestController{}, "*:Param")
if a := handler.URLFor("TestController.List"); a != "/api/list" {
- Info(a)
+ logs.Info(a)
t.Errorf("TestController.List must equal to /api/list")
}
if a := handler.URLFor("TestController.Param", ":last", "xie", ":first", "asta"); a != "/person/xie/asta" {
@@ -120,24 +121,24 @@ func TestUrlFor2(t *testing.T) {
handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, "*:Param")
handler.Add("/:year:int/:month:int/:title/:entid", &TestController{})
if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" {
- Info(handler.URLFor("TestController.GetURL"))
+ logs.Info(handler.URLFor("TestController.GetURL"))
t.Errorf("TestController.List must equal to /v1/astaxie/edit")
}
if handler.URLFor("TestController.List", ":v", "za", ":id", "12", ":page", "123") !=
"/v1/za/cms_12_123.html" {
- Info(handler.URLFor("TestController.List"))
+ logs.Info(handler.URLFor("TestController.List"))
t.Errorf("TestController.List must equal to /v1/za/cms_12_123.html")
}
if handler.URLFor("TestController.Param", ":v", "za", ":id", "12", ":page", "123") !=
"/v1/za_cms/ttt_12_123.html" {
- Info(handler.URLFor("TestController.Param"))
+ logs.Info(handler.URLFor("TestController.Param"))
t.Errorf("TestController.List must equal to /v1/za_cms/ttt_12_123.html")
}
if handler.URLFor("TestController.Get", ":year", "1111", ":month", "11",
":title", "aaaa", ":entid", "aaaa") !=
"/1111/11/aaaa/aaaa" {
- Info(handler.URLFor("TestController.Get"))
+ logs.Info(handler.URLFor("TestController.Get"))
t.Errorf("TestController.Get must equal to /1111/11/aaaa/aaaa")
}
}
diff --git a/session/mysql/sess_mysql.go b/session/mysql/sess_mysql.go
index 969d26c9..838ec669 100644
--- a/session/mysql/sess_mysql.go
+++ b/session/mysql/sess_mysql.go
@@ -115,7 +115,6 @@ func (st *SessionStore) SessionRelease(w http.ResponseWriter) {
}
st.c.Exec("UPDATE "+TableName+" set `session_data`=?, `session_expiry`=? where session_key=?",
b, time.Now().Unix(), st.sid)
-
}
// Provider mysql session provider
diff --git a/session/sess_file.go b/session/sess_file.go
index 9265b030..91acfcd4 100644
--- a/session/sess_file.go
+++ b/session/sess_file.go
@@ -16,7 +16,6 @@ package session
import (
"errors"
- "fmt"
"io"
"io/ioutil"
"net/http"
@@ -82,14 +81,17 @@ func (fs *FileSessionStore) SessionID() string {
func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) {
b, err := EncodeGob(fs.values)
if err != nil {
+ SLogger.Println(err)
return
}
_, err = os.Stat(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid))
var f *os.File
if err == nil {
f, err = os.OpenFile(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid), os.O_RDWR, 0777)
+ SLogger.Println(err)
} else if os.IsNotExist(err) {
f, err = os.Create(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid))
+ SLogger.Println(err)
} else {
return
}
@@ -123,7 +125,7 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) {
err := os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777)
if err != nil {
- println(err.Error())
+ SLogger.Println(err.Error())
}
_, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
var f *os.File
@@ -191,7 +193,7 @@ func (fp *FileProvider) SessionAll() int {
return a.visit(path, f, err)
})
if err != nil {
- fmt.Printf("filepath.Walk() returned %v\n", err)
+ SLogger.Printf("filepath.Walk() returned %v\n", err)
return 0
}
return a.total
@@ -205,11 +207,11 @@ func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) {
err := os.MkdirAll(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])), 0777)
if err != nil {
- println(err.Error())
+ SLogger.Println(err.Error())
}
err = os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777)
if err != nil {
- println(err.Error())
+ SLogger.Println(err.Error())
}
_, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
var newf *os.File
diff --git a/session/session.go b/session/session.go
index 9fe99a17..73f0d677 100644
--- a/session/session.go
+++ b/session/session.go
@@ -31,9 +31,14 @@ import (
"crypto/rand"
"encoding/hex"
"encoding/json"
+ "errors"
"fmt"
+ "io"
+ "log"
"net/http"
+ "net/textproto"
"net/url"
+ "os"
"time"
)
@@ -61,6 +66,9 @@ type Provider interface {
var provides = make(map[string]Provider)
+// SLogger a helpful variable to log information about session
+var SLogger = NewSessionLog(os.Stderr)
+
// Register makes a session provide available by the provided name.
// If Register is called twice with the same name or if driver is nil,
// it panics.
@@ -75,15 +83,18 @@ func Register(name string, provide Provider) {
}
type managerConfig struct {
- CookieName string `json:"cookieName"`
- EnableSetCookie bool `json:"enableSetCookie,omitempty"`
- Gclifetime int64 `json:"gclifetime"`
- Maxlifetime int64 `json:"maxLifetime"`
- Secure bool `json:"secure"`
- CookieLifeTime int `json:"cookieLifeTime"`
- ProviderConfig string `json:"providerConfig"`
- Domain string `json:"domain"`
- SessionIDLength int64 `json:"sessionIDLength"`
+ CookieName string `json:"cookieName"`
+ EnableSetCookie bool `json:"enableSetCookie,omitempty"`
+ Gclifetime int64 `json:"gclifetime"`
+ Maxlifetime int64 `json:"maxLifetime"`
+ Secure bool `json:"secure"`
+ CookieLifeTime int `json:"cookieLifeTime"`
+ ProviderConfig string `json:"providerConfig"`
+ Domain string `json:"domain"`
+ SessionIDLength int64 `json:"sessionIDLength"`
+ EnableSidInHttpHeader bool `json:"enableSidInHttpHeader"`
+ SessionNameInHttpHeader string `json:"sessionNameInHttpHeader"`
+ EnableSidInUrlQuery bool `json:"enableSidInUrlQuery"`
}
// Manager contains Provider and its configuration.
@@ -118,6 +129,19 @@ func NewManager(provideName, config string) (*Manager, error) {
if cf.Maxlifetime == 0 {
cf.Maxlifetime = cf.Gclifetime
}
+
+ if cf.EnableSidInHttpHeader {
+ if cf.SessionNameInHttpHeader == "" {
+ panic(errors.New("SessionNameInHttpHeader is empty"))
+ }
+
+ strMimeHeader := textproto.CanonicalMIMEHeaderKey(cf.SessionNameInHttpHeader)
+ if cf.SessionNameInHttpHeader != strMimeHeader {
+ strErrMsg := "SessionNameInHttpHeader (" + cf.SessionNameInHttpHeader + ") has the wrong format, it should be like this : " + strMimeHeader
+ panic(errors.New(strErrMsg))
+ }
+ }
+
err = provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig)
if err != nil {
return nil, err
@@ -143,12 +167,24 @@ func NewManager(provideName, config string) (*Manager, error) {
func (manager *Manager) getSid(r *http.Request) (string, error) {
cookie, errs := r.Cookie(manager.config.CookieName)
if errs != nil || cookie.Value == "" || cookie.MaxAge < 0 {
- errs := r.ParseForm()
- if errs != nil {
- return "", errs
+ var sid string
+ if manager.config.EnableSidInUrlQuery {
+ errs := r.ParseForm()
+ if errs != nil {
+ return "", errs
+ }
+
+ sid = r.FormValue(manager.config.CookieName)
+ }
+
+ // if not found in Cookie / param, then read it from request headers
+ if manager.config.EnableSidInHttpHeader && sid == "" {
+ sids, isFound := r.Header[manager.config.SessionNameInHttpHeader]
+ if isFound && len(sids) != 0 {
+ return sids[0], nil
+ }
}
- sid := r.FormValue(manager.config.CookieName)
return sid, nil
}
@@ -192,11 +228,21 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
}
r.AddCookie(cookie)
+ if manager.config.EnableSidInHttpHeader {
+ r.Header.Set(manager.config.SessionNameInHttpHeader, sid)
+ w.Header().Set(manager.config.SessionNameInHttpHeader, sid)
+ }
+
return
}
// SessionDestroy Destroy session by its id in http request cookie.
func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) {
+ if manager.config.EnableSidInHttpHeader {
+ r.Header.Del(manager.config.SessionNameInHttpHeader)
+ w.Header().Del(manager.config.SessionNameInHttpHeader)
+ }
+
cookie, err := r.Cookie(manager.config.CookieName)
if err != nil || cookie.Value == "" {
return
@@ -261,6 +307,12 @@ func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Reque
http.SetCookie(w, cookie)
}
r.AddCookie(cookie)
+
+ if manager.config.EnableSidInHttpHeader {
+ r.Header.Set(manager.config.SessionNameInHttpHeader, sid)
+ w.Header().Set(manager.config.SessionNameInHttpHeader, sid)
+ }
+
return
}
@@ -296,3 +348,15 @@ func (manager *Manager) isSecure(req *http.Request) bool {
}
return true
}
+
+// Log implement the log.Logger
+type Log struct {
+ *log.Logger
+}
+
+// NewSessionLog set io.Writer to create a Logger for session.
+func NewSessionLog(out io.Writer) *Log {
+ sl := new(Log)
+ sl.Logger = log.New(out, "[SESSION]", 1e9)
+ return sl
+}
diff --git a/session/ssdb/sess_ssdb.go b/session/ssdb/sess_ssdb.go
new file mode 100644
index 00000000..4dcf160a
--- /dev/null
+++ b/session/ssdb/sess_ssdb.go
@@ -0,0 +1,192 @@
+package ssdb
+
+import (
+ "errors"
+ "net/http"
+ "strconv"
+ "strings"
+ "sync"
+
+ "github.com/astaxie/beego/session"
+ "github.com/ssdb/gossdb/ssdb"
+)
+
+var ssdbProvider = &SsdbProvider{}
+
+type SsdbProvider struct {
+ client *ssdb.Client
+ host string
+ port int
+ maxLifetime int64
+}
+
+func (p *SsdbProvider) connectInit() error {
+ var err error
+ if p.host == "" || p.port == 0 {
+ return errors.New("SessionInit First")
+ }
+ p.client, err = ssdb.Connect(p.host, p.port)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (p *SsdbProvider) SessionInit(maxLifetime int64, savePath string) error {
+ var e error = nil
+ p.maxLifetime = maxLifetime
+ address := strings.Split(savePath, ":")
+ p.host = address[0]
+ p.port, e = strconv.Atoi(address[1])
+ if e != nil {
+ return e
+ }
+ err := p.connectInit()
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (p *SsdbProvider) SessionRead(sid string) (session.Store, error) {
+ if p.client == nil {
+ if err := p.connectInit(); err != nil {
+ return nil, err
+ }
+ }
+ var kv map[interface{}]interface{}
+ value, err := p.client.Get(sid)
+ if err != nil {
+ return nil, err
+ }
+ if value == nil || len(value.(string)) == 0 {
+ kv = make(map[interface{}]interface{})
+ } else {
+ kv, err = session.DecodeGob([]byte(value.(string)))
+ if err != nil {
+ return nil, err
+ }
+ }
+ rs := &SessionStore{sid: sid, values: kv, maxLifetime: p.maxLifetime, client: p.client}
+ return rs, nil
+}
+
+func (p *SsdbProvider) SessionExist(sid string) bool {
+ if p.client == nil {
+ if err := p.connectInit(); err != nil {
+ panic(err)
+ }
+ }
+ value, err := p.client.Get(sid)
+ if err != nil {
+ panic(err)
+ }
+ if value == nil || len(value.(string)) == 0 {
+ return false
+ }
+ return true
+
+}
+func (p *SsdbProvider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
+ //conn.Do("setx", key, v, ttl)
+ if p.client == nil {
+ if err := p.connectInit(); err != nil {
+ return nil, err
+ }
+ }
+ value, err := p.client.Get(oldsid)
+ if err != nil {
+ return nil, err
+ }
+ var kv map[interface{}]interface{}
+ if value == nil || len(value.(string)) == 0 {
+ kv = make(map[interface{}]interface{})
+ } else {
+ kv, err = session.DecodeGob([]byte(value.(string)))
+ if err != nil {
+ return nil, err
+ }
+ _, err = p.client.Del(oldsid)
+ if err != nil {
+ return nil, err
+ }
+ }
+ _, e := p.client.Do("setx", sid, value, p.maxLifetime)
+ if e != nil {
+ return nil, e
+ }
+ rs := &SessionStore{sid: sid, values: kv, maxLifetime: p.maxLifetime, client: p.client}
+ return rs, nil
+}
+
+func (p *SsdbProvider) SessionDestroy(sid string) error {
+ if p.client == nil {
+ if err := p.connectInit(); err != nil {
+ return err
+ }
+ }
+ _, err := p.client.Del(sid)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (p *SsdbProvider) SessionGC() {
+ return
+}
+
+func (p *SsdbProvider) SessionAll() int {
+ return 0
+}
+
+type SessionStore struct {
+ sid string
+ lock sync.RWMutex
+ values map[interface{}]interface{}
+ maxLifetime int64
+ client *ssdb.Client
+}
+
+func (s *SessionStore) Set(key, value interface{}) error {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ s.values[key] = value
+ return nil
+}
+func (s *SessionStore) Get(key interface{}) interface{} {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ if value, ok := s.values[key]; ok {
+ return value
+ }
+ return nil
+}
+
+func (s *SessionStore) Delete(key interface{}) error {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ delete(s.values, key)
+ return nil
+}
+func (s *SessionStore) Flush() error {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ s.values = make(map[interface{}]interface{})
+ return nil
+}
+func (s *SessionStore) SessionID() string {
+ return s.sid
+}
+
+func (s *SessionStore) SessionRelease(w http.ResponseWriter) {
+ b, err := session.EncodeGob(s.values)
+ if err != nil {
+ return
+ }
+ s.client.Do("setx", s.sid, string(b), s.maxLifetime)
+
+}
+func init() {
+ session.Register("ssdb", ssdbProvider)
+}
diff --git a/staticfile.go b/staticfile.go
index 4b19f949..11a2cdcc 100644
--- a/staticfile.go
+++ b/staticfile.go
@@ -27,6 +27,7 @@ import (
"time"
"github.com/astaxie/beego/context"
+ "github.com/astaxie/beego/logs"
)
var errNotStaticRequest = errors.New("request not a static file request")
@@ -48,14 +49,19 @@ func serverStaticRouter(ctx *context.Context) {
if filePath == "" || fileInfo == nil {
if BConfig.RunMode == DEV {
- Warn("Can't find/open the file:", filePath, err)
+ logs.Warn("Can't find/open the file:", filePath, err)
}
http.NotFound(ctx.ResponseWriter, ctx.Request)
return
}
if fileInfo.IsDir() {
- //serveFile will list dir
- http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath)
+ requestURL := ctx.Input.URL()
+ if requestURL[len(requestURL)-1] != '/' {
+ ctx.Redirect(302, requestURL+"/")
+ } else {
+ //serveFile will list dir
+ http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath)
+ }
return
}
@@ -67,7 +73,7 @@ func serverStaticRouter(ctx *context.Context) {
b, n, sch, err := openFile(filePath, fileInfo, acceptEncoding)
if err != nil {
if BConfig.RunMode == DEV {
- Warn("Can't compress the file:", filePath, err)
+ logs.Warn("Can't compress the file:", filePath, err)
}
http.NotFound(ctx.ResponseWriter, ctx.Request)
return
diff --git a/swagger/docs_spec.go b/swagger/docs_spec.go
deleted file mode 100644
index 680324dc..00000000
--- a/swagger/docs_spec.go
+++ /dev/null
@@ -1,160 +0,0 @@
-// Copyright 2014 beego Author. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package swagger struct definition
-package swagger
-
-// SwaggerVersion show the current swagger version
-const SwaggerVersion = "1.2"
-
-// ResourceListing list the resource
-type ResourceListing struct {
- APIVersion string `json:"apiVersion"`
- SwaggerVersion string `json:"swaggerVersion"` // e.g 1.2
- // BasePath string `json:"basePath"` obsolete in 1.1
- APIs []APIRef `json:"apis"`
- Info Information `json:"info"`
-}
-
-// APIRef description the api path and description
-type APIRef struct {
- Path string `json:"path"` // relative or absolute, must start with /
- Description string `json:"description"`
-}
-
-// Information show the API Information
-type Information struct {
- Title string `json:"title,omitempty"`
- Description string `json:"description,omitempty"`
- Contact string `json:"contact,omitempty"`
- TermsOfServiceURL string `json:"termsOfServiceUrl,omitempty"`
- License string `json:"license,omitempty"`
- LicenseURL string `json:"licenseUrl,omitempty"`
-}
-
-// APIDeclaration see https://github.com/wordnik/swagger-core/blob/scala_2.10-1.3-RC3/schemas/api-declaration-schema.json
-type APIDeclaration struct {
- APIVersion string `json:"apiVersion"`
- SwaggerVersion string `json:"swaggerVersion"`
- BasePath string `json:"basePath"`
- ResourcePath string `json:"resourcePath"` // must start with /
- Consumes []string `json:"consumes,omitempty"`
- Produces []string `json:"produces,omitempty"`
- APIs []API `json:"apis,omitempty"`
- Models map[string]Model `json:"models,omitempty"`
-}
-
-// API show tha API struct
-type API struct {
- Path string `json:"path"` // relative or absolute, must start with /
- Description string `json:"description"`
- Operations []Operation `json:"operations,omitempty"`
-}
-
-// Operation desc the Operation
-type Operation struct {
- HTTPMethod string `json:"httpMethod"`
- Nickname string `json:"nickname"`
- Type string `json:"type"` // in 1.1 = DataType
- // ResponseClass string `json:"responseClass"` obsolete in 1.2
- Summary string `json:"summary,omitempty"`
- Notes string `json:"notes,omitempty"`
- Parameters []Parameter `json:"parameters,omitempty"`
- ResponseMessages []ResponseMessage `json:"responseMessages,omitempty"` // optional
- Consumes []string `json:"consumes,omitempty"`
- Produces []string `json:"produces,omitempty"`
- Authorizations []Authorization `json:"authorizations,omitempty"`
- Protocols []Protocol `json:"protocols,omitempty"`
-}
-
-// Protocol support which Protocol
-type Protocol struct {
-}
-
-// ResponseMessage Show the
-type ResponseMessage struct {
- Code int `json:"code"`
- Message string `json:"message"`
- ResponseModel string `json:"responseModel"`
-}
-
-// Parameter desc the request parameters
-type Parameter struct {
- ParamType string `json:"paramType"` // path,query,body,header,form
- Name string `json:"name"`
- Description string `json:"description"`
- DataType string `json:"dataType"` // 1.2 needed?
- Type string `json:"type"` // integer
- Format string `json:"format"` // int64
- AllowMultiple bool `json:"allowMultiple"`
- Required bool `json:"required"`
- Minimum int `json:"minimum"`
- Maximum int `json:"maximum"`
-}
-
-// ErrorResponse desc response
-type ErrorResponse struct {
- Code int `json:"code"`
- Reason string `json:"reason"`
-}
-
-// Model define the data model
-type Model struct {
- ID string `json:"id"`
- Required []string `json:"required,omitempty"`
- Properties map[string]ModelProperty `json:"properties"`
-}
-
-// ModelProperty define the properties
-type ModelProperty struct {
- Type string `json:"type"`
- Description string `json:"description"`
- Items map[string]string `json:"items,omitempty"`
- Format string `json:"format"`
-}
-
-// Authorization see https://github.com/wordnik/swagger-core/wiki/authorizations
-type Authorization struct {
- LocalOAuth OAuth `json:"local-oauth"`
- APIKey APIKey `json:"apiKey"`
-}
-
-// OAuth see https://github.com/wordnik/swagger-core/wiki/authorizations
-type OAuth struct {
- Type string `json:"type"` // e.g. oauth2
- Scopes []string `json:"scopes"` // e.g. PUBLIC
- GrantTypes map[string]GrantType `json:"grantTypes"`
-}
-
-// GrantType see https://github.com/wordnik/swagger-core/wiki/authorizations
-type GrantType struct {
- LoginEndpoint Endpoint `json:"loginEndpoint"`
- TokenName string `json:"tokenName"` // e.g. access_code
- TokenRequestEndpoint Endpoint `json:"tokenRequestEndpoint"`
- TokenEndpoint Endpoint `json:"tokenEndpoint"`
-}
-
-// Endpoint see https://github.com/wordnik/swagger-core/wiki/authorizations
-type Endpoint struct {
- URL string `json:"url"`
- ClientIDName string `json:"clientIdName"`
- ClientSecretName string `json:"clientSecretName"`
- TokenName string `json:"tokenName"`
-}
-
-// APIKey see https://github.com/wordnik/swagger-core/wiki/authorizations
-type APIKey struct {
- Type string `json:"type"` // e.g. apiKey
- PassAs string `json:"passAs"` // e.g. header
-}
diff --git a/swagger/swagger.go b/swagger/swagger.go
new file mode 100644
index 00000000..e48dcf1e
--- /dev/null
+++ b/swagger/swagger.go
@@ -0,0 +1,155 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Swagger™ is a project used to describe and document RESTful APIs.
+//
+// The Swagger specification defines a set of files required to describe such an API. These files can then be used by the Swagger-UI project to display the API and Swagger-Codegen to generate clients in various languages. Additional utilities can also take advantage of the resulting files, such as testing tools.
+// Now in version 2.0, Swagger is more enabling than ever. And it's 100% open source software.
+
+// Package swagger struct definition
+package swagger
+
+// Swagger list the resource
+type Swagger struct {
+ SwaggerVersion string `json:"swagger,omitempty"`
+ Infos Information `json:"info"`
+ Host string `json:"host,omitempty"`
+ BasePath string `json:"basePath,omitempty"`
+ Schemes []string `json:"schemes,omitempty"`
+ Consumes []string `json:"consumes,omitempty"`
+ Produces []string `json:"produces,omitempty"`
+ Paths map[string]Item `json:"paths"`
+ Definitions map[string]Schema `json:"definitions,omitempty"`
+ SecurityDefinitions map[string]Scurity `json:"securityDefinitions,omitempty"`
+ Security map[string][]string `json:"security,omitempty"`
+ Tags []Tag `json:"tags,omitempty"`
+ ExternalDocs ExternalDocs `json:"externalDocs,omitempty"`
+}
+
+// Information Provides metadata about the API. The metadata can be used by the clients if needed.
+type Information struct {
+ Title string `json:"title,omitempty"`
+ Description string `json:"description,omitempty"`
+ Version string `json:"version,omitempty"`
+ TermsOfServiceURL string `json:"termsOfServiceUrl,omitempty"`
+
+ Contact Contact `json:"contact,omitempty"`
+ License License `json:"license,omitempty"`
+}
+
+// Contact information for the exposed API.
+type Contact struct {
+ Name string `json:"name,omitempty"`
+ URL string `json:"url,omitempty"`
+ EMail string `json:"email,omitempty"`
+}
+
+// License information for the exposed API.
+type License struct {
+ Name string `json:"name,omitempty"`
+ URL string `json:"url,omitempty"`
+}
+
+// Item Describes the operations available on a single path.
+type Item struct {
+ Ref string `json:"$ref,omitempty"`
+ Get *Operation `json:"get,omitempty"`
+ Put *Operation `json:"put,omitempty"`
+ Post *Operation `json:"post,omitempty"`
+ Delete *Operation `json:"delete,omitempty"`
+ Options *Operation `json:"options,omitempty"`
+ Head *Operation `json:"head,omitempty"`
+ Patch *Operation `json:"patch,omitempty"`
+}
+
+// Operation Describes a single API operation on a path.
+type Operation struct {
+ Tags []string `json:"tags,omitempty"`
+ Summary string `json:"summary,omitempty"`
+ Description string `json:"description,omitempty"`
+ OperationID string `json:"operationId,omitempty"`
+ Consumes []string `json:"consumes,omitempty"`
+ Produces []string `json:"produces,omitempty"`
+ Schemes []string `json:"schemes,omitempty"`
+ Parameters []Parameter `json:"parameters,omitempty"`
+ Responses map[string]Response `json:"responses,omitempty"`
+ Deprecated bool `json:"deprecated,omitempty"`
+}
+
+// Parameter Describes a single operation parameter.
+type Parameter struct {
+ In string `json:"in,omitempty"`
+ Name string `json:"name,omitempty"`
+ Description string `json:"description,omitempty"`
+ Required bool `json:"required,omitempty"`
+ Schema Schema `json:"schema,omitempty"`
+ Type string `json:"type,omitempty"`
+ Format string `json:"format,omitempty"`
+}
+
+// Schema Object allows the definition of input and output data types.
+type Schema struct {
+ Ref string `json:"$ref,omitempty"`
+ Title string `json:"title,omitempty"`
+ Format string `json:"format,omitempty"`
+ Description string `json:"description,omitempty"`
+ Required []string `json:"required,omitempty"`
+ Type string `json:"type,omitempty"`
+ Properties map[string]Propertie `json:"properties,omitempty"`
+}
+
+// Propertie are taken from the JSON Schema definition but their definitions were adjusted to the Swagger Specification
+type Propertie struct {
+ Title string `json:"title,omitempty"`
+ Description string `json:"description,omitempty"`
+ Default string `json:"default,omitempty"`
+ Type string `json:"type,omitempty"`
+ Example string `json:"example,omitempty"`
+ Required []string `json:"required,omitempty"`
+ Format string `json:"format,omitempty"`
+ ReadOnly bool `json:"readOnly,omitempty"`
+ Properties map[string]Propertie `json:"properties,omitempty"`
+}
+
+// Response as they are returned from executing this operation.
+type Response struct {
+ Description string `json:"description,omitempty"`
+ Schema Schema `json:"schema,omitempty"`
+ Ref string `json:"$ref,omitempty"`
+}
+
+// Scurity Allows the definition of a security scheme that can be used by the operations
+type Scurity struct {
+ Type string `json:"type,omitempty"` // Valid values are "basic", "apiKey" or "oauth2".
+ Description string `json:"description,omitempty"`
+ Name string `json:"name,omitempty"`
+ In string `json:"in,omitempty"` // Valid values are "query" or "header".
+ Flow string `json:"flow,omitempty"` // Valid values are "implicit", "password", "application" or "accessCode".
+ AuthorizationURL string `json:"authorizationUrl,omitempty"`
+ TokenURL string `json:"tokenUrl,omitempty"`
+ Scopes map[string]string `json:"scopes,omitempty"` // The available scopes for the OAuth2 security scheme.
+}
+
+// Tag Allows adding meta data to a single tag that is used by the Operation Object
+type Tag struct {
+ Name string `json:"name,omitempty"`
+ Description string `json:"description,omitempty"`
+ ExternalDocs ExternalDocs `json:"externalDocs,omitempty"`
+}
+
+// ExternalDocs include Additional external documentation
+type ExternalDocs struct {
+ Description string `json:"description,omitempty"`
+ URL string `json:"url,omitempty"`
+}
diff --git a/template.go b/template.go
index e6c43f87..494acc4f 100644
--- a/template.go
+++ b/template.go
@@ -26,6 +26,7 @@ import (
"strings"
"sync"
+ "github.com/astaxie/beego/logs"
"github.com/astaxie/beego/utils"
)
@@ -36,19 +37,35 @@ var (
templatesLock sync.RWMutex
// beeTemplateExt stores the template extension which will build
beeTemplateExt = []string{"tpl", "html"}
+ // beeTemplatePreprocessors stores associations of extension -> preprocessor handler
+ beeTemplateEngines = map[string]templatePreProcessor{}
)
-func executeTemplate(wr io.Writer, name string, data interface{}) error {
+// ExecuteTemplate applies the template with name to the specified data object,
+// writing the output to wr.
+// A template will be executed safely in parallel.
+func ExecuteTemplate(wr io.Writer, name string, data interface{}) error {
if BConfig.RunMode == DEV {
templatesLock.RLock()
defer templatesLock.RUnlock()
}
if t, ok := beeTemplates[name]; ok {
- err := t.ExecuteTemplate(wr, name, data)
- if err != nil {
- Trace("template Execute err:", err)
+ if t.Lookup(name) != nil {
+ err := t.ExecuteTemplate(wr, name, data)
+ if err != nil {
+ logs.Trace("template Execute err:", err)
+ }
+ return err
+ } else {
+ err := t.Execute(wr, data)
+ if err != nil {
+ if err != nil {
+ logs.Trace("template Execute err:", err)
+ }
+ return err
+ }
}
- return err
+ return nil
}
panic("can't find templatefile in the path:" + name)
}
@@ -88,6 +105,8 @@ func AddFuncMap(key string, fn interface{}) error {
return nil
}
+type templatePreProcessor func(root, path string, funcs template.FuncMap) (*template.Template, error)
+
type templateFile struct {
root string
files map[string][]string
@@ -156,13 +175,22 @@ func BuildTemplate(dir string, files ...string) error {
fmt.Printf("filepath.Walk() returned %v\n", err)
return err
}
+ buildAllFiles := len(files) == 0
for _, v := range self.files {
for _, file := range v {
- if len(files) == 0 || utils.InSlice(file, files) {
+ if buildAllFiles || utils.InSlice(file, files) {
templatesLock.Lock()
- t, err := getTemplate(self.root, file, v...)
+ ext := filepath.Ext(file)
+ var t *template.Template
+ if len(ext) == 0 {
+ t, err = getTemplate(self.root, file, v...)
+ } else if fn, ok := beeTemplateEngines[ext[1:]]; ok {
+ t, err = fn(self.root, file, beegoTplFuncMap)
+ } else {
+ t, err = getTemplate(self.root, file, v...)
+ }
if err != nil {
- Trace("parse template err:", file, err)
+ logs.Trace("parse template err:", file, err)
} else {
beeTemplates[file] = t
}
@@ -240,7 +268,7 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others
var subMods1 [][]string
t, subMods1, err = getTplDeep(root, otherFile, "", t)
if err != nil {
- Trace("template parse file err:", err)
+ logs.Trace("template parse file err:", err)
} else if subMods1 != nil && len(subMods1) > 0 {
t, err = _getTemplate(t, root, subMods1, others...)
}
@@ -261,7 +289,7 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others
var subMods1 [][]string
t, subMods1, err = getTplDeep(root, otherFile, "", t)
if err != nil {
- Trace("template parse file err:", err)
+ logs.Trace("template parse file err:", err)
} else if subMods1 != nil && len(subMods1) > 0 {
t, err = _getTemplate(t, root, subMods1, others...)
}
@@ -305,3 +333,9 @@ func DelStaticPath(url string) *App {
delete(BConfig.WebConfig.StaticDir, url)
return BeeApp
}
+
+func AddTemplateEngine(extension string, fn templatePreProcessor) *App {
+ AddTemplateExt(extension)
+ beeTemplateEngines[extension] = fn
+ return BeeApp
+}
diff --git a/templatefunc.go b/templatefunc.go
index 8558733f..36442984 100644
--- a/templatefunc.go
+++ b/templatefunc.go
@@ -421,18 +421,18 @@ func RenderForm(obj interface{}) template.HTML {
fieldT := objT.Field(i)
- label, name, fType, id, class, ignored := parseFormTag(fieldT)
+ label, name, fType, id, class, ignored, required := parseFormTag(fieldT)
if ignored {
continue
}
- raw = append(raw, renderFormField(label, name, fType, fieldV.Interface(), id, class))
+ raw = append(raw, renderFormField(label, name, fType, fieldV.Interface(), id, class, required))
}
return template.HTML(strings.Join(raw, ""))
}
// renderFormField returns a string containing HTML of a single form field.
-func renderFormField(label, name, fType string, value interface{}, id string, class string) string {
+func renderFormField(label, name, fType string, value interface{}, id string, class string, required bool) string {
if id != "" {
id = " id=\"" + id + "\""
}
@@ -441,11 +441,16 @@ func renderFormField(label, name, fType string, value interface{}, id string, cl
class = " class=\"" + class + "\""
}
- if isValidForInput(fType) {
- return fmt.Sprintf(`%v`, label, id, class, name, fType, value)
+ requiredString := ""
+ if required {
+ requiredString = " required"
}
- return fmt.Sprintf(`%v<%v%v%v name="%v">%v%v>`, label, fType, id, class, name, value, fType)
+ if isValidForInput(fType) {
+ return fmt.Sprintf(`%v`, label, id, class, name, fType, value, requiredString)
+ }
+
+ return fmt.Sprintf(`%v<%v%v%v name="%v"%v>%v%v>`, label, fType, id, class, name, requiredString, value, fType)
}
// isValidForInput checks if fType is a valid value for the `type` property of an HTML input element.
@@ -461,7 +466,7 @@ func isValidForInput(fType string) bool {
// parseFormTag takes the stuct-tag of a StructField and parses the `form` value.
// returned are the form label, name-property, type and wether the field should be ignored.
-func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id string, class string, ignored bool) {
+func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id string, class string, ignored bool, required bool) {
tags := strings.Split(fieldT.Tag.Get("form"), ",")
label = fieldT.Name + ": "
name = fieldT.Name
@@ -470,6 +475,12 @@ func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id str
id = fieldT.Tag.Get("id")
class = fieldT.Tag.Get("class")
+ required = false
+ required_field := fieldT.Tag.Get("required")
+ if required_field != "-" && required_field != "" {
+ required, _ = strconv.ParseBool(required_field)
+ }
+
switch len(tags) {
case 1:
if tags[0] == "-" {
@@ -496,6 +507,7 @@ func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id str
label = tags[2]
}
}
+
return
}
diff --git a/templatefunc_test.go b/templatefunc_test.go
index 98fbf7ab..86de37ae 100644
--- a/templatefunc_test.go
+++ b/templatefunc_test.go
@@ -195,54 +195,78 @@ func TestRenderForm(t *testing.T) {
}
func TestRenderFormField(t *testing.T) {
- html := renderFormField("Label: ", "Name", "text", "Value", "", "")
+ html := renderFormField("Label: ", "Name", "text", "Value", "", "", false)
if html != `Label: ` {
t.Errorf("Wrong html output for input[type=text]: %v ", html)
}
- html = renderFormField("Label: ", "Name", "textarea", "Value", "", "")
+ html = renderFormField("Label: ", "Name", "textarea", "Value", "", "", false)
if html != `Label: ` {
t.Errorf("Wrong html output for textarea: %v ", html)
}
+
+ html = renderFormField("Label: ", "Name", "textarea", "Value", "", "", true)
+ if html != `Label: ` {
+ t.Errorf("Wrong html output for textarea: %v ", html)
+ }
}
func TestParseFormTag(t *testing.T) {
// create struct to contain field with different types of struct-tag `form`
type user struct {
- All int `form:"name,text,年龄:"`
- NoName int `form:",hidden,年龄:"`
- OnlyLabel int `form:",,年龄:"`
- OnlyName int `form:"name" id:"name" class:"form-name"`
- Ignored int `form:"-"`
+ All int `form:"name,text,年龄:"`
+ NoName int `form:",hidden,年龄:"`
+ OnlyLabel int `form:",,年龄:"`
+ OnlyName int `form:"name" id:"name" class:"form-name"`
+ Ignored int `form:"-"`
+ Required int `form:"name" required:"true"`
+ IgnoreRequired int `form:"name"`
+ NotRequired int `form:"name" required:"false"`
}
objT := reflect.TypeOf(&user{}).Elem()
- label, name, fType, id, class, ignored := parseFormTag(objT.Field(0))
+ label, name, fType, id, class, ignored, required := parseFormTag(objT.Field(0))
if !(name == "name" && label == "年龄:" && fType == "text" && ignored == false) {
t.Errorf("Form Tag with name, label and type was not correctly parsed.")
}
- label, name, fType, id, class, ignored = parseFormTag(objT.Field(1))
+ label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(1))
if !(name == "NoName" && label == "年龄:" && fType == "hidden" && ignored == false) {
t.Errorf("Form Tag with label and type but without name was not correctly parsed.")
}
- label, name, fType, id, class, ignored = parseFormTag(objT.Field(2))
+ label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(2))
if !(name == "OnlyLabel" && label == "年龄:" && fType == "text" && ignored == false) {
t.Errorf("Form Tag containing only label was not correctly parsed.")
}
- label, name, fType, id, class, ignored = parseFormTag(objT.Field(3))
+ label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(3))
if !(name == "name" && label == "OnlyName: " && fType == "text" && ignored == false &&
id == "name" && class == "form-name") {
t.Errorf("Form Tag containing only name was not correctly parsed.")
}
- label, name, fType, id, class, ignored = parseFormTag(objT.Field(4))
+ label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(4))
if ignored == false {
t.Errorf("Form Tag that should be ignored was not correctly parsed.")
}
+
+ label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(5))
+ if !(name == "name" && required == true) {
+ t.Errorf("Form Tag containing only name and required was not correctly parsed.")
+ }
+
+ label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(6))
+ if !(name == "name" && required == false) {
+ t.Errorf("Form Tag containing only name and ignore required was not correctly parsed.")
+ }
+
+ label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(7))
+ if !(name == "name" && required == false) {
+ t.Errorf("Form Tag containing only name and not required was not correctly parsed.")
+ }
+
}
func TestMapGet(t *testing.T) {
diff --git a/toolbox/task.go b/toolbox/task.go
index 537de428..abd411c8 100644
--- a/toolbox/task.go
+++ b/toolbox/task.go
@@ -389,6 +389,10 @@ func dayMatches(s *Schedule, t time.Time) bool {
// StartTask start all tasks
func StartTask() {
+ if isstart {
+ //If already started, no need to start another goroutine.
+ return
+ }
isstart = true
go run()
}
@@ -432,8 +436,11 @@ func run() {
// StopTask stop all tasks
func StopTask() {
- isstart = false
- stop <- true
+ if isstart {
+ isstart = false
+ stop <- true
+ }
+
}
// AddTask add task with name
diff --git a/utils/captcha/captcha.go b/utils/captcha/captcha.go
index 1a4a6edc..42ac70d3 100644
--- a/utils/captcha/captcha.go
+++ b/utils/captcha/captcha.go
@@ -69,6 +69,7 @@ import (
"github.com/astaxie/beego"
"github.com/astaxie/beego/cache"
"github.com/astaxie/beego/context"
+ "github.com/astaxie/beego/logs"
"github.com/astaxie/beego/utils"
)
@@ -139,7 +140,7 @@ func (c *Captcha) Handler(ctx *context.Context) {
if err := c.store.Put(key, chars, c.Expiration); err != nil {
ctx.Output.SetStatus(500)
ctx.WriteString("captcha reload error")
- beego.Error("Reload Create Captcha Error:", err)
+ logs.Error("Reload Create Captcha Error:", err)
return
}
} else {
@@ -154,7 +155,7 @@ func (c *Captcha) Handler(ctx *context.Context) {
img := NewImage(chars, c.StdWidth, c.StdHeight)
if _, err := img.WriteTo(ctx.ResponseWriter); err != nil {
- beego.Error("Write Captcha Image Error:", err)
+ logs.Error("Write Captcha Image Error:", err)
}
}
@@ -162,7 +163,7 @@ func (c *Captcha) Handler(ctx *context.Context) {
func (c *Captcha) CreateCaptchaHTML() template.HTML {
value, err := c.CreateCaptcha()
if err != nil {
- beego.Error("Create Captcha Error:", err)
+ logs.Error("Create Captcha Error:", err)
return ""
}
diff --git a/utils/pagination/controller.go b/utils/pagination/controller.go
index 1d99cac5..2f022d0c 100644
--- a/utils/pagination/controller.go
+++ b/utils/pagination/controller.go
@@ -18,7 +18,7 @@ import (
"github.com/astaxie/beego/context"
)
-// SetPaginator Instantiates a Paginator and assigns it to context.Input.Data["paginator"].
+// SetPaginator Instantiates a Paginator and assigns it to context.Input.Data("paginator").
func SetPaginator(context *context.Context, per int, nums int64) (paginator *Paginator) {
paginator = NewPaginator(context.Request, per, nums)
context.Input.SetData("paginator", &paginator)
diff --git a/utils/rand.go b/utils/rand.go
index 74bb4121..344d1cd5 100644
--- a/utils/rand.go
+++ b/utils/rand.go
@@ -20,28 +20,24 @@ import (
"time"
)
+var alphaNum = []byte(`0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz`)
+
// RandomCreateBytes generate random []byte by specify chars.
func RandomCreateBytes(n int, alphabets ...byte) []byte {
- const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
+ if len(alphabets) == 0 {
+ alphabets = alphaNum
+ }
var bytes = make([]byte, n)
- var randby bool
+ var randBy bool
if num, err := rand.Read(bytes); num != n || err != nil {
r.Seed(time.Now().UnixNano())
- randby = true
+ randBy = true
}
for i, b := range bytes {
- if len(alphabets) == 0 {
- if randby {
- bytes[i] = alphanum[r.Intn(len(alphanum))]
- } else {
- bytes[i] = alphanum[b%byte(len(alphanum))]
- }
+ if randBy {
+ bytes[i] = alphabets[r.Intn(len(alphabets))]
} else {
- if randby {
- bytes[i] = alphabets[r.Intn(len(alphabets))]
- } else {
- bytes[i] = alphabets[b%byte(len(alphabets))]
- }
+ bytes[i] = alphabets[b%byte(len(alphabets))]
}
}
return bytes
diff --git a/docs.go b/utils/rand_test.go
similarity index 50%
rename from docs.go
rename to utils/rand_test.go
index 72532876..6c238b5e 100644
--- a/docs.go
+++ b/utils/rand_test.go
@@ -1,4 +1,4 @@
-// Copyright 2014 beego Author. All Rights Reserved.
+// Copyright 2016 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,28 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package beego
+package utils
-import (
- "github.com/astaxie/beego/context"
-)
+import "testing"
-// GlobalDocAPI store the swagger api documents
-var GlobalDocAPI = make(map[string]interface{})
+func TestRand_01(t *testing.T) {
+ bs0 := RandomCreateBytes(16)
+ bs1 := RandomCreateBytes(16)
-func serverDocs(ctx *context.Context) {
- var obj interface{}
- if splat := ctx.Input.Param(":splat"); splat == "" {
- obj = GlobalDocAPI["Root"]
- } else {
- if v, ok := GlobalDocAPI[splat]; ok {
- obj = v
- }
+ t.Log(string(bs0), string(bs1))
+ if string(bs0) == string(bs1) {
+ t.FailNow()
}
- if obj != nil {
- ctx.Output.Header("Access-Control-Allow-Origin", "*")
- ctx.Output.JSON(obj, false, false)
- return
+
+ bs0 = RandomCreateBytes(4, []byte(`a`)...)
+
+ if string(bs0) != "aaaa" {
+ t.FailNow()
}
- ctx.Output.SetStatus(404)
}
diff --git a/validation/validation.go b/validation/validation.go
index 2b020aa8..489dfa5e 100644
--- a/validation/validation.go
+++ b/validation/validation.go
@@ -73,6 +73,10 @@ func (e *Error) String() string {
return e.Message
}
+// Implement Error interface.
+// Return e.String()
+func (e *Error) Error() string { return e.String() }
+
// Result is returned from every validation method.
// It provides an indication of success, and a pointer to the Error (if any).
type Result struct {