1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-22 07:20:55 +00:00
This commit is contained in:
Maxgis 2016-07-24 18:01:19 +08:00
commit 6a4ebc67ac
80 changed files with 2889 additions and 1047 deletions

View File

@ -1,8 +1,7 @@
language: go language: go
go: go:
- tip - 1.6
- 1.6.0
- 1.5.3 - 1.5.3
- 1.4.3 - 1.4.3
services: services:
@ -31,21 +30,20 @@ install:
- go get github.com/belogik/goes - go get github.com/belogik/goes
- go get github.com/siddontang/ledisdb/config - go get github.com/siddontang/ledisdb/config
- go get github.com/siddontang/ledisdb/ledis - go get github.com/siddontang/ledisdb/ledis
- go get golang.org/x/tools/cmd/vet
- go get github.com/golang/lint/golint
- go get github.com/ssdb/gossdb/ssdb - go get github.com/ssdb/gossdb/ssdb
before_script: before_script:
- psql --version
- sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi" - sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi"
- sh -c "if [ '$ORM_DRIVER' = 'mysql' ]; then mysql -u root -e 'create database orm_test;'; fi" - sh -c "if [ '$ORM_DRIVER' = 'mysql' ]; then mysql -u root -e 'create database orm_test;'; fi"
- sh -c "if [ '$ORM_DRIVER' = 'sqlite' ]; then touch $TRAVIS_BUILD_DIR/orm_test.db; fi" - sh -c "if [ '$ORM_DRIVER' = 'sqlite' ]; then touch $TRAVIS_BUILD_DIR/orm_test.db; fi"
- sh -c "if [ $(go version) == *1.[5-9]* ]; then go get github.com/golang/lint/golint; golint ./...; fi"
- sh -c "if [ $(go version) == *1.[5-9]* ]; then go tool vet .; fi"
- mkdir -p res/var - mkdir -p res/var
- ./ssdb/ssdb-server ./ssdb/ssdb.conf -d - ./ssdb/ssdb-server ./ssdb/ssdb.conf -d
after_script: after_script:
-killall -w ssdb-server -killall -w ssdb-server
- rm -rf ./res/var/* - rm -rf ./res/var/*
script: script:
- go vet -x ./...
- $HOME/gopath/bin/golint ./...
- go test -v ./... - go test -v ./...
notifications: addons:
webhooks: https://hooks.pubu.im/services/z7m9bvybl3rgtg9 postgresql: "9.4"

View File

@ -30,7 +30,7 @@ func main(){
``` ```
######Congratulations! ######Congratulations!
You just built your first beego app. You just built your first beego app.
Open your browser and visit `http://localhost:8000`. Open your browser and visit `http://localhost:8080`.
Please see [Documentation](http://beego.me/docs) for more. Please see [Documentation](http://beego.me/docs) for more.
## Features ## Features

View File

@ -23,7 +23,10 @@ import (
"text/template" "text/template"
"time" "time"
"reflect"
"github.com/astaxie/beego/grace" "github.com/astaxie/beego/grace"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/toolbox" "github.com/astaxie/beego/toolbox"
"github.com/astaxie/beego/utils" "github.com/astaxie/beego/utils"
) )
@ -90,57 +93,9 @@ func listConf(rw http.ResponseWriter, r *http.Request) {
switch command { switch command {
case "conf": case "conf":
m := make(map[string]interface{}) m := make(map[string]interface{})
list("BConfig", BConfig, m)
m["AppConfigPath"] = appConfigPath m["AppConfigPath"] = appConfigPath
m["AppConfigProvider"] = appConfigProvider m["AppConfigProvider"] = appConfigProvider
m["BConfig.AppName"] = BConfig.AppName
m["BConfig.RunMode"] = BConfig.RunMode
m["BConfig.RouterCaseSensitive"] = BConfig.RouterCaseSensitive
m["BConfig.ServerName"] = BConfig.ServerName
m["BConfig.RecoverPanic"] = BConfig.RecoverPanic
m["BConfig.CopyRequestBody"] = BConfig.CopyRequestBody
m["BConfig.EnableGzip"] = BConfig.EnableGzip
m["BConfig.MaxMemory"] = BConfig.MaxMemory
m["BConfig.EnableErrorsShow"] = BConfig.EnableErrorsShow
m["BConfig.Listen.Graceful"] = BConfig.Listen.Graceful
m["BConfig.Listen.ServerTimeOut"] = BConfig.Listen.ServerTimeOut
m["BConfig.Listen.ListenTCP4"] = BConfig.Listen.ListenTCP4
m["BConfig.Listen.EnableHTTP"] = BConfig.Listen.EnableHTTP
m["BConfig.Listen.HTTPAddr"] = BConfig.Listen.HTTPAddr
m["BConfig.Listen.HTTPPort"] = BConfig.Listen.HTTPPort
m["BConfig.Listen.EnableHTTPS"] = BConfig.Listen.EnableHTTPS
m["BConfig.Listen.HTTPSAddr"] = BConfig.Listen.HTTPSAddr
m["BConfig.Listen.HTTPSPort"] = BConfig.Listen.HTTPSPort
m["BConfig.Listen.HTTPSCertFile"] = BConfig.Listen.HTTPSCertFile
m["BConfig.Listen.HTTPSKeyFile"] = BConfig.Listen.HTTPSKeyFile
m["BConfig.Listen.EnableAdmin"] = BConfig.Listen.EnableAdmin
m["BConfig.Listen.AdminAddr"] = BConfig.Listen.AdminAddr
m["BConfig.Listen.AdminPort"] = BConfig.Listen.AdminPort
m["BConfig.Listen.EnableFcgi"] = BConfig.Listen.EnableFcgi
m["BConfig.Listen.EnableStdIo"] = BConfig.Listen.EnableStdIo
m["BConfig.WebConfig.AutoRender"] = BConfig.WebConfig.AutoRender
m["BConfig.WebConfig.EnableDocs"] = BConfig.WebConfig.EnableDocs
m["BConfig.WebConfig.FlashName"] = BConfig.WebConfig.FlashName
m["BConfig.WebConfig.FlashSeparator"] = BConfig.WebConfig.FlashSeparator
m["BConfig.WebConfig.DirectoryIndex"] = BConfig.WebConfig.DirectoryIndex
m["BConfig.WebConfig.StaticDir"] = BConfig.WebConfig.StaticDir
m["BConfig.WebConfig.StaticExtensionsToGzip"] = BConfig.WebConfig.StaticExtensionsToGzip
m["BConfig.WebConfig.TemplateLeft"] = BConfig.WebConfig.TemplateLeft
m["BConfig.WebConfig.TemplateRight"] = BConfig.WebConfig.TemplateRight
m["BConfig.WebConfig.ViewsPath"] = BConfig.WebConfig.ViewsPath
m["BConfig.WebConfig.EnableXSRF"] = BConfig.WebConfig.EnableXSRF
m["BConfig.WebConfig.XSRFKEY"] = BConfig.WebConfig.XSRFKey
m["BConfig.WebConfig.XSRFExpire"] = BConfig.WebConfig.XSRFExpire
m["BConfig.WebConfig.Session.SessionOn"] = BConfig.WebConfig.Session.SessionOn
m["BConfig.WebConfig.Session.SessionProvider"] = BConfig.WebConfig.Session.SessionProvider
m["BConfig.WebConfig.Session.SessionName"] = BConfig.WebConfig.Session.SessionName
m["BConfig.WebConfig.Session.SessionGCMaxLifetime"] = BConfig.WebConfig.Session.SessionGCMaxLifetime
m["BConfig.WebConfig.Session.SessionProviderConfig"] = BConfig.WebConfig.Session.SessionProviderConfig
m["BConfig.WebConfig.Session.SessionCookieLifeTime"] = BConfig.WebConfig.Session.SessionCookieLifeTime
m["BConfig.WebConfig.Session.SessionAutoSetCookie"] = BConfig.WebConfig.Session.SessionAutoSetCookie
m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain
m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs
m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum
m["BConfig.Log.Outputs"] = BConfig.Log.Outputs
tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
tmpl = template.Must(tmpl.Parse(configTpl)) tmpl = template.Must(tmpl.Parse(configTpl))
tmpl = template.Must(tmpl.Parse(defaultScriptsTpl)) tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
@ -196,7 +151,7 @@ func listConf(rw http.ResponseWriter, r *http.Request) {
BeforeExec: "Before Exec", BeforeExec: "Before Exec",
AfterExec: "After Exec", AfterExec: "After Exec",
FinishRouter: "Finish Router"} { FinishRouter: "Finish Router"} {
if bf, ok := BeeApp.Handlers.filters[k]; ok { if bf := BeeApp.Handlers.filters[k]; len(bf) > 0 {
filterType = fr filterType = fr
filterTypes = append(filterTypes, filterType) filterTypes = append(filterTypes, filterType)
resultList := new([][]string) resultList := new([][]string)
@ -223,6 +178,28 @@ func listConf(rw http.ResponseWriter, r *http.Request) {
} }
} }
func list(root string, p interface{}, m map[string]interface{}) {
pt := reflect.TypeOf(p)
pv := reflect.ValueOf(p)
if pt.Kind() == reflect.Ptr {
pt = pt.Elem()
pv = pv.Elem()
}
for i := 0; i < pv.NumField(); i++ {
var key string
if root == "" {
key = pt.Field(i).Name
} else {
key = root + "." + pt.Field(i).Name
}
if pv.Field(i).Kind() == reflect.Struct {
list(key, pv.Field(i).Interface(), m)
} else {
m[key] = pv.Field(i).Interface()
}
}
}
func printTree(resultList *[][]string, t *Tree) { func printTree(resultList *[][]string, t *Tree) {
for _, tr := range t.fixrouters { for _, tr := range t.fixrouters {
printTree(resultList, tr) printTree(resultList, tr)
@ -410,7 +387,7 @@ func (admin *adminApp) Run() {
for p, f := range admin.routers { for p, f := range admin.routers {
http.Handle(p, f) http.Handle(p, f)
} }
BeeLogger.Info("Admin server Running on %s", addr) logs.Info("Admin server Running on %s", addr)
var err error var err error
if BConfig.Listen.Graceful { if BConfig.Listen.Graceful {
@ -419,6 +396,6 @@ func (admin *adminApp) Run() {
err = http.ListenAndServe(addr, nil) err = http.ListenAndServe(addr, nil)
} }
if err != nil { if err != nil {
BeeLogger.Critical("Admin ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) logs.Critical("Admin ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid()))
} }
} }

72
admin_test.go Normal file
View File

@ -0,0 +1,72 @@
package beego
import (
"testing"
"fmt"
)
func TestList_01(t *testing.T) {
m := make(map[string]interface{})
list("BConfig", BConfig, m)
t.Log(m)
om := oldMap()
for k, v := range om {
if fmt.Sprint(m[k])!= fmt.Sprint(v) {
t.Log(k, "old-key",v,"new-key", m[k])
t.FailNow()
}
}
}
func oldMap() map[string]interface{} {
m := make(map[string]interface{})
m["BConfig.AppName"] = BConfig.AppName
m["BConfig.RunMode"] = BConfig.RunMode
m["BConfig.RouterCaseSensitive"] = BConfig.RouterCaseSensitive
m["BConfig.ServerName"] = BConfig.ServerName
m["BConfig.RecoverPanic"] = BConfig.RecoverPanic
m["BConfig.CopyRequestBody"] = BConfig.CopyRequestBody
m["BConfig.EnableGzip"] = BConfig.EnableGzip
m["BConfig.MaxMemory"] = BConfig.MaxMemory
m["BConfig.EnableErrorsShow"] = BConfig.EnableErrorsShow
m["BConfig.Listen.Graceful"] = BConfig.Listen.Graceful
m["BConfig.Listen.ServerTimeOut"] = BConfig.Listen.ServerTimeOut
m["BConfig.Listen.ListenTCP4"] = BConfig.Listen.ListenTCP4
m["BConfig.Listen.EnableHTTP"] = BConfig.Listen.EnableHTTP
m["BConfig.Listen.HTTPAddr"] = BConfig.Listen.HTTPAddr
m["BConfig.Listen.HTTPPort"] = BConfig.Listen.HTTPPort
m["BConfig.Listen.EnableHTTPS"] = BConfig.Listen.EnableHTTPS
m["BConfig.Listen.HTTPSAddr"] = BConfig.Listen.HTTPSAddr
m["BConfig.Listen.HTTPSPort"] = BConfig.Listen.HTTPSPort
m["BConfig.Listen.HTTPSCertFile"] = BConfig.Listen.HTTPSCertFile
m["BConfig.Listen.HTTPSKeyFile"] = BConfig.Listen.HTTPSKeyFile
m["BConfig.Listen.EnableAdmin"] = BConfig.Listen.EnableAdmin
m["BConfig.Listen.AdminAddr"] = BConfig.Listen.AdminAddr
m["BConfig.Listen.AdminPort"] = BConfig.Listen.AdminPort
m["BConfig.Listen.EnableFcgi"] = BConfig.Listen.EnableFcgi
m["BConfig.Listen.EnableStdIo"] = BConfig.Listen.EnableStdIo
m["BConfig.WebConfig.AutoRender"] = BConfig.WebConfig.AutoRender
m["BConfig.WebConfig.EnableDocs"] = BConfig.WebConfig.EnableDocs
m["BConfig.WebConfig.FlashName"] = BConfig.WebConfig.FlashName
m["BConfig.WebConfig.FlashSeparator"] = BConfig.WebConfig.FlashSeparator
m["BConfig.WebConfig.DirectoryIndex"] = BConfig.WebConfig.DirectoryIndex
m["BConfig.WebConfig.StaticDir"] = BConfig.WebConfig.StaticDir
m["BConfig.WebConfig.StaticExtensionsToGzip"] = BConfig.WebConfig.StaticExtensionsToGzip
m["BConfig.WebConfig.TemplateLeft"] = BConfig.WebConfig.TemplateLeft
m["BConfig.WebConfig.TemplateRight"] = BConfig.WebConfig.TemplateRight
m["BConfig.WebConfig.ViewsPath"] = BConfig.WebConfig.ViewsPath
m["BConfig.WebConfig.EnableXSRF"] = BConfig.WebConfig.EnableXSRF
m["BConfig.WebConfig.XSRFExpire"] = BConfig.WebConfig.XSRFExpire
m["BConfig.WebConfig.Session.SessionOn"] = BConfig.WebConfig.Session.SessionOn
m["BConfig.WebConfig.Session.SessionProvider"] = BConfig.WebConfig.Session.SessionProvider
m["BConfig.WebConfig.Session.SessionName"] = BConfig.WebConfig.Session.SessionName
m["BConfig.WebConfig.Session.SessionGCMaxLifetime"] = BConfig.WebConfig.Session.SessionGCMaxLifetime
m["BConfig.WebConfig.Session.SessionProviderConfig"] = BConfig.WebConfig.Session.SessionProviderConfig
m["BConfig.WebConfig.Session.SessionCookieLifeTime"] = BConfig.WebConfig.Session.SessionCookieLifeTime
m["BConfig.WebConfig.Session.SessionAutoSetCookie"] = BConfig.WebConfig.Session.SessionAutoSetCookie
m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain
m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs
m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum
m["BConfig.Log.Outputs"] = BConfig.Log.Outputs
return m
}

30
app.go
View File

@ -24,6 +24,7 @@ import (
"time" "time"
"github.com/astaxie/beego/grace" "github.com/astaxie/beego/grace"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/utils" "github.com/astaxie/beego/utils"
) )
@ -68,9 +69,9 @@ func (app *App) Run() {
if BConfig.Listen.EnableFcgi { if BConfig.Listen.EnableFcgi {
if BConfig.Listen.EnableStdIo { if BConfig.Listen.EnableStdIo {
if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O
BeeLogger.Info("Use FCGI via standard I/O") logs.Info("Use FCGI via standard I/O")
} else { } else {
BeeLogger.Critical("Cannot use FCGI via standard I/O", err) logs.Critical("Cannot use FCGI via standard I/O", err)
} }
return return
} }
@ -84,10 +85,10 @@ func (app *App) Run() {
l, err = net.Listen("tcp", addr) l, err = net.Listen("tcp", addr)
} }
if err != nil { if err != nil {
BeeLogger.Critical("Listen: ", err) logs.Critical("Listen: ", err)
} }
if err = fcgi.Serve(l, app.Handlers); err != nil { if err = fcgi.Serve(l, app.Handlers); err != nil {
BeeLogger.Critical("fcgi.Serve: ", err) logs.Critical("fcgi.Serve: ", err)
} }
return return
} }
@ -95,6 +96,7 @@ func (app *App) Run() {
app.Server.Handler = app.Handlers app.Server.Handler = app.Handlers
app.Server.ReadTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second app.Server.ReadTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second
app.Server.WriteTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second app.Server.WriteTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second
app.Server.ErrorLog = logs.GetLogger("HTTP")
// run graceful mode // run graceful mode
if BConfig.Listen.Graceful { if BConfig.Listen.Graceful {
@ -111,7 +113,7 @@ func (app *App) Run() {
server.Server.ReadTimeout = app.Server.ReadTimeout server.Server.ReadTimeout = app.Server.ReadTimeout
server.Server.WriteTimeout = app.Server.WriteTimeout server.Server.WriteTimeout = app.Server.WriteTimeout
if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil {
BeeLogger.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond) time.Sleep(100 * time.Microsecond)
endRunning <- true endRunning <- true
} }
@ -126,7 +128,7 @@ func (app *App) Run() {
server.Network = "tcp4" server.Network = "tcp4"
} }
if err := server.ListenAndServe(); err != nil { if err := server.ListenAndServe(); err != nil {
BeeLogger.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) logs.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond) time.Sleep(100 * time.Microsecond)
endRunning <- true endRunning <- true
} }
@ -137,16 +139,18 @@ func (app *App) Run() {
} }
// run normal mode // run normal mode
app.Server.Addr = addr
if BConfig.Listen.EnableHTTPS { if BConfig.Listen.EnableHTTPS {
go func() { go func() {
time.Sleep(20 * time.Microsecond) time.Sleep(20 * time.Microsecond)
if BConfig.Listen.HTTPSPort != 0 { if BConfig.Listen.HTTPSPort != 0 {
app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort)
} else if BConfig.Listen.EnableHTTP {
BeeLogger.Info("Start https server error, confict with http.Please reset https port")
return
} }
BeeLogger.Info("https server Running on %s", app.Server.Addr) logs.Info("https server Running on %s", app.Server.Addr)
if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil {
BeeLogger.Critical("ListenAndServeTLS: ", err) logs.Critical("ListenAndServeTLS: ", err)
time.Sleep(100 * time.Microsecond) time.Sleep(100 * time.Microsecond)
endRunning <- true endRunning <- true
} }
@ -155,24 +159,24 @@ func (app *App) Run() {
if BConfig.Listen.EnableHTTP { if BConfig.Listen.EnableHTTP {
go func() { go func() {
app.Server.Addr = addr app.Server.Addr = addr
BeeLogger.Info("http server Running on %s", app.Server.Addr) logs.Info("http server Running on %s", app.Server.Addr)
if BConfig.Listen.ListenTCP4 { if BConfig.Listen.ListenTCP4 {
ln, err := net.Listen("tcp4", app.Server.Addr) ln, err := net.Listen("tcp4", app.Server.Addr)
if err != nil { if err != nil {
BeeLogger.Critical("ListenAndServe: ", err) logs.Critical("ListenAndServe: ", err)
time.Sleep(100 * time.Microsecond) time.Sleep(100 * time.Microsecond)
endRunning <- true endRunning <- true
return return
} }
if err = app.Server.Serve(ln); err != nil { if err = app.Server.Serve(ln); err != nil {
BeeLogger.Critical("ListenAndServe: ", err) logs.Critical("ListenAndServe: ", err)
time.Sleep(100 * time.Microsecond) time.Sleep(100 * time.Microsecond)
endRunning <- true endRunning <- true
return return
} }
} else { } else {
if err := app.Server.ListenAndServe(); err != nil { if err := app.Server.ListenAndServe(); err != nil {
BeeLogger.Critical("ListenAndServe: ", err) logs.Critical("ListenAndServe: ", err)
time.Sleep(100 * time.Microsecond) time.Sleep(100 * time.Microsecond)
endRunning <- true endRunning <- true
} }

View File

@ -51,6 +51,7 @@ func AddAPPStartHook(hf hookfunc) {
// beego.Run(":8089") // beego.Run(":8089")
// beego.Run("127.0.0.1:8089") // beego.Run("127.0.0.1:8089")
func Run(params ...string) { func Run(params ...string) {
initBeforeHTTPRun() initBeforeHTTPRun()
if len(params) > 0 && params[0] != "" { if len(params) > 0 && params[0] != "" {
@ -71,9 +72,9 @@ func initBeforeHTTPRun() {
AddAPPStartHook(registerMime) AddAPPStartHook(registerMime)
AddAPPStartHook(registerDefaultErrorHandler) AddAPPStartHook(registerDefaultErrorHandler)
AddAPPStartHook(registerSession) AddAPPStartHook(registerSession)
AddAPPStartHook(registerDocs)
AddAPPStartHook(registerTemplate) AddAPPStartHook(registerTemplate)
AddAPPStartHook(registerAdmin) AddAPPStartHook(registerAdmin)
AddAPPStartHook(registerGzip)
for _, hk := range hooks { for _, hk := range hooks {
if err := hk(); err != nil { if err := hk(); err != nil {
@ -84,8 +85,11 @@ func initBeforeHTTPRun() {
// TestBeegoInit is for test package init // TestBeegoInit is for test package init
func TestBeegoInit(ap string) { func TestBeegoInit(ap string) {
os.Setenv("BEEGO_RUNMODE", "test")
appConfigPath = filepath.Join(ap, "conf", "app.conf") appConfigPath = filepath.Join(ap, "conf", "app.conf")
os.Chdir(ap) os.Chdir(ap)
if err := LoadAppConfig(appConfigProvider, appConfigPath); err != nil {
panic(err)
}
BConfig.RunMode = "test"
initBeforeHTTPRun() initBeforeHTTPRun()
} }

View File

@ -18,9 +18,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/garyburd/redigo/redis"
"github.com/astaxie/beego/cache" "github.com/astaxie/beego/cache"
"github.com/garyburd/redigo/redis"
) )
func TestRedisCache(t *testing.T) { func TestRedisCache(t *testing.T) {

View File

@ -1,10 +1,11 @@
package ssdb package ssdb
import ( import (
"github.com/astaxie/beego/cache"
"strconv" "strconv"
"testing" "testing"
"time" "time"
"github.com/astaxie/beego/cache"
) )
func TestSsdbcacheCache(t *testing.T) { func TestSsdbcacheCache(t *testing.T) {

188
config.go
View File

@ -18,9 +18,11 @@ import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"strings" "strings"
"github.com/astaxie/beego/config" "github.com/astaxie/beego/config"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/session" "github.com/astaxie/beego/session"
"github.com/astaxie/beego/utils" "github.com/astaxie/beego/utils"
) )
@ -81,14 +83,17 @@ type WebConfig struct {
// SessionConfig holds session related config // SessionConfig holds session related config
type SessionConfig struct { type SessionConfig struct {
SessionOn bool SessionOn bool
SessionProvider string SessionProvider string
SessionName string SessionName string
SessionGCMaxLifetime int64 SessionGCMaxLifetime int64
SessionProviderConfig string SessionProviderConfig string
SessionCookieLifeTime int SessionCookieLifeTime int
SessionAutoSetCookie bool SessionAutoSetCookie bool
SessionDomain string SessionDomain string
EnableSidInHttpHeader bool // enable store/get the sessionId into/from http headers
SessionNameInHttpHeader string
EnableSidInUrlQuery bool // enable get the sessionId from Url Query params
} }
// LogConfig holds Log related config // LogConfig holds Log related config
@ -115,11 +120,30 @@ var (
) )
func init() { func init() {
AppPath, _ = filepath.Abs(filepath.Dir(os.Args[0])) BConfig = newBConfig()
var err error
if AppPath, err = filepath.Abs(filepath.Dir(os.Args[0])); err != nil {
panic(err)
}
workPath, err := os.Getwd()
if err != nil {
panic(err)
}
appConfigPath = filepath.Join(workPath, "conf", "app.conf")
if !utils.FileExists(appConfigPath) {
appConfigPath = filepath.Join(AppPath, "conf", "app.conf")
if !utils.FileExists(appConfigPath) {
AppConfig = &beegoAppConfig{innerConfig: config.NewFakeConfig()}
return
}
}
if err = parseConfig(appConfigPath); err != nil {
panic(err)
}
}
os.Chdir(AppPath) func newBConfig() *Config {
return &Config{
BConfig = &Config{
AppName: "beego", AppName: "beego",
RunMode: DEV, RunMode: DEV,
RouterCaseSensitive: true, RouterCaseSensitive: true,
@ -162,14 +186,17 @@ func init() {
XSRFKey: "beegoxsrf", XSRFKey: "beegoxsrf",
XSRFExpire: 0, XSRFExpire: 0,
Session: SessionConfig{ Session: SessionConfig{
SessionOn: false, SessionOn: false,
SessionProvider: "memory", SessionProvider: "memory",
SessionName: "beegosessionID", SessionName: "beegosessionID",
SessionGCMaxLifetime: 3600, SessionGCMaxLifetime: 3600,
SessionProviderConfig: "", SessionProviderConfig: "",
SessionCookieLifeTime: 0, //set cookie default is the browser life SessionCookieLifeTime: 0, //set cookie default is the browser life
SessionAutoSetCookie: true, SessionAutoSetCookie: true,
SessionDomain: "", SessionDomain: "",
EnableSidInHttpHeader: false, // enable store/get the sessionId into/from http headers
SessionNameInHttpHeader: "Beegosessionid",
EnableSidInUrlQuery: false, // enable get the sessionId from Url Query params
}, },
}, },
Log: LogConfig{ Log: LogConfig{
@ -178,16 +205,6 @@ func init() {
Outputs: map[string]string{"console": ""}, Outputs: map[string]string{"console": ""},
}, },
} }
appConfigPath = filepath.Join(AppPath, "conf", "app.conf")
if !utils.FileExists(appConfigPath) {
AppConfig = &beegoAppConfig{innerConfig: config.NewFakeConfig()}
return
}
if err := parseConfig(appConfigPath); err != nil {
panic(err)
}
} }
// now only support ini, next will support json. // now only support ini, next will support json.
@ -196,63 +213,23 @@ func parseConfig(appConfigPath string) (err error) {
if err != nil { if err != nil {
return err return err
} }
return assignConfig(AppConfig)
}
func assignConfig(ac config.Configer) error {
// set the run mode first // set the run mode first
if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" { if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" {
BConfig.RunMode = envRunMode BConfig.RunMode = envRunMode
} else if runMode := AppConfig.String("RunMode"); runMode != "" { } else if runMode := ac.String("RunMode"); runMode != "" {
BConfig.RunMode = runMode BConfig.RunMode = runMode
} }
BConfig.AppName = AppConfig.DefaultString("AppName", BConfig.AppName) for _, i := range []interface{}{BConfig, &BConfig.Listen, &BConfig.WebConfig, &BConfig.Log, &BConfig.WebConfig.Session} {
BConfig.RecoverPanic = AppConfig.DefaultBool("RecoverPanic", BConfig.RecoverPanic) assignSingleConfig(i, ac)
BConfig.RouterCaseSensitive = AppConfig.DefaultBool("RouterCaseSensitive", BConfig.RouterCaseSensitive) }
BConfig.ServerName = AppConfig.DefaultString("ServerName", BConfig.ServerName)
BConfig.EnableGzip = AppConfig.DefaultBool("EnableGzip", BConfig.EnableGzip)
BConfig.EnableErrorsShow = AppConfig.DefaultBool("EnableErrorsShow", BConfig.EnableErrorsShow)
BConfig.CopyRequestBody = AppConfig.DefaultBool("CopyRequestBody", BConfig.CopyRequestBody)
BConfig.MaxMemory = AppConfig.DefaultInt64("MaxMemory", BConfig.MaxMemory)
BConfig.Listen.Graceful = AppConfig.DefaultBool("Graceful", BConfig.Listen.Graceful)
BConfig.Listen.HTTPAddr = AppConfig.String("HTTPAddr")
BConfig.Listen.HTTPPort = AppConfig.DefaultInt("HTTPPort", BConfig.Listen.HTTPPort)
BConfig.Listen.ListenTCP4 = AppConfig.DefaultBool("ListenTCP4", BConfig.Listen.ListenTCP4)
BConfig.Listen.EnableHTTP = AppConfig.DefaultBool("EnableHTTP", BConfig.Listen.EnableHTTP)
BConfig.Listen.EnableHTTPS = AppConfig.DefaultBool("EnableHTTPS", BConfig.Listen.EnableHTTPS)
BConfig.Listen.HTTPSAddr = AppConfig.DefaultString("HTTPSAddr", BConfig.Listen.HTTPSAddr)
BConfig.Listen.HTTPSPort = AppConfig.DefaultInt("HTTPSPort", BConfig.Listen.HTTPSPort)
BConfig.Listen.HTTPSCertFile = AppConfig.DefaultString("HTTPSCertFile", BConfig.Listen.HTTPSCertFile)
BConfig.Listen.HTTPSKeyFile = AppConfig.DefaultString("HTTPSKeyFile", BConfig.Listen.HTTPSKeyFile)
BConfig.Listen.EnableAdmin = AppConfig.DefaultBool("EnableAdmin", BConfig.Listen.EnableAdmin)
BConfig.Listen.AdminAddr = AppConfig.DefaultString("AdminAddr", BConfig.Listen.AdminAddr)
BConfig.Listen.AdminPort = AppConfig.DefaultInt("AdminPort", BConfig.Listen.AdminPort)
BConfig.Listen.EnableFcgi = AppConfig.DefaultBool("EnableFcgi", BConfig.Listen.EnableFcgi)
BConfig.Listen.EnableStdIo = AppConfig.DefaultBool("EnableStdIo", BConfig.Listen.EnableStdIo)
BConfig.Listen.ServerTimeOut = AppConfig.DefaultInt64("ServerTimeOut", BConfig.Listen.ServerTimeOut)
BConfig.WebConfig.AutoRender = AppConfig.DefaultBool("AutoRender", BConfig.WebConfig.AutoRender)
BConfig.WebConfig.ViewsPath = AppConfig.DefaultString("ViewsPath", BConfig.WebConfig.ViewsPath)
BConfig.WebConfig.DirectoryIndex = AppConfig.DefaultBool("DirectoryIndex", BConfig.WebConfig.DirectoryIndex)
BConfig.WebConfig.FlashName = AppConfig.DefaultString("FlashName", BConfig.WebConfig.FlashName)
BConfig.WebConfig.FlashSeparator = AppConfig.DefaultString("FlashSeparator", BConfig.WebConfig.FlashSeparator)
BConfig.WebConfig.EnableDocs = AppConfig.DefaultBool("EnableDocs", BConfig.WebConfig.EnableDocs)
BConfig.WebConfig.XSRFKey = AppConfig.DefaultString("XSRFKEY", BConfig.WebConfig.XSRFKey)
BConfig.WebConfig.EnableXSRF = AppConfig.DefaultBool("EnableXSRF", BConfig.WebConfig.EnableXSRF)
BConfig.WebConfig.XSRFExpire = AppConfig.DefaultInt("XSRFExpire", BConfig.WebConfig.XSRFExpire)
BConfig.WebConfig.TemplateLeft = AppConfig.DefaultString("TemplateLeft", BConfig.WebConfig.TemplateLeft)
BConfig.WebConfig.TemplateRight = AppConfig.DefaultString("TemplateRight", BConfig.WebConfig.TemplateRight)
BConfig.WebConfig.Session.SessionOn = AppConfig.DefaultBool("SessionOn", BConfig.WebConfig.Session.SessionOn)
BConfig.WebConfig.Session.SessionProvider = AppConfig.DefaultString("SessionProvider", BConfig.WebConfig.Session.SessionProvider)
BConfig.WebConfig.Session.SessionName = AppConfig.DefaultString("SessionName", BConfig.WebConfig.Session.SessionName)
BConfig.WebConfig.Session.SessionProviderConfig = AppConfig.DefaultString("SessionProviderConfig", BConfig.WebConfig.Session.SessionProviderConfig)
BConfig.WebConfig.Session.SessionGCMaxLifetime = AppConfig.DefaultInt64("SessionGCMaxLifetime", BConfig.WebConfig.Session.SessionGCMaxLifetime)
BConfig.WebConfig.Session.SessionCookieLifeTime = AppConfig.DefaultInt("SessionCookieLifeTime", BConfig.WebConfig.Session.SessionCookieLifeTime)
BConfig.WebConfig.Session.SessionAutoSetCookie = AppConfig.DefaultBool("SessionAutoSetCookie", BConfig.WebConfig.Session.SessionAutoSetCookie)
BConfig.WebConfig.Session.SessionDomain = AppConfig.DefaultString("SessionDomain", BConfig.WebConfig.Session.SessionDomain)
BConfig.Log.AccessLogs = AppConfig.DefaultBool("LogAccessLogs", BConfig.Log.AccessLogs)
BConfig.Log.FileLineNum = AppConfig.DefaultBool("LogFileLineNum", BConfig.Log.FileLineNum)
if sd := AppConfig.String("StaticDir"); sd != "" { if sd := ac.String("StaticDir"); sd != "" {
for k := range BConfig.WebConfig.StaticDir { BConfig.WebConfig.StaticDir = map[string]string{}
delete(BConfig.WebConfig.StaticDir, k)
}
sds := strings.Fields(sd) sds := strings.Fields(sd)
for _, v := range sds { for _, v := range sds {
if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 { if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 {
@ -262,7 +239,8 @@ func parseConfig(appConfigPath string) (err error) {
} }
} }
} }
if sgz := AppConfig.String("StaticExtensionsToGzip"); sgz != "" {
if sgz := ac.String("StaticExtensionsToGzip"); sgz != "" {
extensions := strings.Split(sgz, ",") extensions := strings.Split(sgz, ",")
fileExts := []string{} fileExts := []string{}
for _, ext := range extensions { for _, ext := range extensions {
@ -280,7 +258,7 @@ func parseConfig(appConfigPath string) (err error) {
} }
} }
if lo := AppConfig.String("LogOutputs"); lo != "" { if lo := ac.String("LogOutputs"); lo != "" {
los := strings.Split(lo, ";") los := strings.Split(lo, ";")
for _, v := range los { for _, v := range los {
if logType2Config := strings.SplitN(v, ",", 2); len(logType2Config) == 2 { if logType2Config := strings.SplitN(v, ",", 2); len(logType2Config) == 2 {
@ -292,18 +270,50 @@ func parseConfig(appConfigPath string) (err error) {
} }
//init log //init log
BeeLogger.Reset() logs.Reset()
for adaptor, config := range BConfig.Log.Outputs { for adaptor, config := range BConfig.Log.Outputs {
err = BeeLogger.SetLogger(adaptor, config) err := logs.SetLogger(adaptor, config)
if err != nil { if err != nil {
fmt.Printf("%s with the config `%s` got err:%s\n", adaptor, config, err) fmt.Fprintln(os.Stderr, fmt.Sprintf("%s with the config %q got err:%s", adaptor, config, err.Error()))
} }
} }
SetLogFuncCall(BConfig.Log.FileLineNum) logs.SetLogFuncCall(BConfig.Log.FileLineNum)
return nil return nil
} }
func assignSingleConfig(p interface{}, ac config.Configer) {
pt := reflect.TypeOf(p)
if pt.Kind() != reflect.Ptr {
return
}
pt = pt.Elem()
if pt.Kind() != reflect.Struct {
return
}
pv := reflect.ValueOf(p).Elem()
for i := 0; i < pt.NumField(); i++ {
pf := pv.Field(i)
if !pf.CanSet() {
continue
}
name := pt.Field(i).Name
switch pf.Kind() {
case reflect.String:
pf.SetString(ac.DefaultString(name, pf.String()))
case reflect.Int, reflect.Int64:
pf.SetInt(int64(ac.DefaultInt64(name, pf.Int())))
case reflect.Bool:
pf.SetBool(ac.DefaultBool(name, pf.Bool()))
case reflect.Struct:
default:
//do nothing here
}
}
}
// LoadAppConfig allow developer to apply a config file // LoadAppConfig allow developer to apply a config file
func LoadAppConfig(adapterName, configPath string) error { func LoadAppConfig(adapterName, configPath string) error {
absConfigPath, err := filepath.Abs(configPath) absConfigPath, err := filepath.Abs(configPath)
@ -315,10 +325,6 @@ func LoadAppConfig(adapterName, configPath string) error {
return fmt.Errorf("the target config file: %s don't exist", configPath) return fmt.Errorf("the target config file: %s don't exist", configPath)
} }
if absConfigPath == appConfigPath {
return nil
}
appConfigPath = absConfigPath appConfigPath = absConfigPath
appConfigProvider = adapterName appConfigProvider = adapterName
@ -352,7 +358,7 @@ func (b *beegoAppConfig) String(key string) string {
} }
func (b *beegoAppConfig) Strings(key string) []string { func (b *beegoAppConfig) Strings(key string) []string {
if v := b.innerConfig.Strings(BConfig.RunMode + "::" + key); v[0] != "" { if v := b.innerConfig.Strings(BConfig.RunMode + "::" + key); len(v) > 0 {
return v return v
} }
return b.innerConfig.Strings(key) return b.innerConfig.Strings(key)

View File

@ -12,11 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Package config is used to parse config // Package config is used to parse config.
// Usage: // Usage:
// import( // import "github.com/astaxie/beego/config"
// "github.com/astaxie/beego/config" //Examples.
// )
// //
// cnf, err := config.NewConfig("ini", "config.conf") // cnf, err := config.NewConfig("ini", "config.conf")
// //
@ -38,12 +37,12 @@
// cnf.DIY(key string) (interface{}, error) // cnf.DIY(key string) (interface{}, error)
// cnf.GetSection(section string) (map[string]string, error) // cnf.GetSection(section string) (map[string]string, error)
// cnf.SaveConfigFile(filename string) error // cnf.SaveConfigFile(filename string) error
// //More docs http://beego.me/docs/module/config.md
// more docs http://beego.me/docs/module/config.md
package config package config
import ( import (
"fmt" "fmt"
"os"
) )
// Configer defines how to get and set value from configuration raw data. // Configer defines how to get and set value from configuration raw data.
@ -107,6 +106,69 @@ func NewConfigData(adapterName string, data []byte) (Configer, error) {
return adapter.ParseData(data) return adapter.ParseData(data)
} }
// ExpandValueEnvForMap convert all string value with environment variable.
func ExpandValueEnvForMap(m map[string]interface{}) map[string]interface{} {
for k, v := range m {
switch value := v.(type) {
case string:
m[k] = ExpandValueEnv(value)
case map[string]interface{}:
m[k] = ExpandValueEnvForMap(value)
case map[string]string:
for k2, v2 := range value {
value[k2] = ExpandValueEnv(v2)
}
m[k] = value
}
}
return m
}
// ExpandValueEnv returns value of convert with environment variable.
//
// Return environment variable if value start with "${" and end with "}".
// Return default value if environment variable is empty or not exist.
//
// It accept value formats "${env}" , "${env||}}" , "${env||defaultValue}" , "defaultvalue".
// Examples:
// v1 := config.ExpandValueEnv("${GOPATH}") // return the GOPATH environment variable.
// v2 := config.ExpandValueEnv("${GOAsta||/usr/local/go}") // return the default value "/usr/local/go/".
// v3 := config.ExpandValueEnv("Astaxie") // return the value "Astaxie".
func ExpandValueEnv(value string) (realValue string) {
realValue = value
vLen := len(value)
// 3 = ${}
if vLen < 3 {
return
}
// Need start with "${" and end with "}", then return.
if value[0] != '$' || value[1] != '{' || value[vLen-1] != '}' {
return
}
key := ""
defalutV := ""
// value start with "${"
for i := 2; i < vLen; i++ {
if value[i] == '|' && (i+1 < vLen && value[i+1] == '|') {
key = value[2:i]
defalutV = value[i+2 : vLen-1] // other string is default value.
break
} else if value[i] == '}' {
key = value[2:i]
break
}
}
realValue = os.Getenv(key)
if realValue == "" {
realValue = defalutV
}
return
}
// ParseBool returns the boolean value represented by the string. // ParseBool returns the boolean value represented by the string.
// //
// It accepts 1, 1.0, t, T, TRUE, true, True, YES, yes, Yes,Y, y, ON, on, On, // It accepts 1, 1.0, t, T, TRUE, true, True, YES, yes, Yes,Y, y, ON, on, On,

55
config/config_test.go Normal file
View File

@ -0,0 +1,55 @@
// Copyright 2016 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package config
import (
"os"
"testing"
)
func TestExpandValueEnv(t *testing.T) {
testCases := []struct {
item string
want string
}{
{"", ""},
{"$", "$"},
{"{", "{"},
{"{}", "{}"},
{"${}", ""},
{"${|}", ""},
{"${}", ""},
{"${{}}", ""},
{"${{||}}", "}"},
{"${pwd||}", ""},
{"${pwd||}", ""},
{"${pwd||}", ""},
{"${pwd||}}", "}"},
{"${pwd||{{||}}}", "{{||}}"},
{"${GOPATH}", os.Getenv("GOPATH")},
{"${GOPATH||}", os.Getenv("GOPATH")},
{"${GOPATH||root}", os.Getenv("GOPATH")},
{"${GOPATH_NOT||root}", "root"},
{"${GOPATH_NOT||||root}", "||root"},
}
for _, c := range testCases {
if got := ExpandValueEnv(c.item); got != c.want {
t.Errorf("expand value error, item %q want %q, got %q", c.item, c.want, got)
}
}
}

View File

@ -38,7 +38,7 @@ func (c *fakeConfigContainer) String(key string) string {
} }
func (c *fakeConfigContainer) DefaultString(key string, defaultval string) string { func (c *fakeConfigContainer) DefaultString(key string, defaultval string) string {
v := c.getData(key) v := c.String(key)
if v == "" { if v == "" {
return defaultval return defaultval
} }
@ -46,7 +46,7 @@ func (c *fakeConfigContainer) DefaultString(key string, defaultval string) strin
} }
func (c *fakeConfigContainer) Strings(key string) []string { func (c *fakeConfigContainer) Strings(key string) []string {
v := c.getData(key) v := c.String(key)
if v == "" { if v == "" {
return nil return nil
} }

View File

@ -82,6 +82,10 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
if err == io.EOF { if err == io.EOF {
break break
} }
//It might be a good idea to throw a error on all unknonw errors?
if _, ok := err.(*os.PathError); ok {
return nil, err
}
if bytes.Equal(line, bEmpty) { if bytes.Equal(line, bEmpty) {
continue continue
} }
@ -162,7 +166,7 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
val = bytes.Trim(val, `"`) val = bytes.Trim(val, `"`)
} }
cfg.data[section][key] = string(val) cfg.data[section][key] = ExpandValueEnv(string(val))
if comment.Len() > 0 { if comment.Len() > 0 {
cfg.keyComment[section+"."+key] = comment.String() cfg.keyComment[section+"."+key] = comment.String()
comment.Reset() comment.Reset()
@ -296,7 +300,9 @@ func (c *IniConfigContainer) GetSection(section string) (map[string]string, erro
return nil, errors.New("not exist setction") return nil, errors.New("not exist setction")
} }
// SaveConfigFile save the config into file // SaveConfigFile save the config into file.
//
// BUG(env): The environment variable config item will be saved with real value in SaveConfigFile Funcation.
func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) {
// Write configuration file by filename. // Write configuration file by filename.
f, err := os.Create(filename) f, err := os.Create(filename)

View File

@ -42,11 +42,14 @@ needlogin = ON
enableSession = Y enableSession = Y
enableCookie = N enableCookie = N
flag = 1 flag = 1
path1 = ${GOPATH}
path2 = ${GOPATH||/home/go}
[demo] [demo]
key1="asta" key1="asta"
key2 = "xie" key2 = "xie"
CaseInsensitive = true CaseInsensitive = true
peers = one;two;three peers = one;two;three
password = ${GOPATH}
` `
keyValue = map[string]interface{}{ keyValue = map[string]interface{}{
@ -64,10 +67,13 @@ peers = one;two;three
"enableSession": true, "enableSession": true,
"enableCookie": false, "enableCookie": false,
"flag": true, "flag": true,
"path1": os.Getenv("GOPATH"),
"path2": os.Getenv("GOPATH"),
"demo::key1": "asta", "demo::key1": "asta",
"demo::key2": "xie", "demo::key2": "xie",
"demo::CaseInsensitive": true, "demo::CaseInsensitive": true,
"demo::peers": []string{"one", "two", "three"}, "demo::peers": []string{"one", "two", "three"},
"demo::password": os.Getenv("GOPATH"),
"null": "", "null": "",
"demo2::key1": "", "demo2::key1": "",
"error": "", "error": "",

View File

@ -57,6 +57,9 @@ func (js *JSONConfig) ParseData(data []byte) (Configer, error) {
} }
x.data["rootArray"] = wrappingArray x.data["rootArray"] = wrappingArray
} }
x.data = ExpandValueEnvForMap(x.data)
return x, nil return x, nil
} }

View File

@ -86,16 +86,19 @@ func TestJson(t *testing.T) {
"enableSession": "Y", "enableSession": "Y",
"enableCookie": "N", "enableCookie": "N",
"flag": 1, "flag": 1,
"path1": "${GOPATH}",
"path2": "${GOPATH||/home/go}",
"database": { "database": {
"host": "host", "host": "host",
"port": "port", "port": "port",
"database": "database", "database": "database",
"username": "username", "username": "username",
"password": "password", "password": "${GOPATH}",
"conns":{ "conns":{
"maxconnection":12, "maxconnection":12,
"autoconnect":true, "autoconnect":true,
"connectioninfo":"info" "connectioninfo":"info",
"root": "${GOPATH}"
} }
} }
}` }`
@ -115,13 +118,16 @@ func TestJson(t *testing.T) {
"enableSession": true, "enableSession": true,
"enableCookie": false, "enableCookie": false,
"flag": true, "flag": true,
"path1": os.Getenv("GOPATH"),
"path2": os.Getenv("GOPATH"),
"database::host": "host", "database::host": "host",
"database::port": "port", "database::port": "port",
"database::database": "database", "database::database": "database",
"database::password": "password", "database::password": os.Getenv("GOPATH"),
"database::conns::maxconnection": 12, "database::conns::maxconnection": 12,
"database::conns::autoconnect": true, "database::conns::autoconnect": true,
"database::conns::connectioninfo": "info", "database::conns::connectioninfo": "info",
"database::conns::root": os.Getenv("GOPATH"),
"unknown": "", "unknown": "",
} }
) )

View File

@ -12,21 +12,21 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Package xml for config provider // Package xml for config provider.
// //
// depend on github.com/beego/x2j // depend on github.com/beego/x2j.
// //
// go install github.com/beego/x2j // go install github.com/beego/x2j.
// //
// Usage: // Usage:
// import( // import(
// _ "github.com/astaxie/beego/config/xml" // _ "github.com/astaxie/beego/config/xml"
// "github.com/astaxie/beego/config" // "github.com/astaxie/beego/config"
// ) // )
// //
// cnf, err := config.NewConfig("xml", "config.xml") // cnf, err := config.NewConfig("xml", "config.xml")
// //
// more docs http://beego.me/docs/module/config.md //More docs http://beego.me/docs/module/config.md
package xml package xml
import ( import (
@ -69,7 +69,7 @@ func (xc *Config) Parse(filename string) (config.Configer, error) {
return nil, err return nil, err
} }
x.data = d["config"].(map[string]interface{}) x.data = config.ExpandValueEnvForMap(d["config"].(map[string]interface{}))
return x, nil return x, nil
} }
@ -92,7 +92,7 @@ type ConfigContainer struct {
// Bool returns the boolean value for a given key. // Bool returns the boolean value for a given key.
func (c *ConfigContainer) Bool(key string) (bool, error) { func (c *ConfigContainer) Bool(key string) (bool, error) {
if v, ok := c.data[key]; ok { if v := c.data[key]; v != nil {
return config.ParseBool(v) return config.ParseBool(v)
} }
return false, fmt.Errorf("not exist key: %q", key) return false, fmt.Errorf("not exist key: %q", key)

View File

@ -15,14 +15,18 @@
package xml package xml
import ( import (
"fmt"
"os" "os"
"testing" "testing"
"github.com/astaxie/beego/config" "github.com/astaxie/beego/config"
) )
//xml parse should incluce in <config></config> tags func TestXML(t *testing.T) {
var xmlcontext = `<?xml version="1.0" encoding="UTF-8"?>
var (
//xml parse should incluce in <config></config> tags
xmlcontext = `<?xml version="1.0" encoding="UTF-8"?>
<config> <config>
<appname>beeapi</appname> <appname>beeapi</appname>
<httpport>8080</httpport> <httpport>8080</httpport>
@ -31,10 +35,25 @@ var xmlcontext = `<?xml version="1.0" encoding="UTF-8"?>
<runmode>dev</runmode> <runmode>dev</runmode>
<autorender>false</autorender> <autorender>false</autorender>
<copyrequestbody>true</copyrequestbody> <copyrequestbody>true</copyrequestbody>
<path1>${GOPATH}</path1>
<path2>${GOPATH||/home/go}</path2>
</config> </config>
` `
keyValue = map[string]interface{}{
"appname": "beeapi",
"httpport": 8080,
"mysqlport": int64(3600),
"PI": 3.1415976,
"runmode": "dev",
"autorender": false,
"copyrequestbody": true,
"path1": os.Getenv("GOPATH"),
"path2": os.Getenv("GOPATH"),
"error": "",
"emptystrings": []string{},
}
)
func TestXML(t *testing.T) {
f, err := os.Create("testxml.conf") f, err := os.Create("testxml.conf")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -50,39 +69,42 @@ func TestXML(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if xmlconf.String("appname") != "beeapi" {
t.Fatal("appname not equal to beeapi") for k, v := range keyValue {
}
if port, err := xmlconf.Int("httpport"); err != nil || port != 8080 { var (
t.Error(port) value interface{}
t.Fatal(err) err error
} )
if port, err := xmlconf.Int64("mysqlport"); err != nil || port != 3600 {
t.Error(port) switch v.(type) {
t.Fatal(err) case int:
} value, err = xmlconf.Int(k)
if pi, err := xmlconf.Float("PI"); err != nil || pi != 3.1415976 { case int64:
t.Error(pi) value, err = xmlconf.Int64(k)
t.Fatal(err) case float64:
} value, err = xmlconf.Float(k)
if xmlconf.String("runmode") != "dev" { case bool:
t.Fatal("runmode not equal to dev") value, err = xmlconf.Bool(k)
} case []string:
if v, err := xmlconf.Bool("autorender"); err != nil || v != false { value = xmlconf.Strings(k)
t.Error(v) case string:
t.Fatal(err) value = xmlconf.String(k)
} default:
if v, err := xmlconf.Bool("copyrequestbody"); err != nil || v != true { value, err = xmlconf.DIY(k)
t.Error(v) }
t.Fatal(err) if err != nil {
t.Errorf("get key %q value fatal,%v err %s", k, v, err)
} else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) {
t.Errorf("get key %q value, want %v got %v .", k, v, value)
}
} }
if err = xmlconf.Set("name", "astaxie"); err != nil { if err = xmlconf.Set("name", "astaxie"); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if xmlconf.String("name") != "astaxie" { if xmlconf.String("name") != "astaxie" {
t.Fatal("get name error") t.Fatal("get name error")
} }
if xmlconf.Strings("emptystrings") != nil {
t.Fatal("get emtpy strings error")
}
} }

View File

@ -19,14 +19,14 @@
// go install github.com/beego/goyaml2 // go install github.com/beego/goyaml2
// //
// Usage: // Usage:
// import( // import(
// _ "github.com/astaxie/beego/config/yaml" // _ "github.com/astaxie/beego/config/yaml"
// "github.com/astaxie/beego/config" // "github.com/astaxie/beego/config"
// ) // )
// //
// cnf, err := config.NewConfig("yaml", "config.yaml") // cnf, err := config.NewConfig("yaml", "config.yaml")
// //
// more docs http://beego.me/docs/module/config.md //More docs http://beego.me/docs/module/config.md
package yaml package yaml
import ( import (
@ -110,6 +110,7 @@ func ReadYmlReader(path string) (cnf map[string]interface{}, err error) {
log.Println("Not a Map? >> ", string(buf), data) log.Println("Not a Map? >> ", string(buf), data)
cnf = nil cnf = nil
} }
cnf = config.ExpandValueEnvForMap(cnf)
return return
} }
@ -121,10 +122,11 @@ type ConfigContainer struct {
// Bool returns the boolean value for a given key. // Bool returns the boolean value for a given key.
func (c *ConfigContainer) Bool(key string) (bool, error) { func (c *ConfigContainer) Bool(key string) (bool, error) {
if v, ok := c.data[key]; ok { v, err := c.getData(key)
return config.ParseBool(v) if err != nil {
return false, err
} }
return false, fmt.Errorf("not exist key: %q", key) return config.ParseBool(v)
} }
// DefaultBool return the bool value if has no error // DefaultBool return the bool value if has no error
@ -139,8 +141,12 @@ func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool {
// Int returns the integer value for a given key. // Int returns the integer value for a given key.
func (c *ConfigContainer) Int(key string) (int, error) { func (c *ConfigContainer) Int(key string) (int, error) {
if v, ok := c.data[key].(int64); ok { if v, err := c.getData(key); err != nil {
return int(v), nil return 0, err
} else if vv, ok := v.(int); ok {
return vv, nil
} else if vv, ok := v.(int64); ok {
return int(vv), nil
} }
return 0, errors.New("not int value") return 0, errors.New("not int value")
} }
@ -157,8 +163,10 @@ func (c *ConfigContainer) DefaultInt(key string, defaultval int) int {
// Int64 returns the int64 value for a given key. // Int64 returns the int64 value for a given key.
func (c *ConfigContainer) Int64(key string) (int64, error) { func (c *ConfigContainer) Int64(key string) (int64, error) {
if v, ok := c.data[key].(int64); ok { if v, err := c.getData(key); err != nil {
return v, nil return 0, err
} else if vv, ok := v.(int64); ok {
return vv, nil
} }
return 0, errors.New("not bool value") return 0, errors.New("not bool value")
} }
@ -175,8 +183,14 @@ func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
// Float returns the float value for a given key. // Float returns the float value for a given key.
func (c *ConfigContainer) Float(key string) (float64, error) { func (c *ConfigContainer) Float(key string) (float64, error) {
if v, ok := c.data[key].(float64); ok { if v, err := c.getData(key); err != nil {
return v, nil return 0.0, err
} else if vv, ok := v.(float64); ok {
return vv, nil
} else if vv, ok := v.(int); ok {
return float64(vv), nil
} else if vv, ok := v.(int64); ok {
return float64(vv), nil
} }
return 0.0, errors.New("not float64 value") return 0.0, errors.New("not float64 value")
} }
@ -193,8 +207,10 @@ func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
// String returns the string value for a given key. // String returns the string value for a given key.
func (c *ConfigContainer) String(key string) string { func (c *ConfigContainer) String(key string) string {
if v, ok := c.data[key].(string); ok { if v, err := c.getData(key); err == nil {
return v if vv, ok := v.(string); ok {
return vv
}
} }
return "" return ""
} }
@ -230,8 +246,8 @@ func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []stri
// GetSection returns map for the given section // GetSection returns map for the given section
func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { func (c *ConfigContainer) GetSection(section string) (map[string]string, error) {
v, ok := c.data[section]
if ok { if v, ok := c.data[section]; ok {
return v.(map[string]string), nil return v.(map[string]string), nil
} }
return nil, errors.New("not exist setction") return nil, errors.New("not exist setction")
@ -259,10 +275,19 @@ func (c *ConfigContainer) Set(key, val string) error {
// DIY returns the raw value by a given key. // DIY returns the raw value by a given key.
func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { func (c *ConfigContainer) DIY(key string) (v interface{}, err error) {
return c.getData(key)
}
func (c *ConfigContainer) getData(key string) (interface{}, error) {
if len(key) == 0 {
return nil, errors.New("key is emtpy")
}
if v, ok := c.data[key]; ok { if v, ok := c.data[key]; ok {
return v, nil return v, nil
} }
return nil, errors.New("not exist key") return nil, fmt.Errorf("not exist key %q", key)
} }
func init() { func init() {

View File

@ -15,13 +15,17 @@
package yaml package yaml
import ( import (
"fmt"
"os" "os"
"testing" "testing"
"github.com/astaxie/beego/config" "github.com/astaxie/beego/config"
) )
var yamlcontext = ` func TestYaml(t *testing.T) {
var (
yamlcontext = `
"appname": beeapi "appname": beeapi
"httpport": 8080 "httpport": 8080
"mysqlport": 3600 "mysqlport": 3600
@ -29,9 +33,27 @@ var yamlcontext = `
"runmode": dev "runmode": dev
"autorender": false "autorender": false
"copyrequestbody": true "copyrequestbody": true
"PATH": GOPATH
"path1": ${GOPATH}
"path2": ${GOPATH||/home/go}
"empty": ""
` `
func TestYaml(t *testing.T) { keyValue = map[string]interface{}{
"appname": "beeapi",
"httpport": 8080,
"mysqlport": int64(3600),
"PI": 3.1415976,
"runmode": "dev",
"autorender": false,
"copyrequestbody": true,
"PATH": "GOPATH",
"path1": os.Getenv("GOPATH"),
"path2": os.Getenv("GOPATH"),
"error": "",
"emptystrings": []string{},
}
)
f, err := os.Create("testyaml.conf") f, err := os.Create("testyaml.conf")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -47,32 +69,42 @@ func TestYaml(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if yamlconf.String("appname") != "beeapi" { if yamlconf.String("appname") != "beeapi" {
t.Fatal("appname not equal to beeapi") t.Fatal("appname not equal to beeapi")
} }
if port, err := yamlconf.Int("httpport"); err != nil || port != 8080 {
t.Error(port) for k, v := range keyValue {
t.Fatal(err)
} var (
if port, err := yamlconf.Int64("mysqlport"); err != nil || port != 3600 { value interface{}
t.Error(port) err error
t.Fatal(err) )
}
if pi, err := yamlconf.Float("PI"); err != nil || pi != 3.1415976 { switch v.(type) {
t.Error(pi) case int:
t.Fatal(err) value, err = yamlconf.Int(k)
} case int64:
if yamlconf.String("runmode") != "dev" { value, err = yamlconf.Int64(k)
t.Fatal("runmode not equal to dev") case float64:
} value, err = yamlconf.Float(k)
if v, err := yamlconf.Bool("autorender"); err != nil || v != false { case bool:
t.Error(v) value, err = yamlconf.Bool(k)
t.Fatal(err) case []string:
} value = yamlconf.Strings(k)
if v, err := yamlconf.Bool("copyrequestbody"); err != nil || v != true { case string:
t.Error(v) value = yamlconf.String(k)
t.Fatal(err) default:
value, err = yamlconf.DIY(k)
}
if err != nil {
t.Errorf("get key %q value fatal,%v err %s", k, v, err)
} else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) {
t.Errorf("get key %q value, want %v got %v .", k, v, value)
}
} }
if err = yamlconf.Set("name", "astaxie"); err != nil { if err = yamlconf.Set("name", "astaxie"); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -80,7 +112,4 @@ func TestYaml(t *testing.T) {
t.Fatal("get name error") t.Fatal("get name error")
} }
if yamlconf.Strings("emptystrings") != nil {
t.Fatal("get emtpy strings error")
}
} }

View File

@ -15,7 +15,11 @@
package beego package beego
import ( import (
"encoding/json"
"reflect"
"testing" "testing"
"github.com/astaxie/beego/config"
) )
func TestDefaults(t *testing.T) { func TestDefaults(t *testing.T) {
@ -27,3 +31,109 @@ func TestDefaults(t *testing.T) {
t.Errorf("FlashName was not set to default.") t.Errorf("FlashName was not set to default.")
} }
} }
func TestAssignConfig_01(t *testing.T) {
_BConfig := &Config{}
_BConfig.AppName = "beego_test"
jcf := &config.JSONConfig{}
ac, _ := jcf.ParseData([]byte(`{"AppName":"beego_json"}`))
assignSingleConfig(_BConfig, ac)
if _BConfig.AppName != "beego_json" {
t.Log(_BConfig)
t.FailNow()
}
}
func TestAssignConfig_02(t *testing.T) {
_BConfig := &Config{}
bs, _ := json.Marshal(newBConfig())
jsonMap := map[string]interface{}{}
json.Unmarshal(bs, &jsonMap)
configMap := map[string]interface{}{}
for k, v := range jsonMap {
if reflect.TypeOf(v).Kind() == reflect.Map {
for k1, v1 := range v.(map[string]interface{}) {
if reflect.TypeOf(v1).Kind() == reflect.Map {
for k2, v2 := range v1.(map[string]interface{}) {
configMap[k2] = v2
}
} else {
configMap[k1] = v1
}
}
} else {
configMap[k] = v
}
}
configMap["MaxMemory"] = 1024
configMap["Graceful"] = true
configMap["XSRFExpire"] = 32
configMap["SessionProviderConfig"] = "file"
configMap["FileLineNum"] = true
jcf := &config.JSONConfig{}
bs, _ = json.Marshal(configMap)
ac, _ := jcf.ParseData([]byte(bs))
for _, i := range []interface{}{_BConfig, &_BConfig.Listen, &_BConfig.WebConfig, &_BConfig.Log, &_BConfig.WebConfig.Session} {
assignSingleConfig(i, ac)
}
if _BConfig.MaxMemory != 1024 {
t.Log(_BConfig.MaxMemory)
t.FailNow()
}
if !_BConfig.Listen.Graceful {
t.Log(_BConfig.Listen.Graceful)
t.FailNow()
}
if _BConfig.WebConfig.XSRFExpire != 32 {
t.Log(_BConfig.WebConfig.XSRFExpire)
t.FailNow()
}
if _BConfig.WebConfig.Session.SessionProviderConfig != "file" {
t.Log(_BConfig.WebConfig.Session.SessionProviderConfig)
t.FailNow()
}
if !_BConfig.Log.FileLineNum {
t.Log(_BConfig.Log.FileLineNum)
t.FailNow()
}
}
func TestAssignConfig_03(t *testing.T) {
jcf := &config.JSONConfig{}
ac, _ := jcf.ParseData([]byte(`{"AppName":"beego"}`))
ac.Set("AppName", "test_app")
ac.Set("RunMode", "online")
ac.Set("StaticDir", "download:down download2:down2")
ac.Set("StaticExtensionsToGzip", ".css,.js,.html,.jpg,.png")
assignConfig(ac)
t.Logf("%#v",BConfig)
if BConfig.AppName != "test_app" {
t.FailNow()
}
if BConfig.RunMode != "online" {
t.FailNow()
}
if BConfig.WebConfig.StaticDir["/download"] != "down" {
t.FailNow()
}
if BConfig.WebConfig.StaticDir["/download2"] != "down2" {
t.FailNow()
}
if len(BConfig.WebConfig.StaticExtensionsToGzip) != 5 {
t.FailNow()
}
}

View File

@ -27,6 +27,33 @@ import (
"sync" "sync"
) )
var (
//Default size==20B same as nginx
defaultGzipMinLength = 20
//Content will only be compressed if content length is either unknown or greater than gzipMinLength.
gzipMinLength = defaultGzipMinLength
//The compression level used for deflate compression. (0-9).
gzipCompressLevel int
//List of HTTP methods to compress. If not set, only GET requests are compressed.
includedMethods map[string]bool
getMethodOnly bool
)
func InitGzip(minLength, compressLevel int, methods []string) {
if minLength >= 0 {
gzipMinLength = minLength
}
gzipCompressLevel = compressLevel
if gzipCompressLevel < flate.NoCompression || gzipCompressLevel > flate.BestCompression {
gzipCompressLevel = flate.BestSpeed
}
getMethodOnly = (len(methods) == 0) || (len(methods) == 1 && strings.ToUpper(methods[0]) == "GET")
includedMethods = make(map[string]bool, len(methods))
for _, v := range methods {
includedMethods[strings.ToUpper(v)] = true
}
}
type resetWriter interface { type resetWriter interface {
io.Writer io.Writer
Reset(w io.Writer) Reset(w io.Writer)
@ -41,20 +68,20 @@ func (n nopResetWriter) Reset(w io.Writer) {
} }
type acceptEncoder struct { type acceptEncoder struct {
name string name string
levelEncode func(int) resetWriter levelEncode func(int) resetWriter
bestSpeedPool *sync.Pool customCompressLevelPool *sync.Pool
bestCompressionPool *sync.Pool bestCompressionPool *sync.Pool
} }
func (ac acceptEncoder) encode(wr io.Writer, level int) resetWriter { func (ac acceptEncoder) encode(wr io.Writer, level int) resetWriter {
if ac.bestSpeedPool == nil || ac.bestCompressionPool == nil { if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil {
return nopResetWriter{wr} return nopResetWriter{wr}
} }
var rwr resetWriter var rwr resetWriter
switch level { switch level {
case flate.BestSpeed: case flate.BestSpeed:
rwr = ac.bestSpeedPool.Get().(resetWriter) rwr = ac.customCompressLevelPool.Get().(resetWriter)
case flate.BestCompression: case flate.BestCompression:
rwr = ac.bestCompressionPool.Get().(resetWriter) rwr = ac.bestCompressionPool.Get().(resetWriter)
default: default:
@ -65,13 +92,18 @@ func (ac acceptEncoder) encode(wr io.Writer, level int) resetWriter {
} }
func (ac acceptEncoder) put(wr resetWriter, level int) { func (ac acceptEncoder) put(wr resetWriter, level int) {
if ac.bestSpeedPool == nil || ac.bestCompressionPool == nil { if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil {
return return
} }
wr.Reset(nil) wr.Reset(nil)
//notice
//compressionLevel==BestCompression DOES NOT MATTER
//sync.Pool will not memory leak
switch level { switch level {
case flate.BestSpeed: case gzipCompressLevel:
ac.bestSpeedPool.Put(wr) ac.customCompressLevelPool.Put(wr)
case flate.BestCompression: case flate.BestCompression:
ac.bestCompressionPool.Put(wr) ac.bestCompressionPool.Put(wr)
} }
@ -79,28 +111,22 @@ func (ac acceptEncoder) put(wr resetWriter, level int) {
var ( var (
noneCompressEncoder = acceptEncoder{"", nil, nil, nil} noneCompressEncoder = acceptEncoder{"", nil, nil, nil}
gzipCompressEncoder = acceptEncoder{"gzip", gzipCompressEncoder = acceptEncoder{
func(level int) resetWriter { wr, _ := gzip.NewWriterLevel(nil, level); return wr }, name: "gzip",
&sync.Pool{ levelEncode: func(level int) resetWriter { wr, _ := gzip.NewWriterLevel(nil, level); return wr },
New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestSpeed); return wr }, customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, gzipCompressLevel); return wr }},
}, bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestCompression); return wr }},
&sync.Pool{
New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestCompression); return wr },
},
} }
//according to the sec :http://tools.ietf.org/html/rfc2616#section-3.5 ,the deflate compress in http is zlib indeed //according to the sec :http://tools.ietf.org/html/rfc2616#section-3.5 ,the deflate compress in http is zlib indeed
//deflate //deflate
//The "zlib" format defined in RFC 1950 [31] in combination with //The "zlib" format defined in RFC 1950 [31] in combination with
//the "deflate" compression mechanism described in RFC 1951 [29]. //the "deflate" compression mechanism described in RFC 1951 [29].
deflateCompressEncoder = acceptEncoder{"deflate", deflateCompressEncoder = acceptEncoder{
func(level int) resetWriter { wr, _ := zlib.NewWriterLevel(nil, level); return wr }, name: "deflate",
&sync.Pool{ levelEncode: func(level int) resetWriter { wr, _ := zlib.NewWriterLevel(nil, level); return wr },
New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestSpeed); return wr }, customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, gzipCompressLevel); return wr }},
}, bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestCompression); return wr }},
&sync.Pool{
New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestCompression); return wr },
},
} }
) )
@ -120,7 +146,11 @@ func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string,
// WriteBody reads writes content to writer by the specific encoding(gzip/deflate) // WriteBody reads writes content to writer by the specific encoding(gzip/deflate)
func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) { func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) {
return writeLevel(encoding, writer, bytes.NewReader(content), flate.BestSpeed) if encoding == "" || len(content) < gzipMinLength {
_, err := writer.Write(content)
return false, "", err
}
return writeLevel(encoding, writer, bytes.NewReader(content), gzipCompressLevel)
} }
// writeLevel reads from reader,writes to writer by specific encoding and compress level // writeLevel reads from reader,writes to writer by specific encoding and compress level
@ -156,7 +186,10 @@ func ParseEncoding(r *http.Request) string {
if r == nil { if r == nil {
return "" return ""
} }
return parseEncoding(r) if (getMethodOnly && r.Method == "GET") || includedMethods[r.Method] {
return parseEncoding(r)
}
return ""
} }
type q struct { type q struct {

View File

@ -24,13 +24,11 @@ package context
import ( import (
"bufio" "bufio"
"bytes"
"crypto/hmac" "crypto/hmac"
"crypto/sha1" "crypto/sha1"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"net/http" "net/http"
"strconv" "strconv"
@ -67,18 +65,18 @@ func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) {
ctx.ResponseWriter.reset(rw) ctx.ResponseWriter.reset(rw)
ctx.Input.Reset(ctx) ctx.Input.Reset(ctx)
ctx.Output.Reset(ctx) ctx.Output.Reset(ctx)
ctx._xsrfToken = ""
} }
// Redirect does redirection to localurl with http header status code. // Redirect does redirection to localurl with http header status code.
// It sends http response header directly.
func (ctx *Context) Redirect(status int, localurl string) { func (ctx *Context) Redirect(status int, localurl string) {
ctx.Output.Header("Location", localurl) http.Redirect(ctx.ResponseWriter, ctx.Request, localurl, status)
ctx.ResponseWriter.WriteHeader(status)
} }
// Abort stops this request. // Abort stops this request.
// if beego.ErrorMaps exists, panic body. // if beego.ErrorMaps exists, panic body.
func (ctx *Context) Abort(status int, body string) { func (ctx *Context) Abort(status int, body string) {
ctx.Output.SetStatus(status)
panic(body) panic(body)
} }
@ -195,14 +193,6 @@ func (r *Response) Write(p []byte) (int, error) {
return r.ResponseWriter.Write(p) return r.ResponseWriter.Write(p)
} }
// Copy writes the data to the connection as part of an HTTP reply,
// and sets `started` to true.
// started means the response has sent out.
func (r *Response) Copy(buf *bytes.Buffer) (int64, error) {
r.Started = true
return io.Copy(r.ResponseWriter, buf)
}
// WriteHeader sends an HTTP response header with status code, // WriteHeader sends an HTTP response header with status code,
// and sets `started` to true. // and sets `started` to true.
func (r *Response) WriteHeader(code int) { func (r *Response) WriteHeader(code int) {

47
context/context_test.go Normal file
View File

@ -0,0 +1,47 @@
// Copyright 2016 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package context
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestXsrfReset_01(t *testing.T) {
r := &http.Request{}
c := NewContext()
c.Request = r
c.ResponseWriter = &Response{}
c.ResponseWriter.reset(httptest.NewRecorder())
c.Output.Reset(c)
c.Input.Reset(c)
c.XSRFToken("key", 16)
if c._xsrfToken == "" {
t.FailNow()
}
token := c._xsrfToken
c.Reset(&Response{ResponseWriter: httptest.NewRecorder()}, r)
if c._xsrfToken != "" {
t.FailNow()
}
c.XSRFToken("key", 16)
if c._xsrfToken == "" {
t.FailNow()
}
if token == c._xsrfToken {
t.FailNow()
}
}

View File

@ -89,6 +89,9 @@ func (input *BeegoInput) Site() string {
// Scheme returns request scheme as "http" or "https". // Scheme returns request scheme as "http" or "https".
func (input *BeegoInput) Scheme() string { func (input *BeegoInput) Scheme() string {
if scheme := input.Header("X-Forwarded-Proto"); scheme != "" {
return scheme
}
if input.Context.Request.URL.Scheme != "" { if input.Context.Request.URL.Scheme != "" {
return input.Context.Request.URL.Scheme return input.Context.Request.URL.Scheme
} }

View File

@ -100,7 +100,7 @@ func TestSubDomain(t *testing.T) {
/* TODO Fix this /* TODO Fix this
r, _ = http.NewRequest("GET", "http://127.0.0.1/", nil) r, _ = http.NewRequest("GET", "http://127.0.0.1/", nil)
beegoInput.Request = r beegoInput.Context.Request = r
if beegoInput.SubDomains() != "" { if beegoInput.SubDomains() != "" {
t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains())
} }

View File

@ -21,8 +21,11 @@ import (
"errors" "errors"
"fmt" "fmt"
"html/template" "html/template"
"io"
"mime" "mime"
"net/http" "net/http"
"net/url"
"os"
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
@ -72,10 +75,11 @@ func (output *BeegoOutput) Body(content []byte) error {
if output.Status != 0 { if output.Status != 0 {
output.Context.ResponseWriter.WriteHeader(output.Status) output.Context.ResponseWriter.WriteHeader(output.Status)
output.Status = 0 output.Status = 0
} else {
output.Context.ResponseWriter.Started = true
} }
io.Copy(output.Context.ResponseWriter, buf)
_, err := output.Context.ResponseWriter.Copy(buf) return nil
return err
} }
// Cookie sets cookie value via given key. // Cookie sets cookie value via given key.
@ -235,13 +239,21 @@ func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error {
// Download forces response for download file. // Download forces response for download file.
// it prepares the download response header automatically. // it prepares the download response header automatically.
func (output *BeegoOutput) Download(file string, filename ...string) { func (output *BeegoOutput) Download(file string, filename ...string) {
// check get file error, file not found or other error.
if _, err := os.Stat(file); err != nil {
http.ServeFile(output.Context.ResponseWriter, output.Context.Request, file)
return
}
var fName string
if len(filename) > 0 && filename[0] != "" {
fName = filename[0]
} else {
fName = filepath.Base(file)
}
output.Header("Content-Disposition", "attachment; filename="+url.QueryEscape(fName))
output.Header("Content-Description", "File Transfer") output.Header("Content-Description", "File Transfer")
output.Header("Content-Type", "application/octet-stream") output.Header("Content-Type", "application/octet-stream")
if len(filename) > 0 && filename[0] != "" {
output.Header("Content-Disposition", "attachment; filename="+filename[0])
} else {
output.Header("Content-Disposition", "attachment; filename="+filepath.Base(file))
}
output.Header("Content-Transfer-Encoding", "binary") output.Header("Content-Transfer-Encoding", "binary")
output.Header("Expires", "0") output.Header("Expires", "0")
output.Header("Cache-Control", "must-revalidate") output.Header("Cache-Control", "must-revalidate")
@ -269,55 +281,55 @@ func (output *BeegoOutput) SetStatus(status int) {
// IsCachable returns boolean of this request is cached. // IsCachable returns boolean of this request is cached.
// HTTP 304 means cached. // HTTP 304 means cached.
func (output *BeegoOutput) IsCachable(status int) bool { func (output *BeegoOutput) IsCachable() bool {
return output.Status >= 200 && output.Status < 300 || output.Status == 304 return output.Status >= 200 && output.Status < 300 || output.Status == 304
} }
// IsEmpty returns boolean of this request is empty. // IsEmpty returns boolean of this request is empty.
// HTTP 201204 and 304 means empty. // HTTP 201204 and 304 means empty.
func (output *BeegoOutput) IsEmpty(status int) bool { func (output *BeegoOutput) IsEmpty() bool {
return output.Status == 201 || output.Status == 204 || output.Status == 304 return output.Status == 201 || output.Status == 204 || output.Status == 304
} }
// IsOk returns boolean of this request runs well. // IsOk returns boolean of this request runs well.
// HTTP 200 means ok. // HTTP 200 means ok.
func (output *BeegoOutput) IsOk(status int) bool { func (output *BeegoOutput) IsOk() bool {
return output.Status == 200 return output.Status == 200
} }
// IsSuccessful returns boolean of this request runs successfully. // IsSuccessful returns boolean of this request runs successfully.
// HTTP 2xx means ok. // HTTP 2xx means ok.
func (output *BeegoOutput) IsSuccessful(status int) bool { func (output *BeegoOutput) IsSuccessful() bool {
return output.Status >= 200 && output.Status < 300 return output.Status >= 200 && output.Status < 300
} }
// IsRedirect returns boolean of this request is redirection header. // IsRedirect returns boolean of this request is redirection header.
// HTTP 301,302,307 means redirection. // HTTP 301,302,307 means redirection.
func (output *BeegoOutput) IsRedirect(status int) bool { func (output *BeegoOutput) IsRedirect() bool {
return output.Status == 301 || output.Status == 302 || output.Status == 303 || output.Status == 307 return output.Status == 301 || output.Status == 302 || output.Status == 303 || output.Status == 307
} }
// IsForbidden returns boolean of this request is forbidden. // IsForbidden returns boolean of this request is forbidden.
// HTTP 403 means forbidden. // HTTP 403 means forbidden.
func (output *BeegoOutput) IsForbidden(status int) bool { func (output *BeegoOutput) IsForbidden() bool {
return output.Status == 403 return output.Status == 403
} }
// IsNotFound returns boolean of this request is not found. // IsNotFound returns boolean of this request is not found.
// HTTP 404 means forbidden. // HTTP 404 means forbidden.
func (output *BeegoOutput) IsNotFound(status int) bool { func (output *BeegoOutput) IsNotFound() bool {
return output.Status == 404 return output.Status == 404
} }
// IsClientError returns boolean of this request client sends error data. // IsClientError returns boolean of this request client sends error data.
// HTTP 4xx means forbidden. // HTTP 4xx means forbidden.
func (output *BeegoOutput) IsClientError(status int) bool { func (output *BeegoOutput) IsClientError() bool {
return output.Status >= 400 && output.Status < 500 return output.Status >= 400 && output.Status < 500
} }
// IsServerError returns boolean of this server handler errors. // IsServerError returns boolean of this server handler errors.
// HTTP 5xx means server internal error. // HTTP 5xx means server internal error.
func (output *BeegoOutput) IsServerError(status int) bool { func (output *BeegoOutput) IsServerError() bool {
return output.Status >= 500 && output.Status < 600 return output.Status >= 500 && output.Status < 600
} }

View File

@ -208,7 +208,7 @@ func (c *Controller) RenderBytes() ([]byte, error) {
continue continue
} }
buf.Reset() buf.Reset()
err = executeTemplate(&buf, sectionTpl, c.Data) err = ExecuteTemplate(&buf, sectionTpl, c.Data)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -217,7 +217,7 @@ func (c *Controller) RenderBytes() ([]byte, error) {
} }
buf.Reset() buf.Reset()
executeTemplate(&buf, c.Layout, c.Data) ExecuteTemplate(&buf, c.Layout, c.Data)
} }
return buf.Bytes(), err return buf.Bytes(), err
} }
@ -242,7 +242,7 @@ func (c *Controller) renderTemplate() (bytes.Buffer, error) {
} }
BuildTemplate(BConfig.WebConfig.ViewsPath, buildFiles...) BuildTemplate(BConfig.WebConfig.ViewsPath, buildFiles...)
} }
return buf, executeTemplate(&buf, c.TplName, c.Data) return buf, ExecuteTemplate(&buf, c.TplName, c.Data)
} }
// Redirect sends the redirection response to url with status code. // Redirect sends the redirection response to url with status code.
@ -261,12 +261,13 @@ func (c *Controller) Abort(code string) {
// CustomAbort stops controller handler and show the error data, it's similar Aborts, but support status code and body. // CustomAbort stops controller handler and show the error data, it's similar Aborts, but support status code and body.
func (c *Controller) CustomAbort(status int, body string) { func (c *Controller) CustomAbort(status int, body string) {
c.Ctx.Output.Status = status // first panic from ErrorMaps, it is user defined error functions.
// first panic from ErrorMaps, is is user defined error functions.
if _, ok := ErrorMaps[body]; ok { if _, ok := ErrorMaps[body]; ok {
c.Ctx.Output.Status = status
panic(body) panic(body)
} }
// last panic user string // last panic user string
c.Ctx.ResponseWriter.WriteHeader(status)
c.Ctx.ResponseWriter.Write([]byte(body)) c.Ctx.ResponseWriter.Write([]byte(body))
panic(ErrAbort) panic(ErrAbort)
} }

213
error.go
View File

@ -210,159 +210,139 @@ var ErrorMaps = make(map[string]*errorInfo, 10)
// show 401 unauthorized error. // show 401 unauthorized error.
func unauthorized(rw http.ResponseWriter, r *http.Request) { func unauthorized(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) responseError(rw, r,
data := map[string]interface{}{ 401,
"Title": http.StatusText(401), "<br>The page you have requested can't be authorized."+
"BeegoVersion": VERSION, "<br>Perhaps you are here because:"+
} "<br><br><ul>"+
data["Content"] = template.HTML("<br>The page you have requested can't be authorized." + "<br>The credentials you supplied are incorrect"+
"<br>Perhaps you are here because:" + "<br>There are errors in the website address"+
"<br><br><ul>" + "</ul>",
"<br>The credentials you supplied are incorrect" + )
"<br>There are errors in the website address" +
"</ul>")
t.Execute(rw, data)
} }
// show 402 Payment Required // show 402 Payment Required
func paymentRequired(rw http.ResponseWriter, r *http.Request) { func paymentRequired(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) responseError(rw, r,
data := map[string]interface{}{ 402,
"Title": http.StatusText(402), "<br>The page you have requested Payment Required."+
"BeegoVersion": VERSION, "<br>Perhaps you are here because:"+
} "<br><br><ul>"+
data["Content"] = template.HTML("<br>The page you have requested Payment Required." + "<br>The credentials you supplied are incorrect"+
"<br>Perhaps you are here because:" + "<br>There are errors in the website address"+
"<br><br><ul>" + "</ul>",
"<br>The credentials you supplied are incorrect" + )
"<br>There are errors in the website address" +
"</ul>")
t.Execute(rw, data)
} }
// show 403 forbidden error. // show 403 forbidden error.
func forbidden(rw http.ResponseWriter, r *http.Request) { func forbidden(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) responseError(rw, r,
data := map[string]interface{}{ 403,
"Title": http.StatusText(403), "<br>The page you have requested is forbidden."+
"BeegoVersion": VERSION, "<br>Perhaps you are here because:"+
} "<br><br><ul>"+
data["Content"] = template.HTML("<br>The page you have requested is forbidden." + "<br>Your address may be blocked"+
"<br>Perhaps you are here because:" + "<br>The site may be disabled"+
"<br><br><ul>" + "<br>You need to log in"+
"<br>Your address may be blocked" + "</ul>",
"<br>The site may be disabled" + )
"<br>You need to log in" +
"</ul>")
t.Execute(rw, data)
} }
// show 404 notfound error. // show 404 not found error.
func notFound(rw http.ResponseWriter, r *http.Request) { func notFound(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) responseError(rw, r,
data := map[string]interface{}{ 404,
"Title": http.StatusText(404), "<br>The page you have requested has flown the coop."+
"BeegoVersion": VERSION, "<br>Perhaps you are here because:"+
} "<br><br><ul>"+
data["Content"] = template.HTML("<br>The page you have requested has flown the coop." + "<br>The page has moved"+
"<br>Perhaps you are here because:" + "<br>The page no longer exists"+
"<br><br><ul>" + "<br>You were looking for your puppy and got lost"+
"<br>The page has moved" + "<br>You like 404 pages"+
"<br>The page no longer exists" + "</ul>",
"<br>You were looking for your puppy and got lost" + )
"<br>You like 404 pages" +
"</ul>")
t.Execute(rw, data)
} }
// show 405 Method Not Allowed // show 405 Method Not Allowed
func methodNotAllowed(rw http.ResponseWriter, r *http.Request) { func methodNotAllowed(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) responseError(rw, r,
data := map[string]interface{}{ 405,
"Title": http.StatusText(405), "<br>The method you have requested Not Allowed."+
"BeegoVersion": VERSION, "<br>Perhaps you are here because:"+
} "<br><br><ul>"+
data["Content"] = template.HTML("<br>The method you have requested Not Allowed." + "<br>The method specified in the Request-Line is not allowed for the resource identified by the Request-URI"+
"<br>Perhaps you are here because:" + "<br>The response MUST include an Allow header containing a list of valid methods for the requested resource."+
"<br><br><ul>" + "</ul>",
"<br>The method specified in the Request-Line is not allowed for the resource identified by the Request-URI" + )
"<br>The response MUST include an Allow header containing a list of valid methods for the requested resource." +
"</ul>")
t.Execute(rw, data)
} }
// show 500 internal server error. // show 500 internal server error.
func internalServerError(rw http.ResponseWriter, r *http.Request) { func internalServerError(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) responseError(rw, r,
data := map[string]interface{}{ 500,
"Title": http.StatusText(500), "<br>The page you have requested is down right now."+
"BeegoVersion": VERSION, "<br><br><ul>"+
} "<br>Please try again later and report the error to the website administrator"+
data["Content"] = template.HTML("<br>The page you have requested is down right now." + "<br></ul>",
"<br><br><ul>" + )
"<br>Please try again later and report the error to the website administrator" +
"<br></ul>")
t.Execute(rw, data)
} }
// show 501 Not Implemented. // show 501 Not Implemented.
func notImplemented(rw http.ResponseWriter, r *http.Request) { func notImplemented(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) responseError(rw, r,
data := map[string]interface{}{ 501,
"Title": http.StatusText(504), "<br>The page you have requested is Not Implemented."+
"BeegoVersion": VERSION, "<br><br><ul>"+
} "<br>Please try again later and report the error to the website administrator"+
data["Content"] = template.HTML("<br>The page you have requested is Not Implemented." + "<br></ul>",
"<br><br><ul>" + )
"<br>Please try again later and report the error to the website administrator" +
"<br></ul>")
t.Execute(rw, data)
} }
// show 502 Bad Gateway. // show 502 Bad Gateway.
func badGateway(rw http.ResponseWriter, r *http.Request) { func badGateway(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) responseError(rw, r,
data := map[string]interface{}{ 502,
"Title": http.StatusText(502), "<br>The page you have requested is down right now."+
"BeegoVersion": VERSION, "<br><br><ul>"+
} "<br>The server, while acting as a gateway or proxy, received an invalid response from the upstream server it accessed in attempting to fulfill the request."+
data["Content"] = template.HTML("<br>The page you have requested is down right now." + "<br>Please try again later and report the error to the website administrator"+
"<br><br><ul>" + "<br></ul>",
"<br>The server, while acting as a gateway or proxy, received an invalid response from the upstream server it accessed in attempting to fulfill the request." + )
"<br>Please try again later and report the error to the website administrator" +
"<br></ul>")
t.Execute(rw, data)
} }
// show 503 service unavailable error. // show 503 service unavailable error.
func serviceUnavailable(rw http.ResponseWriter, r *http.Request) { func serviceUnavailable(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) responseError(rw, r,
data := map[string]interface{}{ 503,
"Title": http.StatusText(503), "<br>The page you have requested is unavailable."+
"BeegoVersion": VERSION, "<br>Perhaps you are here because:"+
} "<br><br><ul>"+
data["Content"] = template.HTML("<br>The page you have requested is unavailable." + "<br><br>The page is overloaded"+
"<br>Perhaps you are here because:" + "<br>Please try again later."+
"<br><br><ul>" + "</ul>",
"<br><br>The page is overloaded" + )
"<br>Please try again later." +
"</ul>")
t.Execute(rw, data)
} }
// show 504 Gateway Timeout. // show 504 Gateway Timeout.
func gatewayTimeout(rw http.ResponseWriter, r *http.Request) { func gatewayTimeout(rw http.ResponseWriter, r *http.Request) {
responseError(rw, r,
504,
"<br>The page you have requested is unavailable"+
"<br>Perhaps you are here because:"+
"<br><br><ul>"+
"<br><br>The server, while acting as a gateway or proxy, did not receive a timely response from the upstream server specified by the URI."+
"<br>Please try again later."+
"</ul>",
)
}
func responseError(rw http.ResponseWriter, r *http.Request, errCode int, errContent string) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := map[string]interface{}{ data := map[string]interface{}{
"Title": http.StatusText(504), "Title": http.StatusText(errCode),
"BeegoVersion": VERSION, "BeegoVersion": VERSION,
"Content": template.HTML(errContent),
} }
data["Content"] = template.HTML("<br>The page you have requested is unavailable." +
"<br>Perhaps you are here because:" +
"<br><br><ul>" +
"<br><br>The server, while acting as a gateway or proxy, did not receive a timely response from the upstream server specified by the URI." +
"<br>Please try again later." +
"</ul>")
t.Execute(rw, data) t.Execute(rw, data)
} }
@ -408,7 +388,10 @@ func exception(errCode string, ctx *context.Context) {
if err == nil { if err == nil {
return v return v
} }
return 503 if ctx.Output.Status == 0 {
return 503
}
return ctx.Output.Status
} }
for _, ec := range []string{errCode, "503", "500"} { for _, ec := range []string{errCode, "503", "500"} {

88
error_test.go Normal file
View File

@ -0,0 +1,88 @@
// Copyright 2016 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package beego
import (
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
)
type errorTestController struct {
Controller
}
const parseCodeError = "parse code error"
func (ec *errorTestController) Get() {
errorCode, err := ec.GetInt("code")
if err != nil {
ec.Abort(parseCodeError)
}
if errorCode != 0 {
ec.CustomAbort(errorCode, ec.GetString("code"))
}
ec.Abort("404")
}
func TestErrorCode_01(t *testing.T) {
registerDefaultErrorHandler()
for k := range ErrorMaps {
r, _ := http.NewRequest("GET", "/error?code="+k, nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/error", &errorTestController{})
handler.ServeHTTP(w, r)
code, _ := strconv.Atoi(k)
if w.Code != code {
t.Fail()
}
if !strings.Contains(string(w.Body.Bytes()), http.StatusText(code)) {
t.Fail()
}
}
}
func TestErrorCode_02(t *testing.T) {
registerDefaultErrorHandler()
r, _ := http.NewRequest("GET", "/error?code=0", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/error", &errorTestController{})
handler.ServeHTTP(w, r)
if w.Code != 404 {
t.Fail()
}
}
func TestErrorCode_03(t *testing.T) {
registerDefaultErrorHandler()
r, _ := http.NewRequest("GET", "/error?code=panic", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/error", &errorTestController{})
handler.ServeHTTP(w, r)
if w.Code != 200 {
t.Fail()
}
if string(w.Body.Bytes()) != parseCodeError {
t.Fail()
}
}

View File

@ -20,14 +20,8 @@ import (
"testing" "testing"
"github.com/astaxie/beego/context" "github.com/astaxie/beego/context"
"github.com/astaxie/beego/logs"
) )
func init() {
BeeLogger = logs.NewLogger(10000)
BeeLogger.SetLogger("console", "")
}
var FilterUser = func(ctx *context.Context) { var FilterUser = func(ctx *context.Context) {
ctx.Output.Body([]byte("i am " + ctx.Input.Param(":last") + ctx.Input.Param(":first"))) ctx.Output.Body([]byte("i am " + ctx.Input.Param(":last") + ctx.Input.Param(":first")))
} }

View File

@ -6,6 +6,8 @@ import (
"net/http" "net/http"
"path/filepath" "path/filepath"
"github.com/astaxie/beego/context"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/session" "github.com/astaxie/beego/session"
) )
@ -45,13 +47,16 @@ func registerSession() error {
sessionConfig := AppConfig.String("sessionConfig") sessionConfig := AppConfig.String("sessionConfig")
if sessionConfig == "" { if sessionConfig == "" {
conf := map[string]interface{}{ conf := map[string]interface{}{
"cookieName": BConfig.WebConfig.Session.SessionName, "cookieName": BConfig.WebConfig.Session.SessionName,
"gclifetime": BConfig.WebConfig.Session.SessionGCMaxLifetime, "gclifetime": BConfig.WebConfig.Session.SessionGCMaxLifetime,
"providerConfig": filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig), "providerConfig": filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig),
"secure": BConfig.Listen.EnableHTTPS, "secure": BConfig.Listen.EnableHTTPS,
"enableSetCookie": BConfig.WebConfig.Session.SessionAutoSetCookie, "enableSetCookie": BConfig.WebConfig.Session.SessionAutoSetCookie,
"domain": BConfig.WebConfig.Session.SessionDomain, "domain": BConfig.WebConfig.Session.SessionDomain,
"cookieLifeTime": BConfig.WebConfig.Session.SessionCookieLifeTime, "cookieLifeTime": BConfig.WebConfig.Session.SessionCookieLifeTime,
"enableSidInHttpHeader": BConfig.WebConfig.Session.EnableSidInHttpHeader,
"sessionNameInHttpHeader": BConfig.WebConfig.Session.SessionNameInHttpHeader,
"enableSidInUrlQuery": BConfig.WebConfig.Session.EnableSidInUrlQuery,
} }
confBytes, err := json.Marshal(conf) confBytes, err := json.Marshal(conf)
if err != nil { if err != nil {
@ -70,24 +75,27 @@ func registerSession() error {
func registerTemplate() error { func registerTemplate() error {
if err := BuildTemplate(BConfig.WebConfig.ViewsPath); err != nil { if err := BuildTemplate(BConfig.WebConfig.ViewsPath); err != nil {
if BConfig.RunMode == DEV { if BConfig.RunMode == DEV {
Warn(err) logs.Warn(err)
} }
return err return err
} }
return nil return nil
} }
func registerDocs() error {
if BConfig.WebConfig.EnableDocs {
Get("/docs", serverDocs)
Get("/docs/*", serverDocs)
}
return nil
}
func registerAdmin() error { func registerAdmin() error {
if BConfig.Listen.EnableAdmin { if BConfig.Listen.EnableAdmin {
go beeAdminApp.Run() go beeAdminApp.Run()
} }
return nil return nil
} }
func registerGzip() error {
if BConfig.EnableGzip {
context.InitGzip(
AppConfig.DefaultInt("gzipMinLength", -1),
AppConfig.DefaultInt("gzipCompressLevel", -1),
AppConfig.DefaultStrings("includedMethods", []string{"GET"}),
)
}
return nil
}

35
log.go
View File

@ -33,82 +33,77 @@ const (
) )
// BeeLogger references the used application logger. // BeeLogger references the used application logger.
var BeeLogger = logs.NewLogger(100) var BeeLogger = logs.GetBeeLogger()
// SetLevel sets the global log level used by the simple logger. // SetLevel sets the global log level used by the simple logger.
func SetLevel(l int) { func SetLevel(l int) {
BeeLogger.SetLevel(l) logs.SetLevel(l)
} }
// SetLogFuncCall set the CallDepth, default is 3 // SetLogFuncCall set the CallDepth, default is 3
func SetLogFuncCall(b bool) { func SetLogFuncCall(b bool) {
BeeLogger.EnableFuncCallDepth(b) logs.SetLogFuncCall(b)
BeeLogger.SetLogFuncCallDepth(3)
} }
// SetLogger sets a new logger. // SetLogger sets a new logger.
func SetLogger(adaptername string, config string) error { func SetLogger(adaptername string, config string) error {
err := BeeLogger.SetLogger(adaptername, config) return logs.SetLogger(adaptername, config)
if err != nil {
return err
}
return nil
} }
// Emergency logs a message at emergency level. // Emergency logs a message at emergency level.
func Emergency(v ...interface{}) { func Emergency(v ...interface{}) {
BeeLogger.Emergency(generateFmtStr(len(v)), v...) logs.Emergency(generateFmtStr(len(v)), v...)
} }
// Alert logs a message at alert level. // Alert logs a message at alert level.
func Alert(v ...interface{}) { func Alert(v ...interface{}) {
BeeLogger.Alert(generateFmtStr(len(v)), v...) logs.Alert(generateFmtStr(len(v)), v...)
} }
// Critical logs a message at critical level. // Critical logs a message at critical level.
func Critical(v ...interface{}) { func Critical(v ...interface{}) {
BeeLogger.Critical(generateFmtStr(len(v)), v...) logs.Critical(generateFmtStr(len(v)), v...)
} }
// Error logs a message at error level. // Error logs a message at error level.
func Error(v ...interface{}) { func Error(v ...interface{}) {
BeeLogger.Error(generateFmtStr(len(v)), v...) logs.Error(generateFmtStr(len(v)), v...)
} }
// Warning logs a message at warning level. // Warning logs a message at warning level.
func Warning(v ...interface{}) { func Warning(v ...interface{}) {
BeeLogger.Warning(generateFmtStr(len(v)), v...) logs.Warning(generateFmtStr(len(v)), v...)
} }
// Warn compatibility alias for Warning() // Warn compatibility alias for Warning()
func Warn(v ...interface{}) { func Warn(v ...interface{}) {
BeeLogger.Warn(generateFmtStr(len(v)), v...) logs.Warn(generateFmtStr(len(v)), v...)
} }
// Notice logs a message at notice level. // Notice logs a message at notice level.
func Notice(v ...interface{}) { func Notice(v ...interface{}) {
BeeLogger.Notice(generateFmtStr(len(v)), v...) logs.Notice(generateFmtStr(len(v)), v...)
} }
// Informational logs a message at info level. // Informational logs a message at info level.
func Informational(v ...interface{}) { func Informational(v ...interface{}) {
BeeLogger.Informational(generateFmtStr(len(v)), v...) logs.Informational(generateFmtStr(len(v)), v...)
} }
// Info compatibility alias for Warning() // Info compatibility alias for Warning()
func Info(v ...interface{}) { func Info(v ...interface{}) {
BeeLogger.Info(generateFmtStr(len(v)), v...) logs.Info(generateFmtStr(len(v)), v...)
} }
// Debug logs a message at debug level. // Debug logs a message at debug level.
func Debug(v ...interface{}) { func Debug(v ...interface{}) {
BeeLogger.Debug(generateFmtStr(len(v)), v...) logs.Debug(generateFmtStr(len(v)), v...)
} }
// Trace logs a message at trace level. // Trace logs a message at trace level.
// compatibility alias for Warning() // compatibility alias for Warning()
func Trace(v ...interface{}) { func Trace(v ...interface{}) {
BeeLogger.Trace(generateFmtStr(len(v)), v...) logs.Trace(generateFmtStr(len(v)), v...)
} }
func generateFmtStr(n int) string { func generateFmtStr(n int) string {

View File

@ -113,5 +113,5 @@ func (c *connWriter) needToConnectOnMsg() bool {
} }
func init() { func init() {
Register("conn", NewConn) Register(AdapterConn, NewConn)
} }

View File

@ -97,5 +97,5 @@ func (c *consoleWriter) Flush() {
} }
func init() { func init() {
Register("console", NewConsole) Register(AdapterConsole, NewConsole)
} }

View File

@ -76,5 +76,5 @@ func (el *esLogger) Flush() {
} }
func init() { func init() {
logs.Register("es", NewES) logs.Register(logs.AdapterEs, NewES)
} }

View File

@ -22,6 +22,7 @@ import (
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -30,7 +31,7 @@ import (
// fileLogWriter implements LoggerInterface. // fileLogWriter implements LoggerInterface.
// It writes messages by lines limit, file size limit, or time frequency. // It writes messages by lines limit, file size limit, or time frequency.
type fileLogWriter struct { type fileLogWriter struct {
sync.Mutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize sync.RWMutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize
// The opened file // The opened file
Filename string `json:"filename"` Filename string `json:"filename"`
fileWriter *os.File fileWriter *os.File
@ -47,12 +48,13 @@ type fileLogWriter struct {
Daily bool `json:"daily"` Daily bool `json:"daily"`
MaxDays int64 `json:"maxdays"` MaxDays int64 `json:"maxdays"`
dailyOpenDate int dailyOpenDate int
dailyOpenTime time.Time
Rotate bool `json:"rotate"` Rotate bool `json:"rotate"`
Level int `json:"level"` Level int `json:"level"`
Perm os.FileMode `json:"perm"` Perm string `json:"perm"`
fileNameOnly, suffix string // like "project.log", project is fileNameOnly and .log is suffix fileNameOnly, suffix string // like "project.log", project is fileNameOnly and .log is suffix
} }
@ -60,14 +62,11 @@ type fileLogWriter struct {
// newFileWriter create a FileLogWriter returning as LoggerInterface. // newFileWriter create a FileLogWriter returning as LoggerInterface.
func newFileWriter() Logger { func newFileWriter() Logger {
w := &fileLogWriter{ w := &fileLogWriter{
Filename: "", Daily: true,
MaxLines: 1000000, MaxDays: 7,
MaxSize: 1 << 28, //256 MB Rotate: true,
Daily: true, Level: LevelTrace,
MaxDays: 7, Perm: "0660",
Rotate: true,
Level: LevelTrace,
Perm: 0660,
} }
return w return w
} }
@ -77,11 +76,11 @@ func newFileWriter() Logger {
// { // {
// "filename":"logs/beego.log", // "filename":"logs/beego.log",
// "maxLines":10000, // "maxLines":10000,
// "maxsize":1<<30, // "maxsize":1024,
// "daily":true, // "daily":true,
// "maxDays":15, // "maxDays":15,
// "rotate":true, // "rotate":true,
// "perm":0600 // "perm":"0600"
// } // }
func (w *fileLogWriter) Init(jsonConfig string) error { func (w *fileLogWriter) Init(jsonConfig string) error {
err := json.Unmarshal([]byte(jsonConfig), w) err := json.Unmarshal([]byte(jsonConfig), w)
@ -128,7 +127,9 @@ func (w *fileLogWriter) WriteMsg(when time.Time, msg string, level int) error {
h, d := formatTimeHeader(when) h, d := formatTimeHeader(when)
msg = string(h) + msg + "\n" msg = string(h) + msg + "\n"
if w.Rotate { if w.Rotate {
w.RLock()
if w.needRotate(len(msg), d) { if w.needRotate(len(msg), d) {
w.RUnlock()
w.Lock() w.Lock()
if w.needRotate(len(msg), d) { if w.needRotate(len(msg), d) {
if err := w.doRotate(when); err != nil { if err := w.doRotate(when); err != nil {
@ -136,6 +137,8 @@ func (w *fileLogWriter) WriteMsg(when time.Time, msg string, level int) error {
} }
} }
w.Unlock() w.Unlock()
} else {
w.RUnlock()
} }
} }
@ -151,7 +154,11 @@ func (w *fileLogWriter) WriteMsg(when time.Time, msg string, level int) error {
func (w *fileLogWriter) createLogFile() (*os.File, error) { func (w *fileLogWriter) createLogFile() (*os.File, error) {
// Open the log file // Open the log file
fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, w.Perm) perm, err := strconv.ParseInt(w.Perm, 8, 64)
if err != nil {
return nil, err
}
fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, os.FileMode(perm))
return fd, err return fd, err
} }
@ -162,8 +169,12 @@ func (w *fileLogWriter) initFd() error {
return fmt.Errorf("get stat err: %s\n", err) return fmt.Errorf("get stat err: %s\n", err)
} }
w.maxSizeCurSize = int(fInfo.Size()) w.maxSizeCurSize = int(fInfo.Size())
w.dailyOpenDate = time.Now().Day() w.dailyOpenTime = time.Now()
w.dailyOpenDate = w.dailyOpenTime.Day()
w.maxLinesCurLines = 0 w.maxLinesCurLines = 0
if w.Daily {
go w.dailyRotate(w.dailyOpenTime)
}
if fInfo.Size() > 0 { if fInfo.Size() > 0 {
count, err := w.lines() count, err := w.lines()
if err != nil { if err != nil {
@ -174,6 +185,22 @@ func (w *fileLogWriter) initFd() error {
return nil return nil
} }
func (w *fileLogWriter) dailyRotate(openTime time.Time) {
y, m, d := openTime.Add(24 * time.Hour).Date()
nextDay := time.Date(y, m, d, 0, 0, 0, 0, openTime.Location())
tm := time.NewTimer(time.Duration(nextDay.UnixNano() - openTime.UnixNano() + 100))
select {
case <-tm.C:
w.Lock()
if w.needRotate(0, time.Now().Day()) {
if err := w.doRotate(time.Now()); err != nil {
fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
}
}
w.Unlock()
}
}
func (w *fileLogWriter) lines() (int, error) { func (w *fileLogWriter) lines() (int, error) {
fd, err := os.Open(w.Filename) fd, err := os.Open(w.Filename)
if err != nil { if err != nil {
@ -204,22 +231,29 @@ func (w *fileLogWriter) lines() (int, error) {
// DoRotate means it need to write file in new file. // DoRotate means it need to write file in new file.
// new file name like xx.2013-01-01.log (daily) or xx.001.log (by line or size) // new file name like xx.2013-01-01.log (daily) or xx.001.log (by line or size)
func (w *fileLogWriter) doRotate(logTime time.Time) error { func (w *fileLogWriter) doRotate(logTime time.Time) error {
_, err := os.Lstat(w.Filename)
if err != nil {
return err
}
// file exists // file exists
// Find the next available number // Find the next available number
num := 1 num := 1
fName := "" fName := ""
_, err := os.Lstat(w.Filename)
if err != nil {
//even if the file is not exist or other ,we should RESTART the logger
goto RESTART_LOGGER
}
if w.MaxLines > 0 || w.MaxSize > 0 { if w.MaxLines > 0 || w.MaxSize > 0 {
for ; err == nil && num <= 999; num++ { for ; err == nil && num <= 999; num++ {
fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", logTime.Format("2006-01-02"), num, w.suffix) fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", logTime.Format("2006-01-02"), num, w.suffix)
_, err = os.Lstat(fName) _, err = os.Lstat(fName)
} }
} else { } else {
fName = fmt.Sprintf("%s.%s%s", w.fileNameOnly, logTime.Format("2006-01-02"), w.suffix) fName = fmt.Sprintf("%s.%s%s", w.fileNameOnly, w.dailyOpenTime.Format("2006-01-02"), w.suffix)
_, err = os.Lstat(fName) _, err = os.Lstat(fName)
for ; err == nil && num <= 999; num++ {
fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", w.dailyOpenTime.Format("2006-01-02"), num, w.suffix)
_, err = os.Lstat(fName)
}
} }
// return error if the last file checked still existed // return error if the last file checked still existed
if err == nil { if err == nil {
@ -231,16 +265,18 @@ func (w *fileLogWriter) doRotate(logTime time.Time) error {
// Rename the file to its new found name // Rename the file to its new found name
// even if occurs error,we MUST guarantee to restart new logger // even if occurs error,we MUST guarantee to restart new logger
renameErr := os.Rename(w.Filename, fName) err = os.Rename(w.Filename, fName)
// re-start logger // re-start logger
RESTART_LOGGER:
startLoggerErr := w.startLogger() startLoggerErr := w.startLogger()
go w.deleteOldLog() go w.deleteOldLog()
if startLoggerErr != nil { if startLoggerErr != nil {
return fmt.Errorf("Rotate StartLogger: %s\n", startLoggerErr) return fmt.Errorf("Rotate StartLogger: %s\n", startLoggerErr)
} }
if renameErr != nil { if err != nil {
return fmt.Errorf("Rotate: %s\n", renameErr) return fmt.Errorf("Rotate: %s\n", err)
} }
return nil return nil
@ -255,8 +291,12 @@ func (w *fileLogWriter) deleteOldLog() {
} }
}() }()
if !info.IsDir() && info.ModTime().Unix() < (time.Now().Unix()-60*60*24*w.MaxDays) { if info == nil {
if strings.HasPrefix(filepath.Base(path), w.fileNameOnly) && return
}
if !info.IsDir() && info.ModTime().Add(24*time.Hour*time.Duration(w.MaxDays)).Before(time.Now()) {
if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) &&
strings.HasSuffix(filepath.Base(path), w.suffix) { strings.HasSuffix(filepath.Base(path), w.suffix) {
os.Remove(path) os.Remove(path)
} }
@ -278,5 +318,5 @@ func (w *fileLogWriter) Flush() {
} }
func init() { func init() {
Register("file", newFileWriter) Register(AdapterFile, newFileWriter)
} }

View File

@ -17,12 +17,34 @@ package logs
import ( import (
"bufio" "bufio"
"fmt" "fmt"
"io/ioutil"
"os" "os"
"strconv" "strconv"
"testing" "testing"
"time" "time"
) )
func TestFilePerm(t *testing.T) {
log := NewLogger(10000)
log.SetLogger("file", `{"filename":"test.log", "perm": "0600"}`)
log.Debug("debug")
log.Informational("info")
log.Notice("notice")
log.Warning("warning")
log.Error("error")
log.Alert("alert")
log.Critical("critical")
log.Emergency("emergency")
file, err := os.Stat("test.log")
if err != nil {
t.Fatal(err)
}
if file.Mode() != 0600 {
t.Fatal("unexpected log file permission")
}
os.Remove("test.log")
}
func TestFile1(t *testing.T) { func TestFile1(t *testing.T) {
log := NewLogger(10000) log := NewLogger(10000)
log.SetLogger("file", `{"filename":"test.log"}`) log.SetLogger("file", `{"filename":"test.log"}`)
@ -89,7 +111,7 @@ func TestFile2(t *testing.T) {
os.Remove("test2.log") os.Remove("test2.log")
} }
func TestFileRotate(t *testing.T) { func TestFileRotate_01(t *testing.T) {
log := NewLogger(10000) log := NewLogger(10000)
log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`) log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`)
log.Debug("debug") log.Debug("debug")
@ -110,6 +132,90 @@ func TestFileRotate(t *testing.T) {
os.Remove("test3.log") os.Remove("test3.log")
} }
func TestFileRotate_02(t *testing.T) {
fn1 := "rotate_day.log"
fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log"
testFileRotate(t, fn1, fn2)
}
func TestFileRotate_03(t *testing.T) {
fn1 := "rotate_day.log"
fn := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log"
os.Create(fn)
fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log"
testFileRotate(t, fn1, fn2)
os.Remove(fn)
}
func TestFileRotate_04(t *testing.T) {
fn1 := "rotate_day.log"
fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log"
testFileDailyRotate(t, fn1, fn2)
}
func TestFileRotate_05(t *testing.T) {
fn1 := "rotate_day.log"
fn := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log"
os.Create(fn)
fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log"
testFileDailyRotate(t, fn1, fn2)
os.Remove(fn)
}
func testFileRotate(t *testing.T, fn1, fn2 string) {
fw := &fileLogWriter{
Daily: true,
MaxDays: 7,
Rotate: true,
Level: LevelTrace,
Perm: "0660",
}
fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1))
fw.dailyOpenTime = time.Now().Add(-24 * time.Hour)
fw.dailyOpenDate = fw.dailyOpenTime.Day()
fw.WriteMsg(time.Now(), "this is a msg for test", LevelDebug)
for _, file := range []string{fn1, fn2} {
_, err := os.Stat(file)
if err != nil {
t.FailNow()
}
os.Remove(file)
}
fw.Destroy()
}
func testFileDailyRotate(t *testing.T, fn1, fn2 string) {
fw := &fileLogWriter{
Daily: true,
MaxDays: 7,
Rotate: true,
Level: LevelTrace,
Perm: "0660",
}
fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1))
fw.dailyOpenTime = time.Now().Add(-24 * time.Hour)
fw.dailyOpenDate = fw.dailyOpenTime.Day()
today, _ := time.ParseInLocation("2006-01-02", time.Now().Format("2006-01-02"), fw.dailyOpenTime.Location())
today = today.Add(-1 * time.Second)
fw.dailyRotate(today)
for _, file := range []string{fn1, fn2} {
_, err := os.Stat(file)
if err != nil {
t.FailNow()
}
content, err := ioutil.ReadFile(file)
if err != nil {
t.FailNow()
}
if len(content) > 0 {
t.FailNow()
}
os.Remove(file)
}
fw.Destroy()
}
func exists(path string) (bool, error) { func exists(path string) (bool, error) {
_, err := os.Stat(path) _, err := os.Stat(path)
if err == nil { if err == nil {

View File

@ -35,10 +35,12 @@ package logs
import ( import (
"fmt" "fmt"
"log"
"os" "os"
"path" "path"
"runtime" "runtime"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
) )
@ -55,16 +57,28 @@ const (
LevelDebug LevelDebug
) )
// Legacy loglevel constants to ensure backwards compatibility. // levelLogLogger is defined to implement log.Logger
// // the real log level will be LevelEmergency
// Deprecated: will be removed in 1.5.0. const levelLoggerImpl = -1
// Name for adapter with beego official support
const (
AdapterConsole = "console"
AdapterFile = "file"
AdapterMultiFile = "multifile"
AdapterMail = "stmp"
AdapterConn = "conn"
AdapterEs = "es"
)
// Legacy log level constants to ensure backwards compatibility.
const ( const (
LevelInfo = LevelInformational LevelInfo = LevelInformational
LevelTrace = LevelDebug LevelTrace = LevelDebug
LevelWarn = LevelWarning LevelWarn = LevelWarning
) )
type loggerType func() Logger type newLoggerFunc func() Logger
// Logger defines the behavior of a log provider. // Logger defines the behavior of a log provider.
type Logger interface { type Logger interface {
@ -74,12 +88,13 @@ type Logger interface {
Flush() Flush()
} }
var adapters = make(map[string]loggerType) var adapters = make(map[string]newLoggerFunc)
var levelPrefix = [LevelDebug + 1]string{"[M] ", "[A] ", "[C] ", "[E] ", "[W] ", "[N] ", "[I] ", "[D] "}
// Register makes a log provide available by the provided name. // Register makes a log provide available by the provided name.
// If Register is called twice with the same name or if driver is nil, // If Register is called twice with the same name or if driver is nil,
// it panics. // it panics.
func Register(name string, log loggerType) { func Register(name string, log newLoggerFunc) {
if log == nil { if log == nil {
panic("logs: Register provide is nil") panic("logs: Register provide is nil")
} }
@ -94,15 +109,19 @@ func Register(name string, log loggerType) {
type BeeLogger struct { type BeeLogger struct {
lock sync.Mutex lock sync.Mutex
level int level int
init bool
enableFuncCallDepth bool enableFuncCallDepth bool
loggerFuncCallDepth int loggerFuncCallDepth int
asynchronous bool asynchronous bool
msgChanLen int64
msgChan chan *logMsg msgChan chan *logMsg
signalChan chan string signalChan chan string
wg sync.WaitGroup wg sync.WaitGroup
outputs []*nameLogger outputs []*nameLogger
} }
const defaultAsyncMsgLen = 1e3
type nameLogger struct { type nameLogger struct {
Logger Logger
name string name string
@ -119,18 +138,31 @@ var logMsgPool *sync.Pool
// NewLogger returns a new BeeLogger. // NewLogger returns a new BeeLogger.
// channelLen means the number of messages in chan(used where asynchronous is true). // channelLen means the number of messages in chan(used where asynchronous is true).
// if the buffering chan is full, logger adapters write to file or other way. // if the buffering chan is full, logger adapters write to file or other way.
func NewLogger(channelLen int64) *BeeLogger { func NewLogger(channelLens ...int64) *BeeLogger {
bl := new(BeeLogger) bl := new(BeeLogger)
bl.level = LevelDebug bl.level = LevelDebug
bl.loggerFuncCallDepth = 2 bl.loggerFuncCallDepth = 2
bl.msgChan = make(chan *logMsg, channelLen) bl.msgChanLen = append(channelLens, 0)[0]
if bl.msgChanLen <= 0 {
bl.msgChanLen = defaultAsyncMsgLen
}
bl.signalChan = make(chan string, 1) bl.signalChan = make(chan string, 1)
bl.setLogger(AdapterConsole)
return bl return bl
} }
// Async set the log to asynchronous and start the goroutine // Async set the log to asynchronous and start the goroutine
func (bl *BeeLogger) Async() *BeeLogger { func (bl *BeeLogger) Async(msgLen ...int64) *BeeLogger {
bl.lock.Lock()
defer bl.lock.Unlock()
if bl.asynchronous {
return bl
}
bl.asynchronous = true bl.asynchronous = true
if len(msgLen) > 0 && msgLen[0] > 0 {
bl.msgChanLen = msgLen[0]
}
bl.msgChan = make(chan *logMsg, bl.msgChanLen)
logMsgPool = &sync.Pool{ logMsgPool = &sync.Pool{
New: func() interface{} { New: func() interface{} {
return &logMsg{} return &logMsg{}
@ -143,10 +175,8 @@ func (bl *BeeLogger) Async() *BeeLogger {
// SetLogger provides a given logger adapter into BeeLogger with config string. // SetLogger provides a given logger adapter into BeeLogger with config string.
// config need to be correct JSON as string: {"interval":360}. // config need to be correct JSON as string: {"interval":360}.
func (bl *BeeLogger) SetLogger(adapterName string, config string) error { func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error {
bl.lock.Lock() config := append(configs, "{}")[0]
defer bl.lock.Unlock()
for _, l := range bl.outputs { for _, l := range bl.outputs {
if l.name == adapterName { if l.name == adapterName {
return fmt.Errorf("logs: duplicate adaptername %q (you have set this logger before)", adapterName) return fmt.Errorf("logs: duplicate adaptername %q (you have set this logger before)", adapterName)
@ -168,6 +198,18 @@ func (bl *BeeLogger) SetLogger(adapterName string, config string) error {
return nil return nil
} }
// SetLogger provides a given logger adapter into BeeLogger with config string.
// config need to be correct JSON as string: {"interval":360}.
func (bl *BeeLogger) SetLogger(adapterName string, configs ...string) error {
bl.lock.Lock()
defer bl.lock.Unlock()
if !bl.init {
bl.outputs = []*nameLogger{}
bl.init = true
}
return bl.setLogger(adapterName, configs...)
}
// DelLogger remove a logger adapter in BeeLogger. // DelLogger remove a logger adapter in BeeLogger.
func (bl *BeeLogger) DelLogger(adapterName string) error { func (bl *BeeLogger) DelLogger(adapterName string) error {
bl.lock.Lock() bl.lock.Lock()
@ -196,7 +238,37 @@ func (bl *BeeLogger) writeToLoggers(when time.Time, msg string, level int) {
} }
} }
func (bl *BeeLogger) writeMsg(logLevel int, msg string) error { func (bl *BeeLogger) Write(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
// writeMsg will always add a '\n' character
if p[len(p)-1] == '\n' {
p = p[0 : len(p)-1]
}
// set levelLoggerImpl to ensure all log message will be write out
err = bl.writeMsg(levelLoggerImpl, string(p))
if err == nil {
return len(p), err
}
return 0, err
}
func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error {
if !bl.init {
bl.lock.Lock()
bl.setLogger(AdapterConsole)
bl.lock.Unlock()
}
if logLevel == levelLoggerImpl {
// set to emergency to ensure all log will be print out correctly
logLevel = LevelEmergency
} else {
msg = levelPrefix[logLevel] + msg
}
if len(v) > 0 {
msg = fmt.Sprintf(msg, v...)
}
when := time.Now() when := time.Now()
if bl.enableFuncCallDepth { if bl.enableFuncCallDepth {
_, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth) _, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth)
@ -205,7 +277,7 @@ func (bl *BeeLogger) writeMsg(logLevel int, msg string) error {
line = 0 line = 0
} }
_, filename := path.Split(file) _, filename := path.Split(file)
msg = "[" + filename + ":" + strconv.FormatInt(int64(line), 10) + "]" + msg msg = "[" + filename + ":" + strconv.FormatInt(int64(line), 10) + "] " + msg
} }
if bl.asynchronous { if bl.asynchronous {
lm := logMsgPool.Get().(*logMsg) lm := logMsgPool.Get().(*logMsg)
@ -273,8 +345,7 @@ func (bl *BeeLogger) Emergency(format string, v ...interface{}) {
if LevelEmergency > bl.level { if LevelEmergency > bl.level {
return return
} }
msg := fmt.Sprintf("[M] "+format, v...) bl.writeMsg(LevelEmergency, format, v...)
bl.writeMsg(LevelEmergency, msg)
} }
// Alert Log ALERT level message. // Alert Log ALERT level message.
@ -282,8 +353,7 @@ func (bl *BeeLogger) Alert(format string, v ...interface{}) {
if LevelAlert > bl.level { if LevelAlert > bl.level {
return return
} }
msg := fmt.Sprintf("[A] "+format, v...) bl.writeMsg(LevelAlert, format, v...)
bl.writeMsg(LevelAlert, msg)
} }
// Critical Log CRITICAL level message. // Critical Log CRITICAL level message.
@ -291,8 +361,7 @@ func (bl *BeeLogger) Critical(format string, v ...interface{}) {
if LevelCritical > bl.level { if LevelCritical > bl.level {
return return
} }
msg := fmt.Sprintf("[C] "+format, v...) bl.writeMsg(LevelCritical, format, v...)
bl.writeMsg(LevelCritical, msg)
} }
// Error Log ERROR level message. // Error Log ERROR level message.
@ -300,17 +369,12 @@ func (bl *BeeLogger) Error(format string, v ...interface{}) {
if LevelError > bl.level { if LevelError > bl.level {
return return
} }
msg := fmt.Sprintf("[E] "+format, v...) bl.writeMsg(LevelError, format, v...)
bl.writeMsg(LevelError, msg)
} }
// Warning Log WARNING level message. // Warning Log WARNING level message.
func (bl *BeeLogger) Warning(format string, v ...interface{}) { func (bl *BeeLogger) Warning(format string, v ...interface{}) {
if LevelWarning > bl.level { bl.Warn(format, v...)
return
}
msg := fmt.Sprintf("[W] "+format, v...)
bl.writeMsg(LevelWarning, msg)
} }
// Notice Log NOTICE level message. // Notice Log NOTICE level message.
@ -318,17 +382,12 @@ func (bl *BeeLogger) Notice(format string, v ...interface{}) {
if LevelNotice > bl.level { if LevelNotice > bl.level {
return return
} }
msg := fmt.Sprintf("[N] "+format, v...) bl.writeMsg(LevelNotice, format, v...)
bl.writeMsg(LevelNotice, msg)
} }
// Informational Log INFORMATIONAL level message. // Informational Log INFORMATIONAL level message.
func (bl *BeeLogger) Informational(format string, v ...interface{}) { func (bl *BeeLogger) Informational(format string, v ...interface{}) {
if LevelInformational > bl.level { bl.Info(format, v...)
return
}
msg := fmt.Sprintf("[I] "+format, v...)
bl.writeMsg(LevelInformational, msg)
} }
// Debug Log DEBUG level message. // Debug Log DEBUG level message.
@ -336,38 +395,31 @@ func (bl *BeeLogger) Debug(format string, v ...interface{}) {
if LevelDebug > bl.level { if LevelDebug > bl.level {
return return
} }
msg := fmt.Sprintf("[D] "+format, v...) bl.writeMsg(LevelDebug, format, v...)
bl.writeMsg(LevelDebug, msg)
} }
// Warn Log WARN level message. // Warn Log WARN level message.
// compatibility alias for Warning() // compatibility alias for Warning()
func (bl *BeeLogger) Warn(format string, v ...interface{}) { func (bl *BeeLogger) Warn(format string, v ...interface{}) {
if LevelWarning > bl.level { if LevelWarn > bl.level {
return return
} }
msg := fmt.Sprintf("[W] "+format, v...) bl.writeMsg(LevelWarn, format, v...)
bl.writeMsg(LevelWarning, msg)
} }
// Info Log INFO level message. // Info Log INFO level message.
// compatibility alias for Informational() // compatibility alias for Informational()
func (bl *BeeLogger) Info(format string, v ...interface{}) { func (bl *BeeLogger) Info(format string, v ...interface{}) {
if LevelInformational > bl.level { if LevelInfo > bl.level {
return return
} }
msg := fmt.Sprintf("[I] "+format, v...) bl.writeMsg(LevelInfo, format, v...)
bl.writeMsg(LevelInformational, msg)
} }
// Trace Log TRACE level message. // Trace Log TRACE level message.
// compatibility alias for Debug() // compatibility alias for Debug()
func (bl *BeeLogger) Trace(format string, v ...interface{}) { func (bl *BeeLogger) Trace(format string, v ...interface{}) {
if LevelDebug > bl.level { bl.Debug(format, v...)
return
}
msg := fmt.Sprintf("[D] "+format, v...)
bl.writeMsg(LevelDebug, msg)
} }
// Flush flush all chan data. // Flush flush all chan data.
@ -386,6 +438,7 @@ func (bl *BeeLogger) Close() {
if bl.asynchronous { if bl.asynchronous {
bl.signalChan <- "close" bl.signalChan <- "close"
bl.wg.Wait() bl.wg.Wait()
close(bl.msgChan)
} else { } else {
bl.flush() bl.flush()
for _, l := range bl.outputs { for _, l := range bl.outputs {
@ -393,7 +446,6 @@ func (bl *BeeLogger) Close() {
} }
bl.outputs = nil bl.outputs = nil
} }
close(bl.msgChan)
close(bl.signalChan) close(bl.signalChan)
} }
@ -407,16 +459,175 @@ func (bl *BeeLogger) Reset() {
} }
func (bl *BeeLogger) flush() { func (bl *BeeLogger) flush() {
for { if bl.asynchronous {
if len(bl.msgChan) > 0 { for {
bm := <-bl.msgChan if len(bl.msgChan) > 0 {
bl.writeToLoggers(bm.when, bm.msg, bm.level) bm := <-bl.msgChan
logMsgPool.Put(bm) bl.writeToLoggers(bm.when, bm.msg, bm.level)
continue logMsgPool.Put(bm)
continue
}
break
} }
break
} }
for _, l := range bl.outputs { for _, l := range bl.outputs {
l.Flush() l.Flush()
} }
} }
// beeLogger references the used application logger.
var beeLogger *BeeLogger = NewLogger()
// GetLogger returns the default BeeLogger
func GetBeeLogger() *BeeLogger {
return beeLogger
}
var beeLoggerMap = struct {
sync.RWMutex
logs map[string]*log.Logger
}{
logs: map[string]*log.Logger{},
}
// GetLogger returns the default BeeLogger
func GetLogger(prefixes ...string) *log.Logger {
prefix := append(prefixes, "")[0]
if prefix != "" {
prefix = fmt.Sprintf(`[%s] `, strings.ToUpper(prefix))
}
beeLoggerMap.RLock()
l, ok := beeLoggerMap.logs[prefix]
if ok {
beeLoggerMap.RUnlock()
return l
}
beeLoggerMap.RUnlock()
beeLoggerMap.Lock()
defer beeLoggerMap.Unlock()
l, ok = beeLoggerMap.logs[prefix]
if !ok {
l = log.New(beeLogger, prefix, 0)
beeLoggerMap.logs[prefix] = l
}
return l
}
// Reset will remove all the adapter
func Reset() {
beeLogger.Reset()
}
func Async(msgLen ...int64) *BeeLogger {
return beeLogger.Async(msgLen...)
}
// SetLevel sets the global log level used by the simple logger.
func SetLevel(l int) {
beeLogger.SetLevel(l)
}
// EnableFuncCallDepth enable log funcCallDepth
func EnableFuncCallDepth(b bool) {
beeLogger.enableFuncCallDepth = b
}
// SetLogFuncCall set the CallDepth, default is 3
func SetLogFuncCall(b bool) {
beeLogger.EnableFuncCallDepth(b)
beeLogger.SetLogFuncCallDepth(3)
}
// SetLogFuncCallDepth set log funcCallDepth
func SetLogFuncCallDepth(d int) {
beeLogger.loggerFuncCallDepth = d
}
// SetLogger sets a new logger.
func SetLogger(adapter string, config ...string) error {
err := beeLogger.SetLogger(adapter, config...)
if err != nil {
return err
}
return nil
}
// Emergency logs a message at emergency level.
func Emergency(f interface{}, v ...interface{}) {
beeLogger.Emergency(formatLog(f, v...))
}
// Alert logs a message at alert level.
func Alert(f interface{}, v ...interface{}) {
beeLogger.Alert(formatLog(f, v...))
}
// Critical logs a message at critical level.
func Critical(f interface{}, v ...interface{}) {
beeLogger.Critical(formatLog(f, v...))
}
// Error logs a message at error level.
func Error(f interface{}, v ...interface{}) {
beeLogger.Error(formatLog(f, v...))
}
// Warning logs a message at warning level.
func Warning(f interface{}, v ...interface{}) {
beeLogger.Warn(formatLog(f, v...))
}
// Warn compatibility alias for Warning()
func Warn(f interface{}, v ...interface{}) {
beeLogger.Warn(formatLog(f, v...))
}
// Notice logs a message at notice level.
func Notice(f interface{}, v ...interface{}) {
beeLogger.Notice(formatLog(f, v...))
}
// Informational logs a message at info level.
func Informational(f interface{}, v ...interface{}) {
beeLogger.Info(formatLog(f, v...))
}
// Info compatibility alias for Warning()
func Info(f interface{}, v ...interface{}) {
beeLogger.Info(formatLog(f, v...))
}
// Debug logs a message at debug level.
func Debug(f interface{}, v ...interface{}) {
beeLogger.Debug(formatLog(f, v...))
}
// Trace logs a message at trace level.
// compatibility alias for Warning()
func Trace(f interface{}, v ...interface{}) {
beeLogger.Trace(formatLog(f, v...))
}
func formatLog(f interface{}, v ...interface{}) string {
var msg string
switch f.(type) {
case string:
msg = f.(string)
if len(v) == 0 {
return msg
}
if strings.Contains(msg, "%") && !strings.Contains(msg, "%%") {
//format string
} else {
//do not contain format char
msg += strings.Repeat(" %v", len(v))
}
default:
msg = fmt.Sprint(f)
if len(v) == 0 {
return msg
}
msg += strings.Repeat(" %v", len(v))
}
return fmt.Sprintf(msg, v...)
}

View File

@ -36,43 +36,46 @@ func (lg *logWriter) println(when time.Time, msg string) {
lg.Unlock() lg.Unlock()
} }
const y1 = `0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999`
const y2 = `0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789`
const mo1 = `000000000111`
const mo2 = `123456789012`
const d1 = `0000000001111111111222222222233`
const d2 = `1234567890123456789012345678901`
const h1 = `000000000011111111112222`
const h2 = `012345678901234567890123`
const mi1 = `000000000011111111112222222222333333333344444444445555555555`
const mi2 = `012345678901234567890123456789012345678901234567890123456789`
const s1 = `000000000011111111112222222222333333333344444444445555555555`
const s2 = `012345678901234567890123456789012345678901234567890123456789`
func formatTimeHeader(when time.Time) ([]byte, int) { func formatTimeHeader(when time.Time) ([]byte, int) {
y, mo, d := when.Date() y, mo, d := when.Date()
h, mi, s := when.Clock() h, mi, s := when.Clock()
//len(2006/01/02 15:03:04)==19 //len("2006/01/02 15:04:05 ")==20
var buf [20]byte var buf [20]byte
t := 3
for y >= 10 { //change to '3' after 984 years, LOL
p := y / 10 buf[0] = '2'
buf[t] = byte('0' + y - p*10) //change to '1' after 84 years, LOL
y = p buf[1] = '0'
t-- buf[2] = y1[y-2000]
} buf[3] = y2[y-2000]
buf[0] = byte('0' + y)
buf[4] = '/' buf[4] = '/'
if mo > 9 { buf[5] = mo1[mo-1]
buf[5] = '1' buf[6] = mo2[mo-1]
buf[6] = byte('0' + mo - 9)
} else {
buf[5] = '0'
buf[6] = byte('0' + mo)
}
buf[7] = '/' buf[7] = '/'
t = d / 10 buf[8] = d1[d-1]
buf[8] = byte('0' + t) buf[9] = d2[d-1]
buf[9] = byte('0' + d - t*10)
buf[10] = ' ' buf[10] = ' '
t = h / 10 buf[11] = h1[h]
buf[11] = byte('0' + t) buf[12] = h2[h]
buf[12] = byte('0' + h - t*10)
buf[13] = ':' buf[13] = ':'
t = mi / 10 buf[14] = mi1[mi]
buf[14] = byte('0' + t) buf[15] = mi2[mi]
buf[15] = byte('0' + mi - t*10)
buf[16] = ':' buf[16] = ':'
t = s / 10 buf[17] = s1[s]
buf[17] = byte('0' + t) buf[18] = s2[s]
buf[18] = byte('0' + s - t*10)
buf[19] = ' ' buf[19] = ' '
return buf[0:], d return buf[0:], d

57
logs/logger_test.go Normal file
View File

@ -0,0 +1,57 @@
// Copyright 2016 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package logs
import (
"testing"
"time"
)
func TestFormatHeader_0(t *testing.T) {
tm := time.Now()
if tm.Year() >= 2100 {
t.FailNow()
}
dur := time.Second
for {
if tm.Year() >= 2100 {
break
}
h, _ := formatTimeHeader(tm)
if tm.Format("2006/01/02 15:04:05 ") != string(h) {
t.Log(tm)
t.FailNow()
}
tm = tm.Add(dur)
dur *= 2
}
}
func TestFormatHeader_1(t *testing.T) {
tm := time.Now()
year := tm.Year()
dur := time.Second
for {
if tm.Year() >= year+1 {
break
}
h, _ := formatTimeHeader(tm)
if tm.Format("2006/01/02 15:04:05 ") != string(h) {
t.Log(tm)
t.FailNow()
}
tm = tm.Add(dur)
}
}

View File

@ -112,5 +112,5 @@ func newFilesWriter() Logger {
} }
func init() { func init() {
Register("multifile", newFilesWriter) Register(AdapterMultiFile, newFilesWriter)
} }

View File

@ -156,5 +156,5 @@ func (s *SMTPWriter) Destroy() {
} }
func init() { func init() {
Register("smtp", newSMTPWriter) Register(AdapterMail, newSMTPWriter)
} }

View File

@ -33,7 +33,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/astaxie/beego" "github.com/astaxie/beego/logs"
"github.com/astaxie/beego/orm" "github.com/astaxie/beego/orm"
) )
@ -90,7 +90,7 @@ func (m *Migration) Reset() {
func (m *Migration) Exec(name, status string) error { func (m *Migration) Exec(name, status string) error {
o := orm.NewOrm() o := orm.NewOrm()
for _, s := range m.sqls { for _, s := range m.sqls {
beego.Info("exec sql:", s) logs.Info("exec sql:", s)
r := o.Raw(s) r := o.Raw(s)
_, err := r.Exec() _, err := r.Exec()
if err != nil { if err != nil {
@ -144,20 +144,20 @@ func Upgrade(lasttime int64) error {
i := 0 i := 0
for _, v := range sm { for _, v := range sm {
if v.created > lasttime { if v.created > lasttime {
beego.Info("start upgrade", v.name) logs.Info("start upgrade", v.name)
v.m.Reset() v.m.Reset()
v.m.Up() v.m.Up()
err := v.m.Exec(v.name, "up") err := v.m.Exec(v.name, "up")
if err != nil { if err != nil {
beego.Error("execute error:", err) logs.Error("execute error:", err)
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
return err return err
} }
beego.Info("end upgrade:", v.name) logs.Info("end upgrade:", v.name)
i++ i++
} }
} }
beego.Info("total success upgrade:", i, " migration") logs.Info("total success upgrade:", i, " migration")
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
return nil return nil
} }
@ -165,20 +165,20 @@ func Upgrade(lasttime int64) error {
// Rollback rollback the migration by the name // Rollback rollback the migration by the name
func Rollback(name string) error { func Rollback(name string) error {
if v, ok := migrationMap[name]; ok { if v, ok := migrationMap[name]; ok {
beego.Info("start rollback") logs.Info("start rollback")
v.Reset() v.Reset()
v.Down() v.Down()
err := v.Exec(name, "down") err := v.Exec(name, "down")
if err != nil { if err != nil {
beego.Error("execute error:", err) logs.Error("execute error:", err)
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
return err return err
} }
beego.Info("end rollback") logs.Info("end rollback")
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
return nil return nil
} }
beego.Error("not exist the migrationMap name:" + name) logs.Error("not exist the migrationMap name:" + name)
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
return errors.New("not exist the migrationMap name:" + name) return errors.New("not exist the migrationMap name:" + name)
} }
@ -191,23 +191,23 @@ func Reset() error {
for j := len(sm) - 1; j >= 0; j-- { for j := len(sm) - 1; j >= 0; j-- {
v := sm[j] v := sm[j]
if isRollBack(v.name) { if isRollBack(v.name) {
beego.Info("skip the", v.name) logs.Info("skip the", v.name)
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
continue continue
} }
beego.Info("start reset:", v.name) logs.Info("start reset:", v.name)
v.m.Reset() v.m.Reset()
v.m.Down() v.m.Down()
err := v.m.Exec(v.name, "down") err := v.m.Exec(v.name, "down")
if err != nil { if err != nil {
beego.Error("execute error:", err) logs.Error("execute error:", err)
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
return err return err
} }
i++ i++
beego.Info("end reset:", v.name) logs.Info("end reset:", v.name)
} }
beego.Info("total success reset:", i, " migration") logs.Info("total success reset:", i, " migration")
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
return nil return nil
} }
@ -216,7 +216,7 @@ func Reset() error {
func Refresh() error { func Refresh() error {
err := Reset() err := Reset()
if err != nil { if err != nil {
beego.Error("execute error:", err) logs.Error("execute error:", err)
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
return err return err
} }
@ -265,7 +265,7 @@ func isRollBack(name string) bool {
var maps []orm.Params var maps []orm.Params
num, err := o.Raw("select * from migrations where `name` = ? order by id_migration desc", name).Values(&maps) num, err := o.Raw("select * from migrations where `name` = ? order by id_migration desc", name).Values(&maps)
if err != nil { if err != nil {
beego.Info("get name has error", err) logs.Info("get name has error", err)
return false return false
} }
if num <= 0 { if num <= 0 {

View File

@ -44,7 +44,7 @@ func NewNamespace(prefix string, params ...LinkNamespace) *Namespace {
return ns return ns
} }
// Cond set condtion function // Cond set condition function
// if cond return true can run this namespace, else can't // if cond return true can run this namespace, else can't
// usage: // usage:
// ns.Cond(func (ctx *context.Context) bool{ // ns.Cond(func (ctx *context.Context) bool{
@ -60,7 +60,7 @@ func (n *Namespace) Cond(cond namespaceCond) *Namespace {
exception("405", ctx) exception("405", ctx)
} }
} }
if v, ok := n.handlers.filters[BeforeRouter]; ok { if v := n.handlers.filters[BeforeRouter]; len(v) > 0 {
mr := new(FilterRouter) mr := new(FilterRouter)
mr.tree = NewTree() mr.tree = NewTree()
mr.pattern = "*" mr.pattern = "*"

View File

@ -61,8 +61,8 @@ func TestNamespaceNest(t *testing.T) {
ns.Namespace( ns.Namespace(
NewNamespace("/admin"). NewNamespace("/admin").
Get("/order", func(ctx *context.Context) { Get("/order", func(ctx *context.Context) {
ctx.Output.Body([]byte("order")) ctx.Output.Body([]byte("order"))
}), }),
) )
AddNamespace(ns) AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r) BeeApp.Handlers.ServeHTTP(w, r)
@ -79,8 +79,8 @@ func TestNamespaceNestParam(t *testing.T) {
ns.Namespace( ns.Namespace(
NewNamespace("/admin"). NewNamespace("/admin").
Get("/order/:id", func(ctx *context.Context) { Get("/order/:id", func(ctx *context.Context) {
ctx.Output.Body([]byte(ctx.Input.Param(":id"))) ctx.Output.Body([]byte(ctx.Input.Param(":id")))
}), }),
) )
AddNamespace(ns) AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r) BeeApp.Handlers.ServeHTTP(w, r)
@ -124,8 +124,8 @@ func TestNamespaceFilter(t *testing.T) {
ctx.Output.Body([]byte("this is Filter")) ctx.Output.Body([]byte("this is Filter"))
}). }).
Get("/user/:id", func(ctx *context.Context) { Get("/user/:id", func(ctx *context.Context) {
ctx.Output.Body([]byte(ctx.Input.Param(":id"))) ctx.Output.Body([]byte(ctx.Input.Param(":id")))
}) })
AddNamespace(ns) AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r) BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != "this is Filter" { if w.Body.String() != "this is Filter" {

View File

@ -52,9 +52,15 @@ checkColumn:
case TypeBooleanField: case TypeBooleanField:
col = T["bool"] col = T["bool"]
case TypeCharField: case TypeCharField:
col = fmt.Sprintf(T["string"], fieldSize) if al.Driver == DRPostgres && fi.toText {
col = T["string-text"]
} else {
col = fmt.Sprintf(T["string"], fieldSize)
}
case TypeTextField: case TypeTextField:
col = T["string-text"] col = T["string-text"]
case TypeTimeField:
col = T["time.Time-clock"]
case TypeDateField: case TypeDateField:
col = T["time.Time-date"] col = T["time.Time-date"]
case TypeDateTimeField: case TypeDateTimeField:
@ -88,6 +94,18 @@ checkColumn:
} else { } else {
col = fmt.Sprintf(s, fi.digits, fi.decimals) col = fmt.Sprintf(s, fi.digits, fi.decimals)
} }
case TypeJSONField:
if al.Driver != DRPostgres {
fieldType = TypeCharField
goto checkColumn
}
col = T["json"]
case TypeJsonbField:
if al.Driver != DRPostgres {
fieldType = TypeCharField
goto checkColumn
}
col = T["jsonb"]
case RelForeignKey, RelOneToOne: case RelForeignKey, RelOneToOne:
fieldType = fi.relModelInfo.fields.pk.fieldType fieldType = fi.relModelInfo.fields.pk.fieldType
fieldSize = fi.relModelInfo.fields.pk.size fieldSize = fi.relModelInfo.fields.pk.size
@ -264,7 +282,7 @@ func getColumnDefault(fi *fieldInfo) string {
// These defaults will be useful if there no config value orm:"default" and NOT NULL is on // These defaults will be useful if there no config value orm:"default" and NOT NULL is on
switch fi.fieldType { switch fi.fieldType {
case TypeDateField, TypeDateTimeField, TypeTextField: case TypeTimeField, TypeDateField, TypeDateTimeField, TypeTextField:
return v return v
case TypeBitField, TypeSmallIntegerField, TypeIntegerField, case TypeBitField, TypeSmallIntegerField, TypeIntegerField,
@ -276,6 +294,8 @@ func getColumnDefault(fi *fieldInfo) string {
case TypeBooleanField: case TypeBooleanField:
t = " DEFAULT %s " t = " DEFAULT %s "
d = "FALSE" d = "FALSE"
case TypeJSONField, TypeJsonbField:
d = "{}"
} }
if fi.colDefault { if fi.colDefault {

133
orm/db.go
View File

@ -24,6 +24,7 @@ import (
) )
const ( const (
formatTime = "15:04:05"
formatDate = "2006-01-02" formatDate = "2006-01-02"
formatDateTime = "2006-01-02 15:04:05" formatDateTime = "2006-01-02 15:04:05"
) )
@ -71,12 +72,12 @@ type dbBase struct {
var _ dbBaser = new(dbBase) var _ dbBaser = new(dbBase)
// get struct columns values as interface slice. // get struct columns values as interface slice.
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, err error) { func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, autoFields []string, err error) {
var columns []string if names == nil {
ns := make([]string, 0, len(cols))
if names != nil { names = &ns
columns = *names
} }
values = make([]interface{}, 0, len(cols))
for _, column := range cols { for _, column := range cols {
var fi *fieldInfo var fi *fieldInfo
@ -90,18 +91,24 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
} }
value, err := d.collectFieldValue(mi, fi, ind, insert, tz) value, err := d.collectFieldValue(mi, fi, ind, insert, tz)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
if names != nil { // ignore empty value auto field
columns = append(columns, column) if insert && fi.auto {
if fi.fieldType&IsPositiveIntegerField > 0 {
if vu, ok := value.(uint64); !ok || vu == 0 {
continue
}
} else {
if vu, ok := value.(int64); !ok || vu == 0 {
continue
}
}
autoFields = append(autoFields, fi.column)
} }
values = append(values, value) *names, values = append(*names, column), append(values, value)
}
if names != nil {
*names = columns
} }
return return
@ -134,7 +141,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
} else { } else {
value = field.Bool() value = field.Bool()
} }
case TypeCharField, TypeTextField: case TypeCharField, TypeTextField, TypeJSONField, TypeJsonbField:
if ns, ok := field.Interface().(sql.NullString); ok { if ns, ok := field.Interface().(sql.NullString); ok {
value = nil value = nil
if ns.Valid { if ns.Valid {
@ -169,7 +176,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
value = field.Float() value = field.Float()
} }
} }
case TypeDateField, TypeDateTimeField: case TypeTimeField, TypeDateField, TypeDateTimeField:
value = field.Interface() value = field.Interface()
if t, ok := value.(time.Time); ok { if t, ok := value.(time.Time); ok {
d.ins.TimeToDB(&t, tz) d.ins.TimeToDB(&t, tz)
@ -181,7 +188,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
} }
default: default:
switch { switch {
case fi.fieldType&IsPostiveIntegerField > 0: case fi.fieldType&IsPositiveIntegerField > 0:
if field.Kind() == reflect.Ptr { if field.Kind() == reflect.Ptr {
if field.IsNil() { if field.IsNil() {
value = nil value = nil
@ -223,7 +230,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
} }
} }
switch fi.fieldType { switch fi.fieldType {
case TypeDateField, TypeDateTimeField: case TypeTimeField, TypeDateField, TypeDateTimeField:
if fi.autoNow || fi.autoNowAdd && insert { if fi.autoNow || fi.autoNowAdd && insert {
if insert { if insert {
if t, ok := value.(time.Time); ok && !t.IsZero() { if t, ok := value.(time.Time); ok && !t.IsZero() {
@ -236,10 +243,21 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
if fi.isFielder { if fi.isFielder {
f := field.Addr().Interface().(Fielder) f := field.Addr().Interface().(Fielder)
f.SetRaw(tnow.In(DefaultTimeLoc)) f.SetRaw(tnow.In(DefaultTimeLoc))
} else if field.Kind() == reflect.Ptr {
v := tnow.In(DefaultTimeLoc)
field.Set(reflect.ValueOf(&v))
} else { } else {
field.Set(reflect.ValueOf(tnow.In(DefaultTimeLoc))) field.Set(reflect.ValueOf(tnow.In(DefaultTimeLoc)))
} }
} }
case TypeJSONField, TypeJsonbField:
if s, ok := value.(string); (ok && len(s) == 0) || value == nil {
if fi.colDefault && fi.initial.Exist() {
value = fi.initial.String()
} else {
value = nil
}
}
} }
} }
return value, nil return value, nil
@ -273,7 +291,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
// insert struct with prepared statement and given struct reflect value. // insert struct with prepared statement and given struct reflect value.
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -300,7 +318,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
if len(cols) > 0 { if len(cols) > 0 {
var err error var err error
whereCols = make([]string, 0, len(cols)) whereCols = make([]string, 0, len(cols))
args, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz) args, _, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz)
if err != nil { if err != nil {
return err return err
} }
@ -349,13 +367,21 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
// execute insert sql dbQuerier with given struct reflect.Value. // execute insert sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
names := make([]string, 0, len(mi.fields.dbcols)-1) names := make([]string, 0, len(mi.fields.dbcols))
values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz) values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
return d.InsertValue(q, mi, false, names, values) id, err := d.InsertValue(q, mi, false, names, values)
if err != nil {
return 0, err
}
if len(autoFields) > 0 {
err = d.ins.setval(q, mi, autoFields)
}
return id, err
} }
// multi-insert sql with given slice struct reflect.Value. // multi-insert sql with given slice struct reflect.Value.
@ -369,7 +395,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
// typ := reflect.Indirect(mi.addrField).Type() // typ := reflect.Indirect(mi.addrField).Type()
length := sind.Len() length, autoFields := sind.Len(), make([]string, 0, 1)
for i := 1; i <= length; i++ { for i := 1; i <= length; i++ {
@ -381,16 +407,18 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
// } // }
if i == 1 { if i == 1 {
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz) var (
vus []interface{}
err error
)
vus, autoFields, err = d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz)
if err != nil { if err != nil {
return cnt, err return cnt, err
} }
values = make([]interface{}, bulk*len(vus)) values = make([]interface{}, bulk*len(vus))
nums += copy(values, vus) nums += copy(values, vus)
} else { } else {
vus, _, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, nil, tz)
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
if err != nil { if err != nil {
return cnt, err return cnt, err
} }
@ -412,7 +440,12 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
} }
} }
return cnt, nil var err error
if len(autoFields) > 0 {
err = d.ins.setval(q, mi, autoFields)
}
return cnt, err
} }
// execute insert sql with given struct and given values. // execute insert sql with given struct and given values.
@ -472,7 +505,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
setNames = make([]string, 0, len(cols)) setNames = make([]string, 0, len(cols))
} }
setValues, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz) setValues, _, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -516,7 +549,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
} }
if num > 0 { if num > 0 {
if mi.fields.pk.auto { if mi.fields.pk.auto {
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(0) ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(0)
} else { } else {
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0) ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0)
@ -1071,13 +1104,13 @@ setValue:
} }
value = b value = b
} }
case fieldType == TypeCharField || fieldType == TypeTextField: case fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField:
if str == nil { if str == nil {
value = ToStr(val) value = ToStr(val)
} else { } else {
value = str.String() value = str.String()
} }
case fieldType == TypeDateField || fieldType == TypeDateTimeField: case fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField:
if str == nil { if str == nil {
switch t := val.(type) { switch t := val.(type) {
case time.Time: case time.Time:
@ -1097,15 +1130,20 @@ setValue:
if len(s) >= 19 { if len(s) >= 19 {
s = s[:19] s = s[:19]
t, err = time.ParseInLocation(formatDateTime, s, tz) t, err = time.ParseInLocation(formatDateTime, s, tz)
} else { } else if len(s) >= 10 {
if len(s) > 10 { if len(s) > 10 {
s = s[:10] s = s[:10]
} }
t, err = time.ParseInLocation(formatDate, s, tz) t, err = time.ParseInLocation(formatDate, s, tz)
} else if len(s) >= 8 {
if len(s) > 8 {
s = s[:8]
}
t, err = time.ParseInLocation(formatTime, s, tz)
} }
t = t.In(DefaultTimeLoc) t = t.In(DefaultTimeLoc)
if err != nil && s != "0000-00-00" && s != "0000-00-00 00:00:00" { if err != nil && s != "00:00:00" && s != "0000-00-00" && s != "0000-00-00 00:00:00" {
tErr = err tErr = err
goto end goto end
} }
@ -1140,7 +1178,7 @@ setValue:
tErr = err tErr = err
goto end goto end
} }
if fieldType&IsPostiveIntegerField > 0 { if fieldType&IsPositiveIntegerField > 0 {
v, _ := str.Uint64() v, _ := str.Uint64()
value = v value = v
} else { } else {
@ -1212,7 +1250,7 @@ setValue:
field.SetBool(value.(bool)) field.SetBool(value.(bool))
} }
} }
case fieldType == TypeCharField || fieldType == TypeTextField: case fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField:
if isNative { if isNative {
if ns, ok := field.Interface().(sql.NullString); ok { if ns, ok := field.Interface().(sql.NullString); ok {
if value == nil { if value == nil {
@ -1234,12 +1272,18 @@ setValue:
field.SetString(value.(string)) field.SetString(value.(string))
} }
} }
case fieldType == TypeDateField || fieldType == TypeDateTimeField: case fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField:
if isNative { if isNative {
if value == nil { if value == nil {
value = time.Time{} value = time.Time{}
} else if field.Kind() == reflect.Ptr {
if value != nil {
v := value.(time.Time)
field.Set(reflect.ValueOf(&v))
}
} else {
field.Set(reflect.ValueOf(value))
} }
field.Set(reflect.ValueOf(value))
} }
case fieldType == TypePositiveBitField && field.Kind() == reflect.Ptr: case fieldType == TypePositiveBitField && field.Kind() == reflect.Ptr:
if value != nil { if value != nil {
@ -1292,7 +1336,7 @@ setValue:
field.Set(reflect.ValueOf(&v)) field.Set(reflect.ValueOf(&v))
} }
case fieldType&IsIntegerField > 0: case fieldType&IsIntegerField > 0:
if fieldType&IsPostiveIntegerField > 0 { if fieldType&IsPositiveIntegerField > 0 {
if isNative { if isNative {
if value == nil { if value == nil {
value = uint64(0) value = uint64(0)
@ -1440,7 +1484,11 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
sels := strings.Join(cols, ", ") sels := strings.Join(cols, ", ")
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s%s", sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) sqlSelect := "SELECT"
if qs.distinct {
sqlSelect += " DISTINCT"
}
query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
@ -1562,6 +1610,11 @@ func (d *dbBase) HasReturningID(*modelInfo, *string) bool {
return false return false
} }
// sync auto key
func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error {
return nil
}
// convert time from db. // convert time from db.
func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) { func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) {
*t = t.In(tz) *t = t.In(tz)

View File

@ -56,6 +56,8 @@ var postgresTypes = map[string]string{
"uint64": `bigint CHECK("%COL%" >= 0)`, "uint64": `bigint CHECK("%COL%" >= 0)`,
"float64": "double precision", "float64": "double precision",
"float64-decimal": "numeric(%d, %d)", "float64-decimal": "numeric(%d, %d)",
"json": "json",
"jsonb": "jsonb",
} }
// postgresql dbBaser. // postgresql dbBaser.
@ -123,14 +125,35 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
} }
// make returning sql support for postgresql. // make returning sql support for postgresql.
func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) { func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool {
if mi.fields.pk.auto { fi := mi.fields.pk
if query != nil { if fi.fieldType&IsPositiveIntegerField == 0 && fi.fieldType&IsIntegerField == 0 {
*query = fmt.Sprintf(`%s RETURNING "%s"`, *query, mi.fields.pk.column) return false
}
has = true
} }
return
if query != nil {
*query = fmt.Sprintf(`%s RETURNING "%s"`, *query, fi.column)
}
return true
}
// sync auto key
func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error {
if len(autoFields) == 0 {
return nil
}
Q := d.ins.TableQuote()
for _, name := range autoFields {
query := fmt.Sprintf("SELECT setval(pg_get_serial_sequence('%s', '%s'), (SELECT MAX(%s%s%s) FROM %s%s%s));",
mi.table, name,
Q, name, Q,
Q, mi.table, Q)
if _, err := db.Exec(query); err != nil {
return err
}
}
return nil
} }
// show table sql for postgresql. // show table sql for postgresql.

View File

@ -33,13 +33,13 @@ func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interfac
fi := mi.fields.pk fi := mi.fields.pk
v := ind.FieldByIndex(fi.fieldIndex) v := ind.FieldByIndex(fi.fieldIndex)
if fi.fieldType&IsPostiveIntegerField > 0 { if fi.fieldType&IsPositiveIntegerField > 0 {
vu := v.Uint() vu := v.Uint()
exist = vu > 0 exist = vu > 0
value = vu value = vu
} else if fi.fieldType&IsIntegerField > 0 { } else if fi.fieldType&IsIntegerField > 0 {
vu := v.Int() vu := v.Int()
exist = vu > 0 exist = true
value = vu value = vu
} else { } else {
vu := v.String() vu := v.String()
@ -74,24 +74,32 @@ outFor:
case reflect.String: case reflect.String:
v := val.String() v := val.String()
if fi != nil { if fi != nil {
if fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField { if fi.fieldType == TypeTimeField || fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField {
var t time.Time var t time.Time
var err error var err error
if len(v) >= 19 { if len(v) >= 19 {
s := v[:19] s := v[:19]
t, err = time.ParseInLocation(formatDateTime, s, DefaultTimeLoc) t, err = time.ParseInLocation(formatDateTime, s, DefaultTimeLoc)
} else { } else if len(v) >= 10 {
s := v s := v
if len(v) > 10 { if len(v) > 10 {
s = v[:10] s = v[:10]
} }
t, err = time.ParseInLocation(formatDate, s, tz) t, err = time.ParseInLocation(formatDate, s, tz)
} else {
s := v
if len(s) > 8 {
s = v[:8]
}
t, err = time.ParseInLocation(formatTime, s, tz)
} }
if err == nil { if err == nil {
if fi.fieldType == TypeDateField { if fi.fieldType == TypeDateField {
v = t.In(tz).Format(formatDate) v = t.In(tz).Format(formatDate)
} else { } else if fi.fieldType == TypeDateTimeField {
v = t.In(tz).Format(formatDateTime) v = t.In(tz).Format(formatDateTime)
} else {
v = t.In(tz).Format(formatTime)
} }
} }
} }
@ -137,8 +145,10 @@ outFor:
if v, ok := arg.(time.Time); ok { if v, ok := arg.(time.Time); ok {
if fi != nil && fi.fieldType == TypeDateField { if fi != nil && fi.fieldType == TypeDateField {
arg = v.In(tz).Format(formatDate) arg = v.In(tz).Format(formatDate)
} else { } else if fi.fieldType == TypeDateTimeField {
arg = v.In(tz).Format(formatDateTime) arg = v.In(tz).Format(formatDateTime)
} else {
arg = v.In(tz).Format(formatTime)
} }
} else { } else {
typ := val.Type() typ := val.Type()

View File

@ -66,7 +66,7 @@ func registerModel(prefix string, model interface{}) {
} }
if info.fields.pk == nil { if info.fields.pk == nil {
fmt.Printf("<orm.RegisterModel> `%s` need a primary key field\n", name) fmt.Printf("<orm.RegisterModel> `%s` need a primary key field, default use 'id' if not set\n", name)
os.Exit(2) os.Exit(2)
} }

View File

@ -25,6 +25,7 @@ const (
TypeBooleanField = 1 << iota TypeBooleanField = 1 << iota
TypeCharField TypeCharField
TypeTextField TypeTextField
TypeTimeField
TypeDateField TypeDateField
TypeDateTimeField TypeDateTimeField
TypeBitField TypeBitField
@ -37,6 +38,8 @@ const (
TypePositiveBigIntegerField TypePositiveBigIntegerField
TypeFloatField TypeFloatField
TypeDecimalField TypeDecimalField
TypeJSONField
TypeJsonbField
RelForeignKey RelForeignKey
RelOneToOne RelOneToOne
RelManyToMany RelManyToMany
@ -46,10 +49,10 @@ const (
// Define some logic enum // Define some logic enum
const ( const (
IsIntegerField = ^-TypePositiveBigIntegerField >> 4 << 5 IsIntegerField = ^-TypePositiveBigIntegerField >> 5 << 6
IsPostiveIntegerField = ^-TypePositiveBigIntegerField >> 8 << 9 IsPositiveIntegerField = ^-TypePositiveBigIntegerField >> 9 << 10
IsRelField = ^-RelReverseMany >> 14 << 15 IsRelField = ^-RelReverseMany >> 17 << 18
IsFieldType = ^-RelReverseMany<<1 + 1 IsFieldType = ^-RelReverseMany<<1 + 1
) )
// BooleanField A true/false field. // BooleanField A true/false field.
@ -145,6 +148,65 @@ func (e *CharField) RawValue() interface{} {
// verify CharField implement Fielder // verify CharField implement Fielder
var _ Fielder = new(CharField) var _ Fielder = new(CharField)
// TimeField A time, represented in go by a time.Time instance.
// only time values like 10:00:00
// Has a few extra, optional attr tag:
//
// auto_now:
// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps.
// Note that the current date is always used; its not just a default value that you can override.
//
// auto_now_add:
// Automatically set the field to now when the object is first created. Useful for creation of timestamps.
// Note that the current date is always used; its not just a default value that you can override.
//
// eg: `orm:"auto_now"` or `orm:"auto_now_add"`
type TimeField time.Time
// Value return the time.Time
func (e TimeField) Value() time.Time {
return time.Time(e)
}
// Set set the TimeField's value
func (e *TimeField) Set(d time.Time) {
*e = TimeField(d)
}
// String convert time to string
func (e *TimeField) String() string {
return e.Value().String()
}
// FieldType return enum type Date
func (e *TimeField) FieldType() int {
return TypeDateField
}
// SetRaw convert the interface to time.Time. Allow string and time.Time
func (e *TimeField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
v, err := timeParse(d, formatTime)
if err != nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<TimeField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return time value
func (e *TimeField) RawValue() interface{} {
return e.Value()
}
var _ Fielder = new(TimeField)
// DateField A date, represented in go by a time.Time instance. // DateField A date, represented in go by a time.Time instance.
// only date values like 2006-01-02 // only date values like 2006-01-02
// Has a few extra, optional attr tag: // Has a few extra, optional attr tag:
@ -627,3 +689,87 @@ func (e *TextField) RawValue() interface{} {
// verify TextField implement Fielder // verify TextField implement Fielder
var _ Fielder = new(TextField) var _ Fielder = new(TextField)
// JSONField postgres json field.
type JSONField string
// Value return JSONField value
func (j JSONField) Value() string {
return string(j)
}
// Set the JSONField value
func (j *JSONField) Set(d string) {
*j = JSONField(d)
}
// String convert JSONField to string
func (j *JSONField) String() string {
return j.Value()
}
// FieldType return enum type
func (j *JSONField) FieldType() int {
return TypeJSONField
}
// SetRaw convert interface string to string
func (j *JSONField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
j.Set(d)
default:
return fmt.Errorf("<JSONField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return JSONField value
func (j *JSONField) RawValue() interface{} {
return j.Value()
}
// verify JSONField implement Fielder
var _ Fielder = new(JSONField)
// JsonbField postgres json field.
type JsonbField string
// Value return JsonbField value
func (j JsonbField) Value() string {
return string(j)
}
// Set the JsonbField value
func (j *JsonbField) Set(d string) {
*j = JsonbField(d)
}
// String convert JsonbField to string
func (j *JsonbField) String() string {
return j.Value()
}
// FieldType return enum type
func (j *JsonbField) FieldType() int {
return TypeJsonbField
}
// SetRaw convert interface string to string
func (j *JsonbField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
j.Set(d)
default:
return fmt.Errorf("<JsonbField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return JsonbField value
func (j *JsonbField) RawValue() interface{} {
return j.Value()
}
// verify JsonbField implement Fielder
var _ Fielder = new(JsonbField)

View File

@ -119,6 +119,7 @@ type fieldInfo struct {
colDefault bool colDefault bool
initial StrTo initial StrTo
size int size int
toText bool
autoNow bool autoNow bool
autoNowAdd bool autoNowAdd bool
rel bool rel bool
@ -239,8 +240,15 @@ checkType:
if err != nil { if err != nil {
goto end goto end
} }
if fieldType == TypeCharField && tags["type"] == "text" { if fieldType == TypeCharField {
fieldType = TypeTextField switch tags["type"] {
case "text":
fieldType = TypeTextField
case "json":
fieldType = TypeJSONField
case "jsonb":
fieldType = TypeJsonbField
}
} }
if fieldType == TypeFloatField && (digits != "" || decimals != "") { if fieldType == TypeFloatField && (digits != "" || decimals != "") {
fieldType = TypeDecimalField fieldType = TypeDecimalField
@ -248,6 +256,9 @@ checkType:
if fieldType == TypeDateTimeField && tags["type"] == "date" { if fieldType == TypeDateTimeField && tags["type"] == "date" {
fieldType = TypeDateField fieldType = TypeDateField
} }
if fieldType == TypeTimeField && tags["type"] == "time" {
fieldType = TypeTimeField
}
} }
switch fieldType { switch fieldType {
@ -339,7 +350,7 @@ checkType:
switch fieldType { switch fieldType {
case TypeBooleanField: case TypeBooleanField:
case TypeCharField: case TypeCharField, TypeJSONField, TypeJsonbField:
if size != "" { if size != "" {
v, e := StrTo(size).Int32() v, e := StrTo(size).Int32()
if e != nil { if e != nil {
@ -349,11 +360,12 @@ checkType:
} }
} else { } else {
fi.size = 255 fi.size = 255
fi.toText = true
} }
case TypeTextField: case TypeTextField:
fi.index = false fi.index = false
fi.unique = false fi.unique = false
case TypeDateField, TypeDateTimeField: case TypeTimeField, TypeDateField, TypeDateTimeField:
if attrs["auto_now"] { if attrs["auto_now"] {
fi.autoNow = true fi.autoNow = true
} else if attrs["auto_now_add"] { } else if attrs["auto_now_add"] {
@ -406,7 +418,7 @@ checkType:
fi.index = false fi.index = false
} }
if fi.auto || fi.pk || fi.unique || fieldType == TypeDateField || fieldType == TypeDateTimeField { if fi.auto || fi.pk || fi.unique || fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField {
// can not set default // can not set default
initial.Clear() initial.Clear()
} }

View File

@ -78,40 +78,43 @@ func (e *SliceStringField) RawValue() interface{} {
var _ Fielder = new(SliceStringField) var _ Fielder = new(SliceStringField)
// A json field. // A json field.
type JSONField struct { type JSONFieldTest struct {
Name string Name string
Data string Data string
} }
func (e *JSONField) String() string { func (e *JSONFieldTest) String() string {
data, _ := json.Marshal(e) data, _ := json.Marshal(e)
return string(data) return string(data)
} }
func (e *JSONField) FieldType() int { func (e *JSONFieldTest) FieldType() int {
return TypeTextField return TypeTextField
} }
func (e *JSONField) SetRaw(value interface{}) error { func (e *JSONFieldTest) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case string: case string:
return json.Unmarshal([]byte(d), e) return json.Unmarshal([]byte(d), e)
default: default:
return fmt.Errorf("<JsonField.SetRaw> unknown value `%v`", value) return fmt.Errorf("<JSONField.SetRaw> unknown value `%v`", value)
} }
} }
func (e *JSONField) RawValue() interface{} { func (e *JSONFieldTest) RawValue() interface{} {
return e.String() return e.String()
} }
var _ Fielder = new(JSONField) var _ Fielder = new(JSONFieldTest)
type Data struct { type Data struct {
ID int `orm:"column(id)"` ID int `orm:"column(id)"`
Boolean bool Boolean bool
Char string `orm:"size(50)"` Char string `orm:"size(50)"`
Text string `orm:"type(text)"` Text string `orm:"type(text)"`
JSON string `orm:"type(json);default({\"name\":\"json\"})"`
Jsonb string `orm:"type(jsonb)"`
Time time.Time `orm:"type(time)"`
Date time.Time `orm:"type(date)"` Date time.Time `orm:"type(date)"`
DateTime time.Time `orm:"column(datetime)"` DateTime time.Time `orm:"column(datetime)"`
Byte byte Byte byte
@ -136,6 +139,9 @@ type DataNull struct {
Boolean bool `orm:"null"` Boolean bool `orm:"null"`
Char string `orm:"null;size(50)"` Char string `orm:"null;size(50)"`
Text string `orm:"null;type(text)"` Text string `orm:"null;type(text)"`
JSON string `orm:"type(json);null"`
Jsonb string `orm:"type(jsonb);null"`
Time time.Time `orm:"null;type(time)"`
Date time.Time `orm:"null;type(date)"` Date time.Time `orm:"null;type(date)"`
DateTime time.Time `orm:"null;column(datetime)"` DateTime time.Time `orm:"null;column(datetime)"`
Byte byte `orm:"null"` Byte byte `orm:"null"`
@ -175,6 +181,9 @@ type DataNull struct {
Float32Ptr *float32 `orm:"null"` Float32Ptr *float32 `orm:"null"`
Float64Ptr *float64 `orm:"null"` Float64Ptr *float64 `orm:"null"`
DecimalPtr *float64 `orm:"digits(8);decimals(4);null"` DecimalPtr *float64 `orm:"digits(8);decimals(4);null"`
TimePtr *time.Time `orm:"null;type(time)"`
DatePtr *time.Time `orm:"null;type(date)"`
DateTimePtr *time.Time `orm:"null"`
} }
type String string type String string
@ -237,7 +246,7 @@ type User struct {
ShouldSkip string `orm:"-"` ShouldSkip string `orm:"-"`
Nums int Nums int
Langs SliceStringField `orm:"size(100)"` Langs SliceStringField `orm:"size(100)"`
Extra JSONField `orm:"type(text)"` Extra JSONFieldTest `orm:"type(text)"`
unexport bool `orm:"-"` unexport bool `orm:"-"`
unexportBool bool unexportBool bool
} }
@ -375,6 +384,28 @@ func NewInLine() *InLine {
return new(InLine) return new(InLine)
} }
type InLineOneToOne struct {
// Common Fields
ModelBase
Note string
InLine *InLine `orm:"rel(fk);column(inline)"`
}
func NewInLineOneToOne() *InLineOneToOne {
return new(InLineOneToOne)
}
type IntegerPk struct {
ID int64 `orm:"pk"`
Value string
}
type UintPk struct {
ID uint32 `orm:"pk"`
Name string
}
var DBARGS = struct { var DBARGS = struct {
Driver string Driver string
Source string Source string

View File

@ -137,6 +137,8 @@ func getFieldType(val reflect.Value) (ft int, err error) {
ft = TypeBooleanField ft = TypeBooleanField
case reflect.TypeOf(new(string)): case reflect.TypeOf(new(string)):
ft = TypeCharField ft = TypeCharField
case reflect.TypeOf(new(time.Time)):
ft = TypeDateTimeField
default: default:
elm := reflect.Indirect(val) elm := reflect.Indirect(val)
switch elm.Kind() { switch elm.Kind() {
@ -192,10 +194,10 @@ func parseStructTag(data string, attrs *map[string]bool, tags *map[string]string
tag := make(map[string]string) tag := make(map[string]string)
for _, v := range strings.Split(data, defaultStructTagDelim) { for _, v := range strings.Split(data, defaultStructTagDelim) {
v = strings.TrimSpace(v) v = strings.TrimSpace(v)
if supportTag[v] == 1 { if t := strings.ToLower(v); supportTag[t] == 1 {
attr[v] = true attr[t] = true
} else if i := strings.Index(v, "("); i > 0 && strings.Index(v, ")") == len(v)-1 { } else if i := strings.Index(v, "("); i > 0 && strings.Index(v, ")") == len(v)-1 {
name := v[:i] name := t[:i]
if supportTag[name] == 2 { if supportTag[name] == 2 {
v = v[i+1 : len(v)-1] v = v[i+1 : len(v)-1]
tag[name] = v tag[name] = v

View File

@ -140,7 +140,14 @@ func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, i
return (err == nil), id, err return (err == nil), id, err
} }
return false, ind.FieldByIndex(mi.fields.pk.fieldIndex).Int(), err id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex)
if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
id = int64(vid.Uint())
} else {
id = vid.Int()
}
return false, id, err
} }
// insert model data to database // insert model data to database
@ -159,7 +166,7 @@ func (o *orm) Insert(md interface{}) (int64, error) {
// set auto pk field // set auto pk field
func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) { func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) {
if mi.fields.pk.auto { if mi.fields.pk.auto {
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id)) ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id))
} else { } else {
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(id) ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(id)
@ -184,7 +191,7 @@ func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) {
if bulk <= 1 { if bulk <= 1 {
for i := 0; i < sind.Len(); i++ { for i := 0; i < sind.Len(); i++ {
ind := sind.Index(i) ind := reflect.Indirect(sind.Index(i))
mi, _ := o.getMiInd(ind.Interface(), false) mi, _ := o.getMiInd(ind.Interface(), false)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
if err != nil { if err != nil {

View File

@ -31,7 +31,7 @@ type Log struct {
// NewLog set io.Writer to create a Logger. // NewLog set io.Writer to create a Logger.
func NewLog(out io.Writer) *Log { func NewLog(out io.Writer) *Log {
d := new(Log) d := new(Log)
d.Logger = log.New(out, "[ORM]", 1e9) d.Logger = log.New(out, "[ORM]", log.LstdFlags)
return d return d
} }

View File

@ -50,7 +50,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) {
} }
if id > 0 { if id > 0 {
if o.mi.fields.pk.auto { if o.mi.fields.pk.auto {
if o.mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { if o.mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetUint(uint64(id)) ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetUint(uint64(id))
} else { } else {
ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetInt(id) ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetInt(id)

View File

@ -19,6 +19,7 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"math"
"os" "os"
"path/filepath" "path/filepath"
"reflect" "reflect"
@ -33,6 +34,7 @@ var _ = os.PathSeparator
var ( var (
testDate = formatDate + " -0700" testDate = formatDate + " -0700"
testDateTime = formatDateTime + " -0700" testDateTime = formatDateTime + " -0700"
testTime = formatTime + " -0700"
) )
type argAny []interface{} type argAny []interface{}
@ -188,6 +190,9 @@ func TestSyncDb(t *testing.T) {
RegisterModel(new(Permission)) RegisterModel(new(Permission))
RegisterModel(new(GroupPermissions)) RegisterModel(new(GroupPermissions))
RegisterModel(new(InLine)) RegisterModel(new(InLine))
RegisterModel(new(InLineOneToOne))
RegisterModel(new(IntegerPk))
RegisterModel(new(UintPk))
err := RunSyncdb("default", true, Debug) err := RunSyncdb("default", true, Debug)
throwFail(t, err) throwFail(t, err)
@ -208,6 +213,9 @@ func TestRegisterModels(t *testing.T) {
RegisterModel(new(Permission)) RegisterModel(new(Permission))
RegisterModel(new(GroupPermissions)) RegisterModel(new(GroupPermissions))
RegisterModel(new(InLine)) RegisterModel(new(InLine))
RegisterModel(new(InLineOneToOne))
RegisterModel(new(IntegerPk))
RegisterModel(new(UintPk))
BootStrap() BootStrap()
@ -233,6 +241,9 @@ var DataValues = map[string]interface{}{
"Boolean": true, "Boolean": true,
"Char": "char", "Char": "char",
"Text": "text", "Text": "text",
"JSON": `{"name":"json"}`,
"Jsonb": `{"name": "jsonb"}`,
"Time": time.Now(),
"Date": time.Now(), "Date": time.Now(),
"DateTime": time.Now(), "DateTime": time.Now(),
"Byte": byte(1<<8 - 1), "Byte": byte(1<<8 - 1),
@ -257,10 +268,12 @@ func TestDataTypes(t *testing.T) {
ind := reflect.Indirect(reflect.ValueOf(&d)) ind := reflect.Indirect(reflect.ValueOf(&d))
for name, value := range DataValues { for name, value := range DataValues {
if name == "JSON" {
continue
}
e := ind.FieldByName(name) e := ind.FieldByName(name)
e.Set(reflect.ValueOf(value)) e.Set(reflect.ValueOf(value))
} }
id, err := dORM.Insert(&d) id, err := dORM.Insert(&d)
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(id, 1)) throwFail(t, AssertIs(id, 1))
@ -281,6 +294,9 @@ func TestDataTypes(t *testing.T) {
case "DateTime": case "DateTime":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
case "Time":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(testTime)
} }
throwFail(t, AssertIs(vu == value, true), value, vu) throwFail(t, AssertIs(vu == value, true), value, vu)
} }
@ -299,10 +315,18 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(id, 1)) throwFail(t, AssertIs(id, 1))
data := `{"ok":1,"data":{"arr":[1,2],"msg":"gopher"}}`
d = DataNull{ID: 1, JSON: data}
num, err := dORM.Update(&d)
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
d = DataNull{ID: 1} d = DataNull{ID: 1}
err = dORM.Read(&d) err = dORM.Read(&d)
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(d.JSON, data))
throwFail(t, AssertIs(d.NullBool.Valid, false)) throwFail(t, AssertIs(d.NullBool.Valid, false))
throwFail(t, AssertIs(d.NullString.Valid, false)) throwFail(t, AssertIs(d.NullString.Valid, false))
throwFail(t, AssertIs(d.NullInt64.Valid, false)) throwFail(t, AssertIs(d.NullInt64.Valid, false))
@ -326,6 +350,9 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, AssertIs(d.Float32Ptr, nil)) throwFail(t, AssertIs(d.Float32Ptr, nil))
throwFail(t, AssertIs(d.Float64Ptr, nil)) throwFail(t, AssertIs(d.Float64Ptr, nil))
throwFail(t, AssertIs(d.DecimalPtr, nil)) throwFail(t, AssertIs(d.DecimalPtr, nil))
throwFail(t, AssertIs(d.TimePtr, nil))
throwFail(t, AssertIs(d.DatePtr, nil))
throwFail(t, AssertIs(d.DateTimePtr, nil))
_, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec() _, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec()
throwFail(t, err) throwFail(t, err)
@ -352,6 +379,9 @@ func TestNullDataTypes(t *testing.T) {
float32Ptr := float32(42.0) float32Ptr := float32(42.0)
float64Ptr := float64(42.0) float64Ptr := float64(42.0)
decimalPtr := float64(42.0) decimalPtr := float64(42.0)
timePtr := time.Now()
datePtr := time.Now()
dateTimePtr := time.Now()
d = DataNull{ d = DataNull{
DateTime: time.Now(), DateTime: time.Now(),
@ -377,6 +407,9 @@ func TestNullDataTypes(t *testing.T) {
Float32Ptr: &float32Ptr, Float32Ptr: &float32Ptr,
Float64Ptr: &float64Ptr, Float64Ptr: &float64Ptr,
DecimalPtr: &decimalPtr, DecimalPtr: &decimalPtr,
TimePtr: &timePtr,
DatePtr: &datePtr,
DateTimePtr: &dateTimePtr,
} }
id, err = dORM.Insert(&d) id, err = dORM.Insert(&d)
@ -417,6 +450,9 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, AssertIs(*d.Float32Ptr, float32Ptr)) throwFail(t, AssertIs(*d.Float32Ptr, float32Ptr))
throwFail(t, AssertIs(*d.Float64Ptr, float64Ptr)) throwFail(t, AssertIs(*d.Float64Ptr, float64Ptr))
throwFail(t, AssertIs(*d.DecimalPtr, decimalPtr)) throwFail(t, AssertIs(*d.DecimalPtr, decimalPtr))
throwFail(t, AssertIs((*d.TimePtr).Format(testTime), timePtr.Format(testTime)))
throwFail(t, AssertIs((*d.DatePtr).Format(testDate), datePtr.Format(testDate)))
throwFail(t, AssertIs((*d.DateTimePtr).Format(testDateTime), dateTimePtr.Format(testDateTime)))
} }
func TestDataCustomTypes(t *testing.T) { func TestDataCustomTypes(t *testing.T) {
@ -1521,6 +1557,7 @@ func TestRawQueryRow(t *testing.T) {
Boolean bool Boolean bool
Char string Char string
Text string Text string
Time time.Time
Date time.Time Date time.Time
DateTime time.Time DateTime time.Time
Byte byte Byte byte
@ -1549,14 +1586,14 @@ func TestRawQueryRow(t *testing.T) {
Q := dDbBaser.TableQuote() Q := dDbBaser.TableQuote()
cols := []string{ cols := []string{
"id", "boolean", "char", "text", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32", "id", "boolean", "char", "text", "time", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32",
"int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal", "int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal",
} }
sep := fmt.Sprintf("%s, %s", Q, Q) sep := fmt.Sprintf("%s, %s", Q, Q)
query := fmt.Sprintf("SELECT %s%s%s FROM data WHERE id = ?", Q, strings.Join(cols, sep), Q) query := fmt.Sprintf("SELECT %s%s%s FROM data WHERE id = ?", Q, strings.Join(cols, sep), Q)
var id int var id int
values := []interface{}{ values := []interface{}{
&id, &Boolean, &Char, &Text, &Date, &DateTime, &Byte, &Rune, &Int, &Int8, &Int16, &Int32, &id, &Boolean, &Char, &Text, &Time, &Date, &DateTime, &Byte, &Rune, &Int, &Int8, &Int16, &Int32,
&Int64, &Uint, &Uint8, &Uint16, &Uint32, &Uint64, &Float32, &Float64, &Decimal, &Int64, &Uint, &Uint8, &Uint16, &Uint32, &Uint64, &Float32, &Float64, &Decimal,
} }
err := dORM.Raw(query, 1).QueryRow(values...) err := dORM.Raw(query, 1).QueryRow(values...)
@ -1567,6 +1604,10 @@ func TestRawQueryRow(t *testing.T) {
switch col { switch col {
case "id": case "id":
throwFail(t, AssertIs(id, 1)) throwFail(t, AssertIs(id, 1))
case "time":
v = v.(time.Time).In(DefaultTimeLoc)
value := dataValues[col].(time.Time).In(DefaultTimeLoc)
throwFail(t, AssertIs(v, value, testTime))
case "date": case "date":
v = v.(time.Time).In(DefaultTimeLoc) v = v.(time.Time).In(DefaultTimeLoc)
value := dataValues[col].(time.Time).In(DefaultTimeLoc) value := dataValues[col].(time.Time).In(DefaultTimeLoc)
@ -1614,6 +1655,9 @@ func TestQueryRows(t *testing.T) {
e := ind.FieldByName(name) e := ind.FieldByName(name)
vu := e.Interface() vu := e.Interface()
switch name { switch name {
case "Time":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(testTime)
case "Date": case "Date":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate)
value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) value = value.(time.Time).In(DefaultTimeLoc).Format(testDate)
@ -1638,6 +1682,9 @@ func TestQueryRows(t *testing.T) {
e := ind.FieldByName(name) e := ind.FieldByName(name)
vu := e.Interface() vu := e.Interface()
switch name { switch name {
case "Time":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(testTime)
case "Date": case "Date":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate)
value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) value = value.(time.Time).In(DefaultTimeLoc).Format(testDate)
@ -1959,3 +2006,171 @@ func TestInLine(t *testing.T) {
throwFail(t, AssertIs(il.Created.In(DefaultTimeLoc), inline.Created.In(DefaultTimeLoc), testDate)) throwFail(t, AssertIs(il.Created.In(DefaultTimeLoc), inline.Created.In(DefaultTimeLoc), testDate))
throwFail(t, AssertIs(il.Updated.In(DefaultTimeLoc), inline.Updated.In(DefaultTimeLoc), testDateTime)) throwFail(t, AssertIs(il.Updated.In(DefaultTimeLoc), inline.Updated.In(DefaultTimeLoc), testDateTime))
} }
func TestInLineOneToOne(t *testing.T) {
name := "121"
email := "121@go.com"
inline := NewInLine()
inline.Name = name
inline.Email = email
id, err := dORM.Insert(inline)
throwFail(t, err)
throwFail(t, AssertIs(id, 2))
note := "one2one"
il121 := NewInLineOneToOne()
il121.Note = note
il121.InLine = inline
_, err = dORM.Insert(il121)
throwFail(t, err)
throwFail(t, AssertIs(il121.ID, 1))
il := NewInLineOneToOne()
err = dORM.QueryTable(il).Filter("Id", 1).RelatedSel().One(il)
throwFail(t, err)
throwFail(t, AssertIs(il.Note, note))
throwFail(t, AssertIs(il.InLine.ID, id))
throwFail(t, AssertIs(il.InLine.Name, name))
throwFail(t, AssertIs(il.InLine.Email, email))
rinline := NewInLine()
err = dORM.QueryTable(rinline).Filter("InLineOneToOne__Id", 1).One(rinline)
throwFail(t, err)
throwFail(t, AssertIs(rinline.ID, id))
throwFail(t, AssertIs(rinline.Name, name))
throwFail(t, AssertIs(rinline.Email, email))
}
func TestIntegerPk(t *testing.T) {
its := []IntegerPk{
{ID: math.MinInt64, Value: "-"},
{ID: 0, Value: "0"},
{ID: math.MaxInt64, Value: "+"},
}
num, err := dORM.InsertMulti(len(its), its)
throwFail(t, err)
throwFail(t, AssertIs(num, len(its)))
for _, intPk := range its {
out := IntegerPk{ID: intPk.ID}
err = dORM.Read(&out)
throwFail(t, err)
throwFail(t, AssertIs(out.Value, intPk.Value))
}
num, err = dORM.InsertMulti(1, []*IntegerPk{&IntegerPk{
ID: 1, Value: "ok",
}})
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
}
func TestInsertAuto(t *testing.T) {
u := &User{
UserName: "autoPre",
Email: "autoPre@gmail.com",
}
id, err := dORM.Insert(u)
throwFail(t, err)
id += 100
su := &User{
ID: int(id),
UserName: "auto",
Email: "auto@gmail.com",
}
nid, err := dORM.Insert(su)
throwFail(t, err)
throwFail(t, AssertIs(nid, id))
users := []User{
{ID: int(id + 100), UserName: "auto_100"},
{ID: int(id + 110), UserName: "auto_110"},
{ID: int(id + 120), UserName: "auto_120"},
}
num, err := dORM.InsertMulti(100, users)
throwFail(t, err)
throwFail(t, AssertIs(num, 3))
u = &User{
UserName: "auto_121",
}
nid, err = dORM.Insert(u)
throwFail(t, err)
throwFail(t, AssertIs(nid, id+120+1))
}
func TestUintPk(t *testing.T) {
name := "go"
u := &UintPk{
ID: 8,
Name: name,
}
created, pk, err := dORM.ReadOrCreate(u, "ID")
throwFail(t, err)
throwFail(t, AssertIs(created, true))
throwFail(t, AssertIs(u.Name, name))
nu := &UintPk{ID: 8}
created, pk, err = dORM.ReadOrCreate(nu, "ID")
throwFail(t, err)
throwFail(t, AssertIs(created, false))
throwFail(t, AssertIs(nu.ID, u.ID))
throwFail(t, AssertIs(pk, u.ID))
throwFail(t, AssertIs(nu.Name, name))
dORM.Delete(u)
}
func TestSnake(t *testing.T) {
cases := map[string]string{
"i": "i",
"I": "i",
"iD": "i_d",
"ID": "id",
"NO": "no",
"NOO": "noo",
"NOOooOOoo": "noo_oo_oo_oo",
"OrderNO": "order_no",
"tagName": "tag_name",
"tag_Name": "tag_name",
"tag_name": "tag_name",
"_tag_name": "_tag_name",
"tag_666name": "tag_666name",
"tag_666Name": "tag_666_name",
}
for name, want := range cases {
got := snakeString(name)
throwFail(t, AssertIs(got, want))
}
}
func TestIgnoreCaseTag(t *testing.T) {
type testTagModel struct {
ID int `orm:"pk"`
NOO string `orm:"column(n)"`
Name01 string `orm:"NULL"`
Name02 string `orm:"COLUMN(Name)"`
Name03 string `orm:"Column(name)"`
}
modelCache.clean()
RegisterModel(&testTagModel{})
info, ok := modelCache.get("test_tag_model")
throwFail(t, AssertIs(ok, true))
throwFail(t, AssertNot(info, nil))
if t == nil {
return
}
throwFail(t, AssertIs(info.fields.GetByName("NOO").column, "n"))
throwFail(t, AssertIs(info.fields.GetByName("Name01").null, true))
throwFail(t, AssertIs(info.fields.GetByName("Name02").column, "Name"))
throwFail(t, AssertIs(info.fields.GetByName("Name03").column, "name"))
}

View File

@ -420,4 +420,5 @@ type dbBaser interface {
ShowColumnsQuery(string) string ShowColumnsQuery(string) string
IndexExists(dbQuerier, string, string) bool IndexExists(dbQuerier, string, string) bool
collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error)
setval(dbQuerier, *modelInfo, []string) error
} }

View File

@ -181,18 +181,36 @@ func ToInt64(value interface{}) (d int64) {
return return
} }
// snake string, XxYy to xx_yy // snake string, XxYy to xx_yy , XxYY to xx_yy
func snakeString(s string) string { func snakeString(s string) string {
data := make([]byte, 0, len(s)*2) data := make([]byte, 0, len(s)*2)
j := false
num := len(s) num := len(s)
for i := 0; i < num; i++ { for i := 0; i < num; i++ {
d := s[i] d := s[i]
if i > 0 && d >= 'A' && d <= 'Z' && j { if i > 0 && d != '_' && s[i-1] != '_' {
data = append(data, '_') need := false
} // upper as 1, lower as 0
if d != '_' { // XX -> 11 -> 11
j = true // Xx -> 10 -> 10
// XxYyZZ -> 101011 -> 10_10_11
isUpper := d >= 'A' && d <= 'Z'
preIsUpper := s[i-1] >= 'A' && s[i-1] <= 'Z'
if isUpper {
// like : xxYy
if !preIsUpper {
need = true
}
} else {
if preIsUpper {
// ignore "Xy" in "xxXyy"
if i-2 >= 0 && s[i-2] >= 'A' && s[i-2] <= 'Z' {
need = true
}
}
}
if need {
data = append(data, '_')
}
} }
data = append(data, d) data = append(data, d)
} }

View File

@ -23,10 +23,11 @@ import (
"go/token" "go/token"
"io/ioutil" "io/ioutil"
"os" "os"
"path" "path/filepath"
"sort" "sort"
"strings" "strings"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/utils" "github.com/astaxie/beego/utils"
) )
@ -55,10 +56,11 @@ func init() {
} }
func parserPkg(pkgRealpath, pkgpath string) error { func parserPkg(pkgRealpath, pkgpath string) error {
rep := strings.NewReplacer("/", "_", ".", "_") rep := strings.NewReplacer("\\", "_", "/", "_", ".", "_")
commentFilename = coomentPrefix + rep.Replace(pkgpath) + ".go" commentFilename, _ = filepath.Rel(AppPath, pkgRealpath)
commentFilename = coomentPrefix + rep.Replace(commentFilename) + ".go"
if !compareFile(pkgRealpath) { if !compareFile(pkgRealpath) {
Info(pkgRealpath + " no changed") logs.Info(pkgRealpath + " no changed")
return nil return nil
} }
genInfoList = make(map[string][]ControllerComments) genInfoList = make(map[string][]ControllerComments)
@ -86,7 +88,7 @@ func parserPkg(pkgRealpath, pkgpath string) error {
} }
} }
} }
genRouterCode() genRouterCode(pkgRealpath)
savetoFile(pkgRealpath) savetoFile(pkgRealpath)
return nil return nil
} }
@ -129,9 +131,9 @@ func parserComments(comments *ast.CommentGroup, funcName, controllerName, pkgpat
return nil return nil
} }
func genRouterCode() { func genRouterCode(pkgRealpath string) {
os.Mkdir(path.Join(AppPath, "routers"), 0755) os.Mkdir(getRouterDir(pkgRealpath), 0755)
Info("generate router from comments") logs.Info("generate router from comments")
var ( var (
globalinfo string globalinfo string
sortKey []string sortKey []string
@ -164,15 +166,15 @@ func genRouterCode() {
globalinfo = globalinfo + ` globalinfo = globalinfo + `
beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"], beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"],
beego.ControllerComments{ beego.ControllerComments{
"` + strings.TrimSpace(c.Method) + `", Method: "` + strings.TrimSpace(c.Method) + `",
` + "`" + c.Router + "`" + `, ` + "Router: `" + c.Router + "`" + `,
` + allmethod + `, AllowHTTPMethods: ` + allmethod + `,
` + params + `}) Params: ` + params + `})
` `
} }
} }
if globalinfo != "" { if globalinfo != "" {
f, err := os.Create(path.Join(AppPath, "routers", commentFilename)) f, err := os.Create(filepath.Join(getRouterDir(pkgRealpath), commentFilename))
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -182,7 +184,7 @@ func genRouterCode() {
} }
func compareFile(pkgRealpath string) bool { func compareFile(pkgRealpath string) bool {
if !utils.FileExists(path.Join(AppPath, "routers", commentFilename)) { if !utils.FileExists(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) {
return true return true
} }
if utils.FileExists(lastupdateFilename) { if utils.FileExists(lastupdateFilename) {
@ -229,3 +231,19 @@ func getpathTime(pkgRealpath string) (lastupdate int64, err error) {
} }
return lastupdate, nil return lastupdate, nil
} }
func getRouterDir(pkgRealpath string) string {
dir := filepath.Dir(pkgRealpath)
for {
d := filepath.Join(dir, "routers")
if utils.FileExists(d) {
return d
}
if r, _ := filepath.Rel(dir, AppPath); r == "." {
return d
}
// Parent dir.
dir = filepath.Dir(dir)
}
}

111
router.go
View File

@ -28,6 +28,7 @@ import (
"time" "time"
beecontext "github.com/astaxie/beego/context" beecontext "github.com/astaxie/beego/context"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/toolbox" "github.com/astaxie/beego/toolbox"
"github.com/astaxie/beego/utils" "github.com/astaxie/beego/utils"
) )
@ -114,7 +115,7 @@ type controllerInfo struct {
type ControllerRegister struct { type ControllerRegister struct {
routers map[string]*Tree routers map[string]*Tree
enableFilter bool enableFilter bool
filters map[int][]*FilterRouter filters [FinishRouter + 1][]*FilterRouter
pool sync.Pool pool sync.Pool
} }
@ -122,7 +123,6 @@ type ControllerRegister struct {
func NewControllerRegister() *ControllerRegister { func NewControllerRegister() *ControllerRegister {
cr := &ControllerRegister{ cr := &ControllerRegister{
routers: make(map[string]*Tree), routers: make(map[string]*Tree),
filters: make(map[int][]*FilterRouter),
} }
cr.pool.New = func() interface{} { cr.pool.New = func() interface{} {
return beecontext.NewContext() return beecontext.NewContext()
@ -408,7 +408,6 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface)
// InsertFilter Add a FilterFunc with pattern rule and action constant. // InsertFilter Add a FilterFunc with pattern rule and action constant.
// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) // The bool params is for setting the returnOnOutput value (false allows multiple filters to execute)
func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error {
mr := new(FilterRouter) mr := new(FilterRouter)
mr.tree = NewTree() mr.tree = NewTree()
mr.pattern = pattern mr.pattern = pattern
@ -426,9 +425,13 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter
} }
// add Filter into // add Filter into
func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) error { func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) {
p.filters[pos] = append(p.filters[pos], mr) if pos < BeforeStatic || pos > FinishRouter {
err = fmt.Errorf("can not find your filter postion")
return
}
p.enableFilter = true p.enableFilter = true
p.filters[pos] = append(p.filters[pos], mr)
return nil return nil
} }
@ -437,11 +440,11 @@ func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) error
func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) string { func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) string {
paths := strings.Split(endpoint, ".") paths := strings.Split(endpoint, ".")
if len(paths) <= 1 { if len(paths) <= 1 {
Warn("urlfor endpoint must like path.controller.method") logs.Warn("urlfor endpoint must like path.controller.method")
return "" return ""
} }
if len(values)%2 != 0 { if len(values)%2 != 0 {
Warn("urlfor params must key-value pair") logs.Warn("urlfor params must key-value pair")
return "" return ""
} }
params := make(map[string]string) params := make(map[string]string)
@ -577,20 +580,16 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin
return false, "" return false, ""
} }
func (p *ControllerRegister) execFilter(context *beecontext.Context, pos int, urlPath string) (started bool) { func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath string, pos int) (started bool) {
if p.enableFilter { for _, filterR := range p.filters[pos] {
if l, ok := p.filters[pos]; ok { if filterR.returnOnOutput && context.ResponseWriter.Started {
for _, filterR := range l { return true
if filterR.returnOnOutput && context.ResponseWriter.Started { }
return true if ok := filterR.ValidRouter(urlPath, context); ok {
} filterR.filterFunc(context)
if ok := filterR.ValidRouter(urlPath, context); ok { }
filterR.filterFunc(context) if filterR.returnOnOutput && context.ResponseWriter.Started {
} return true
if filterR.returnOnOutput && context.ResponseWriter.Started {
return true
}
}
} }
} }
return false return false
@ -617,11 +616,10 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
context.Output.Header("Server", BConfig.ServerName) context.Output.Header("Server", BConfig.ServerName)
} }
var urlPath string var urlPath = r.URL.Path
if !BConfig.RouterCaseSensitive { if !BConfig.RouterCaseSensitive {
urlPath = strings.ToLower(r.URL.Path) urlPath = strings.ToLower(urlPath)
} else {
urlPath = r.URL.Path
} }
// filter wrong http method // filter wrong http method
@ -631,11 +629,12 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
} }
// filter for static file // filter for static file
if p.execFilter(context, BeforeStatic, urlPath) { if len(p.filters[BeforeStatic]) > 0 && p.execFilter(context, urlPath, BeforeStatic) {
goto Admin goto Admin
} }
serverStaticRouter(context) serverStaticRouter(context)
if context.ResponseWriter.Started { if context.ResponseWriter.Started {
findRouter = true findRouter = true
goto Admin goto Admin
@ -653,9 +652,9 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
var err error var err error
context.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) context.Input.CruSession, err = GlobalSessions.SessionStart(rw, r)
if err != nil { if err != nil {
Error(err) logs.Error(err)
exception("503", context) exception("503", context)
return goto Admin
} }
defer func() { defer func() {
if context.Input.CruSession != nil { if context.Input.CruSession != nil {
@ -663,8 +662,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
} }
}() }()
} }
if len(p.filters[BeforeRouter]) > 0 && p.execFilter(context, urlPath, BeforeRouter) {
if p.execFilter(context, BeforeRouter, urlPath) {
goto Admin goto Admin
} }
@ -693,7 +691,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
if findRouter { if findRouter {
//execute middleware filters //execute middleware filters
if p.execFilter(context, BeforeExec, urlPath) { if len(p.filters[BeforeExec]) > 0 && p.execFilter(context, urlPath, BeforeExec) {
goto Admin goto Admin
} }
isRunnable := false isRunnable := false
@ -783,7 +781,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
if !context.ResponseWriter.Started && context.Output.Status == 0 { if !context.ResponseWriter.Started && context.Output.Status == 0 {
if BConfig.WebConfig.AutoRender { if BConfig.WebConfig.AutoRender {
if err := execController.Render(); err != nil { if err := execController.Render(); err != nil {
panic(err) logs.Error(err)
} }
} }
} }
@ -794,17 +792,18 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
} }
//execute middleware filters //execute middleware filters
if p.execFilter(context, AfterExec, urlPath) { if len(p.filters[AfterExec]) > 0 && p.execFilter(context, urlPath, AfterExec) {
goto Admin goto Admin
} }
} }
if len(p.filters[FinishRouter]) > 0 && p.execFilter(context, urlPath, FinishRouter) {
p.execFilter(context, FinishRouter, urlPath) goto Admin
}
Admin: Admin:
timeDur := time.Since(startTime)
//admin module record QPS //admin module record QPS
if BConfig.Listen.EnableAdmin { if BConfig.Listen.EnableAdmin {
timeDur := time.Since(startTime)
if FilterMonitorFunc(r.Method, r.URL.Path, timeDur) { if FilterMonitorFunc(r.Method, r.URL.Path, timeDur) {
if runRouter != nil { if runRouter != nil {
go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, runRouter.Name(), timeDur) go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, runRouter.Name(), timeDur)
@ -815,6 +814,7 @@ Admin:
} }
if BConfig.RunMode == DEV || BConfig.Log.AccessLogs { if BConfig.RunMode == DEV || BConfig.Log.AccessLogs {
timeDur := time.Since(startTime)
var devInfo string var devInfo string
if findRouter { if findRouter {
if routerInfo != nil { if routerInfo != nil {
@ -826,7 +826,7 @@ Admin:
devInfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s |", r.Method, r.URL.Path, timeDur.String(), "notmatch") devInfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s |", r.Method, r.URL.Path, timeDur.String(), "notmatch")
} }
if DefaultAccessLogFilter == nil || !DefaultAccessLogFilter.Filter(context) { if DefaultAccessLogFilter == nil || !DefaultAccessLogFilter.Filter(context) {
Debug(devInfo) logs.Debug(devInfo)
} }
} }
@ -843,27 +843,26 @@ func (p *ControllerRegister) recoverPanic(context *beecontext.Context) {
} }
if !BConfig.RecoverPanic { if !BConfig.RecoverPanic {
panic(err) panic(err)
} else { }
if BConfig.EnableErrorsShow { if BConfig.EnableErrorsShow {
if _, ok := ErrorMaps[fmt.Sprint(err)]; ok { if _, ok := ErrorMaps[fmt.Sprint(err)]; ok {
exception(fmt.Sprint(err), context) exception(fmt.Sprint(err), context)
return return
}
} }
var stack string }
Critical("the request url is ", context.Input.URL()) var stack string
Critical("Handler crashed with error", err) logs.Critical("the request url is ", context.Input.URL())
for i := 1; ; i++ { logs.Critical("Handler crashed with error", err)
_, file, line, ok := runtime.Caller(i) for i := 1; ; i++ {
if !ok { _, file, line, ok := runtime.Caller(i)
break if !ok {
} break
Critical(fmt.Sprintf("%s:%d", file, line))
stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line))
}
if BConfig.RunMode == DEV {
showErr(err, context, stack)
} }
logs.Critical(fmt.Sprintf("%s:%d", file, line))
stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line))
}
if BConfig.RunMode == DEV {
showErr(err, context, stack)
} }
} }
} }

View File

@ -21,6 +21,7 @@ import (
"testing" "testing"
"github.com/astaxie/beego/context" "github.com/astaxie/beego/context"
"github.com/astaxie/beego/logs"
) )
type TestController struct { type TestController struct {
@ -94,7 +95,7 @@ func TestUrlFor(t *testing.T) {
handler.Add("/api/list", &TestController{}, "*:List") handler.Add("/api/list", &TestController{}, "*:List")
handler.Add("/person/:last/:first", &TestController{}, "*:Param") handler.Add("/person/:last/:first", &TestController{}, "*:Param")
if a := handler.URLFor("TestController.List"); a != "/api/list" { if a := handler.URLFor("TestController.List"); a != "/api/list" {
Info(a) logs.Info(a)
t.Errorf("TestController.List must equal to /api/list") t.Errorf("TestController.List must equal to /api/list")
} }
if a := handler.URLFor("TestController.Param", ":last", "xie", ":first", "asta"); a != "/person/xie/asta" { if a := handler.URLFor("TestController.Param", ":last", "xie", ":first", "asta"); a != "/person/xie/asta" {
@ -120,24 +121,24 @@ func TestUrlFor2(t *testing.T) {
handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, "*:Param") handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, "*:Param")
handler.Add("/:year:int/:month:int/:title/:entid", &TestController{}) handler.Add("/:year:int/:month:int/:title/:entid", &TestController{})
if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" { if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" {
Info(handler.URLFor("TestController.GetURL")) logs.Info(handler.URLFor("TestController.GetURL"))
t.Errorf("TestController.List must equal to /v1/astaxie/edit") t.Errorf("TestController.List must equal to /v1/astaxie/edit")
} }
if handler.URLFor("TestController.List", ":v", "za", ":id", "12", ":page", "123") != if handler.URLFor("TestController.List", ":v", "za", ":id", "12", ":page", "123") !=
"/v1/za/cms_12_123.html" { "/v1/za/cms_12_123.html" {
Info(handler.URLFor("TestController.List")) logs.Info(handler.URLFor("TestController.List"))
t.Errorf("TestController.List must equal to /v1/za/cms_12_123.html") t.Errorf("TestController.List must equal to /v1/za/cms_12_123.html")
} }
if handler.URLFor("TestController.Param", ":v", "za", ":id", "12", ":page", "123") != if handler.URLFor("TestController.Param", ":v", "za", ":id", "12", ":page", "123") !=
"/v1/za_cms/ttt_12_123.html" { "/v1/za_cms/ttt_12_123.html" {
Info(handler.URLFor("TestController.Param")) logs.Info(handler.URLFor("TestController.Param"))
t.Errorf("TestController.List must equal to /v1/za_cms/ttt_12_123.html") t.Errorf("TestController.List must equal to /v1/za_cms/ttt_12_123.html")
} }
if handler.URLFor("TestController.Get", ":year", "1111", ":month", "11", if handler.URLFor("TestController.Get", ":year", "1111", ":month", "11",
":title", "aaaa", ":entid", "aaaa") != ":title", "aaaa", ":entid", "aaaa") !=
"/1111/11/aaaa/aaaa" { "/1111/11/aaaa/aaaa" {
Info(handler.URLFor("TestController.Get")) logs.Info(handler.URLFor("TestController.Get"))
t.Errorf("TestController.Get must equal to /1111/11/aaaa/aaaa") t.Errorf("TestController.Get must equal to /1111/11/aaaa/aaaa")
} }
} }

View File

@ -115,7 +115,6 @@ func (st *SessionStore) SessionRelease(w http.ResponseWriter) {
} }
st.c.Exec("UPDATE "+TableName+" set `session_data`=?, `session_expiry`=? where session_key=?", st.c.Exec("UPDATE "+TableName+" set `session_data`=?, `session_expiry`=? where session_key=?",
b, time.Now().Unix(), st.sid) b, time.Now().Unix(), st.sid)
} }
// Provider mysql session provider // Provider mysql session provider

View File

@ -16,7 +16,6 @@ package session
import ( import (
"errors" "errors"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
@ -82,14 +81,17 @@ func (fs *FileSessionStore) SessionID() string {
func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) { func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) {
b, err := EncodeGob(fs.values) b, err := EncodeGob(fs.values)
if err != nil { if err != nil {
SLogger.Println(err)
return return
} }
_, err = os.Stat(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid)) _, err = os.Stat(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid))
var f *os.File var f *os.File
if err == nil { if err == nil {
f, err = os.OpenFile(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid), os.O_RDWR, 0777) f, err = os.OpenFile(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid), os.O_RDWR, 0777)
SLogger.Println(err)
} else if os.IsNotExist(err) { } else if os.IsNotExist(err) {
f, err = os.Create(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid)) f, err = os.Create(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid))
SLogger.Println(err)
} else { } else {
return return
} }
@ -123,7 +125,7 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) {
err := os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777) err := os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777)
if err != nil { if err != nil {
println(err.Error()) SLogger.Println(err.Error())
} }
_, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) _, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
var f *os.File var f *os.File
@ -191,7 +193,7 @@ func (fp *FileProvider) SessionAll() int {
return a.visit(path, f, err) return a.visit(path, f, err)
}) })
if err != nil { if err != nil {
fmt.Printf("filepath.Walk() returned %v\n", err) SLogger.Printf("filepath.Walk() returned %v\n", err)
return 0 return 0
} }
return a.total return a.total
@ -205,11 +207,11 @@ func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) {
err := os.MkdirAll(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])), 0777) err := os.MkdirAll(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])), 0777)
if err != nil { if err != nil {
println(err.Error()) SLogger.Println(err.Error())
} }
err = os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777) err = os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777)
if err != nil { if err != nil {
println(err.Error()) SLogger.Println(err.Error())
} }
_, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) _, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
var newf *os.File var newf *os.File

View File

@ -31,9 +31,14 @@ import (
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io"
"log"
"net/http" "net/http"
"net/textproto"
"net/url" "net/url"
"os"
"time" "time"
) )
@ -61,6 +66,9 @@ type Provider interface {
var provides = make(map[string]Provider) var provides = make(map[string]Provider)
// SLogger a helpful variable to log information about session
var SLogger = NewSessionLog(os.Stderr)
// Register makes a session provide available by the provided name. // Register makes a session provide available by the provided name.
// If Register is called twice with the same name or if driver is nil, // If Register is called twice with the same name or if driver is nil,
// it panics. // it panics.
@ -75,15 +83,18 @@ func Register(name string, provide Provider) {
} }
type managerConfig struct { type managerConfig struct {
CookieName string `json:"cookieName"` CookieName string `json:"cookieName"`
EnableSetCookie bool `json:"enableSetCookie,omitempty"` EnableSetCookie bool `json:"enableSetCookie,omitempty"`
Gclifetime int64 `json:"gclifetime"` Gclifetime int64 `json:"gclifetime"`
Maxlifetime int64 `json:"maxLifetime"` Maxlifetime int64 `json:"maxLifetime"`
Secure bool `json:"secure"` Secure bool `json:"secure"`
CookieLifeTime int `json:"cookieLifeTime"` CookieLifeTime int `json:"cookieLifeTime"`
ProviderConfig string `json:"providerConfig"` ProviderConfig string `json:"providerConfig"`
Domain string `json:"domain"` Domain string `json:"domain"`
SessionIDLength int64 `json:"sessionIDLength"` SessionIDLength int64 `json:"sessionIDLength"`
EnableSidInHttpHeader bool `json:"enableSidInHttpHeader"`
SessionNameInHttpHeader string `json:"sessionNameInHttpHeader"`
EnableSidInUrlQuery bool `json:"enableSidInUrlQuery"`
} }
// Manager contains Provider and its configuration. // Manager contains Provider and its configuration.
@ -118,6 +129,19 @@ func NewManager(provideName, config string) (*Manager, error) {
if cf.Maxlifetime == 0 { if cf.Maxlifetime == 0 {
cf.Maxlifetime = cf.Gclifetime cf.Maxlifetime = cf.Gclifetime
} }
if cf.EnableSidInHttpHeader {
if cf.SessionNameInHttpHeader == "" {
panic(errors.New("SessionNameInHttpHeader is empty"))
}
strMimeHeader := textproto.CanonicalMIMEHeaderKey(cf.SessionNameInHttpHeader)
if cf.SessionNameInHttpHeader != strMimeHeader {
strErrMsg := "SessionNameInHttpHeader (" + cf.SessionNameInHttpHeader + ") has the wrong format, it should be like this : " + strMimeHeader
panic(errors.New(strErrMsg))
}
}
err = provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig) err = provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig)
if err != nil { if err != nil {
return nil, err return nil, err
@ -143,12 +167,24 @@ func NewManager(provideName, config string) (*Manager, error) {
func (manager *Manager) getSid(r *http.Request) (string, error) { func (manager *Manager) getSid(r *http.Request) (string, error) {
cookie, errs := r.Cookie(manager.config.CookieName) cookie, errs := r.Cookie(manager.config.CookieName)
if errs != nil || cookie.Value == "" || cookie.MaxAge < 0 { if errs != nil || cookie.Value == "" || cookie.MaxAge < 0 {
errs := r.ParseForm() var sid string
if errs != nil { if manager.config.EnableSidInUrlQuery {
return "", errs errs := r.ParseForm()
if errs != nil {
return "", errs
}
sid = r.FormValue(manager.config.CookieName)
}
// if not found in Cookie / param, then read it from request headers
if manager.config.EnableSidInHttpHeader && sid == "" {
sids, isFound := r.Header[manager.config.SessionNameInHttpHeader]
if isFound && len(sids) != 0 {
return sids[0], nil
}
} }
sid := r.FormValue(manager.config.CookieName)
return sid, nil return sid, nil
} }
@ -192,11 +228,21 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
} }
r.AddCookie(cookie) r.AddCookie(cookie)
if manager.config.EnableSidInHttpHeader {
r.Header.Set(manager.config.SessionNameInHttpHeader, sid)
w.Header().Set(manager.config.SessionNameInHttpHeader, sid)
}
return return
} }
// SessionDestroy Destroy session by its id in http request cookie. // SessionDestroy Destroy session by its id in http request cookie.
func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) {
if manager.config.EnableSidInHttpHeader {
r.Header.Del(manager.config.SessionNameInHttpHeader)
w.Header().Del(manager.config.SessionNameInHttpHeader)
}
cookie, err := r.Cookie(manager.config.CookieName) cookie, err := r.Cookie(manager.config.CookieName)
if err != nil || cookie.Value == "" { if err != nil || cookie.Value == "" {
return return
@ -261,6 +307,12 @@ func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Reque
http.SetCookie(w, cookie) http.SetCookie(w, cookie)
} }
r.AddCookie(cookie) r.AddCookie(cookie)
if manager.config.EnableSidInHttpHeader {
r.Header.Set(manager.config.SessionNameInHttpHeader, sid)
w.Header().Set(manager.config.SessionNameInHttpHeader, sid)
}
return return
} }
@ -296,3 +348,15 @@ func (manager *Manager) isSecure(req *http.Request) bool {
} }
return true return true
} }
// Log implement the log.Logger
type Log struct {
*log.Logger
}
// NewSessionLog set io.Writer to create a Logger for session.
func NewSessionLog(out io.Writer) *Log {
sl := new(Log)
sl.Logger = log.New(out, "[SESSION]", 1e9)
return sl
}

192
session/ssdb/sess_ssdb.go Normal file
View File

@ -0,0 +1,192 @@
package ssdb
import (
"errors"
"net/http"
"strconv"
"strings"
"sync"
"github.com/astaxie/beego/session"
"github.com/ssdb/gossdb/ssdb"
)
var ssdbProvider = &SsdbProvider{}
type SsdbProvider struct {
client *ssdb.Client
host string
port int
maxLifetime int64
}
func (p *SsdbProvider) connectInit() error {
var err error
if p.host == "" || p.port == 0 {
return errors.New("SessionInit First")
}
p.client, err = ssdb.Connect(p.host, p.port)
if err != nil {
return err
}
return nil
}
func (p *SsdbProvider) SessionInit(maxLifetime int64, savePath string) error {
var e error = nil
p.maxLifetime = maxLifetime
address := strings.Split(savePath, ":")
p.host = address[0]
p.port, e = strconv.Atoi(address[1])
if e != nil {
return e
}
err := p.connectInit()
if err != nil {
return err
}
return nil
}
func (p *SsdbProvider) SessionRead(sid string) (session.Store, error) {
if p.client == nil {
if err := p.connectInit(); err != nil {
return nil, err
}
}
var kv map[interface{}]interface{}
value, err := p.client.Get(sid)
if err != nil {
return nil, err
}
if value == nil || len(value.(string)) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob([]byte(value.(string)))
if err != nil {
return nil, err
}
}
rs := &SessionStore{sid: sid, values: kv, maxLifetime: p.maxLifetime, client: p.client}
return rs, nil
}
func (p *SsdbProvider) SessionExist(sid string) bool {
if p.client == nil {
if err := p.connectInit(); err != nil {
panic(err)
}
}
value, err := p.client.Get(sid)
if err != nil {
panic(err)
}
if value == nil || len(value.(string)) == 0 {
return false
}
return true
}
func (p *SsdbProvider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
//conn.Do("setx", key, v, ttl)
if p.client == nil {
if err := p.connectInit(); err != nil {
return nil, err
}
}
value, err := p.client.Get(oldsid)
if err != nil {
return nil, err
}
var kv map[interface{}]interface{}
if value == nil || len(value.(string)) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob([]byte(value.(string)))
if err != nil {
return nil, err
}
_, err = p.client.Del(oldsid)
if err != nil {
return nil, err
}
}
_, e := p.client.Do("setx", sid, value, p.maxLifetime)
if e != nil {
return nil, e
}
rs := &SessionStore{sid: sid, values: kv, maxLifetime: p.maxLifetime, client: p.client}
return rs, nil
}
func (p *SsdbProvider) SessionDestroy(sid string) error {
if p.client == nil {
if err := p.connectInit(); err != nil {
return err
}
}
_, err := p.client.Del(sid)
if err != nil {
return err
}
return nil
}
func (p *SsdbProvider) SessionGC() {
return
}
func (p *SsdbProvider) SessionAll() int {
return 0
}
type SessionStore struct {
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxLifetime int64
client *ssdb.Client
}
func (s *SessionStore) Set(key, value interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
s.values[key] = value
return nil
}
func (s *SessionStore) Get(key interface{}) interface{} {
s.lock.Lock()
defer s.lock.Unlock()
if value, ok := s.values[key]; ok {
return value
}
return nil
}
func (s *SessionStore) Delete(key interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
delete(s.values, key)
return nil
}
func (s *SessionStore) Flush() error {
s.lock.Lock()
defer s.lock.Unlock()
s.values = make(map[interface{}]interface{})
return nil
}
func (s *SessionStore) SessionID() string {
return s.sid
}
func (s *SessionStore) SessionRelease(w http.ResponseWriter) {
b, err := session.EncodeGob(s.values)
if err != nil {
return
}
s.client.Do("setx", s.sid, string(b), s.maxLifetime)
}
func init() {
session.Register("ssdb", ssdbProvider)
}

View File

@ -27,6 +27,7 @@ import (
"time" "time"
"github.com/astaxie/beego/context" "github.com/astaxie/beego/context"
"github.com/astaxie/beego/logs"
) )
var errNotStaticRequest = errors.New("request not a static file request") var errNotStaticRequest = errors.New("request not a static file request")
@ -48,14 +49,19 @@ func serverStaticRouter(ctx *context.Context) {
if filePath == "" || fileInfo == nil { if filePath == "" || fileInfo == nil {
if BConfig.RunMode == DEV { if BConfig.RunMode == DEV {
Warn("Can't find/open the file:", filePath, err) logs.Warn("Can't find/open the file:", filePath, err)
} }
http.NotFound(ctx.ResponseWriter, ctx.Request) http.NotFound(ctx.ResponseWriter, ctx.Request)
return return
} }
if fileInfo.IsDir() { if fileInfo.IsDir() {
//serveFile will list dir requestURL := ctx.Input.URL()
http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath) if requestURL[len(requestURL)-1] != '/' {
ctx.Redirect(302, requestURL+"/")
} else {
//serveFile will list dir
http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath)
}
return return
} }
@ -67,7 +73,7 @@ func serverStaticRouter(ctx *context.Context) {
b, n, sch, err := openFile(filePath, fileInfo, acceptEncoding) b, n, sch, err := openFile(filePath, fileInfo, acceptEncoding)
if err != nil { if err != nil {
if BConfig.RunMode == DEV { if BConfig.RunMode == DEV {
Warn("Can't compress the file:", filePath, err) logs.Warn("Can't compress the file:", filePath, err)
} }
http.NotFound(ctx.ResponseWriter, ctx.Request) http.NotFound(ctx.ResponseWriter, ctx.Request)
return return

View File

@ -1,160 +0,0 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package swagger struct definition
package swagger
// SwaggerVersion show the current swagger version
const SwaggerVersion = "1.2"
// ResourceListing list the resource
type ResourceListing struct {
APIVersion string `json:"apiVersion"`
SwaggerVersion string `json:"swaggerVersion"` // e.g 1.2
// BasePath string `json:"basePath"` obsolete in 1.1
APIs []APIRef `json:"apis"`
Info Information `json:"info"`
}
// APIRef description the api path and description
type APIRef struct {
Path string `json:"path"` // relative or absolute, must start with /
Description string `json:"description"`
}
// Information show the API Information
type Information struct {
Title string `json:"title,omitempty"`
Description string `json:"description,omitempty"`
Contact string `json:"contact,omitempty"`
TermsOfServiceURL string `json:"termsOfServiceUrl,omitempty"`
License string `json:"license,omitempty"`
LicenseURL string `json:"licenseUrl,omitempty"`
}
// APIDeclaration see https://github.com/wordnik/swagger-core/blob/scala_2.10-1.3-RC3/schemas/api-declaration-schema.json
type APIDeclaration struct {
APIVersion string `json:"apiVersion"`
SwaggerVersion string `json:"swaggerVersion"`
BasePath string `json:"basePath"`
ResourcePath string `json:"resourcePath"` // must start with /
Consumes []string `json:"consumes,omitempty"`
Produces []string `json:"produces,omitempty"`
APIs []API `json:"apis,omitempty"`
Models map[string]Model `json:"models,omitempty"`
}
// API show tha API struct
type API struct {
Path string `json:"path"` // relative or absolute, must start with /
Description string `json:"description"`
Operations []Operation `json:"operations,omitempty"`
}
// Operation desc the Operation
type Operation struct {
HTTPMethod string `json:"httpMethod"`
Nickname string `json:"nickname"`
Type string `json:"type"` // in 1.1 = DataType
// ResponseClass string `json:"responseClass"` obsolete in 1.2
Summary string `json:"summary,omitempty"`
Notes string `json:"notes,omitempty"`
Parameters []Parameter `json:"parameters,omitempty"`
ResponseMessages []ResponseMessage `json:"responseMessages,omitempty"` // optional
Consumes []string `json:"consumes,omitempty"`
Produces []string `json:"produces,omitempty"`
Authorizations []Authorization `json:"authorizations,omitempty"`
Protocols []Protocol `json:"protocols,omitempty"`
}
// Protocol support which Protocol
type Protocol struct {
}
// ResponseMessage Show the
type ResponseMessage struct {
Code int `json:"code"`
Message string `json:"message"`
ResponseModel string `json:"responseModel"`
}
// Parameter desc the request parameters
type Parameter struct {
ParamType string `json:"paramType"` // path,query,body,header,form
Name string `json:"name"`
Description string `json:"description"`
DataType string `json:"dataType"` // 1.2 needed?
Type string `json:"type"` // integer
Format string `json:"format"` // int64
AllowMultiple bool `json:"allowMultiple"`
Required bool `json:"required"`
Minimum int `json:"minimum"`
Maximum int `json:"maximum"`
}
// ErrorResponse desc response
type ErrorResponse struct {
Code int `json:"code"`
Reason string `json:"reason"`
}
// Model define the data model
type Model struct {
ID string `json:"id"`
Required []string `json:"required,omitempty"`
Properties map[string]ModelProperty `json:"properties"`
}
// ModelProperty define the properties
type ModelProperty struct {
Type string `json:"type"`
Description string `json:"description"`
Items map[string]string `json:"items,omitempty"`
Format string `json:"format"`
}
// Authorization see https://github.com/wordnik/swagger-core/wiki/authorizations
type Authorization struct {
LocalOAuth OAuth `json:"local-oauth"`
APIKey APIKey `json:"apiKey"`
}
// OAuth see https://github.com/wordnik/swagger-core/wiki/authorizations
type OAuth struct {
Type string `json:"type"` // e.g. oauth2
Scopes []string `json:"scopes"` // e.g. PUBLIC
GrantTypes map[string]GrantType `json:"grantTypes"`
}
// GrantType see https://github.com/wordnik/swagger-core/wiki/authorizations
type GrantType struct {
LoginEndpoint Endpoint `json:"loginEndpoint"`
TokenName string `json:"tokenName"` // e.g. access_code
TokenRequestEndpoint Endpoint `json:"tokenRequestEndpoint"`
TokenEndpoint Endpoint `json:"tokenEndpoint"`
}
// Endpoint see https://github.com/wordnik/swagger-core/wiki/authorizations
type Endpoint struct {
URL string `json:"url"`
ClientIDName string `json:"clientIdName"`
ClientSecretName string `json:"clientSecretName"`
TokenName string `json:"tokenName"`
}
// APIKey see https://github.com/wordnik/swagger-core/wiki/authorizations
type APIKey struct {
Type string `json:"type"` // e.g. apiKey
PassAs string `json:"passAs"` // e.g. header
}

155
swagger/swagger.go Normal file
View File

@ -0,0 +1,155 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Swagger™ is a project used to describe and document RESTful APIs.
//
// The Swagger specification defines a set of files required to describe such an API. These files can then be used by the Swagger-UI project to display the API and Swagger-Codegen to generate clients in various languages. Additional utilities can also take advantage of the resulting files, such as testing tools.
// Now in version 2.0, Swagger is more enabling than ever. And it's 100% open source software.
// Package swagger struct definition
package swagger
// Swagger list the resource
type Swagger struct {
SwaggerVersion string `json:"swagger,omitempty"`
Infos Information `json:"info"`
Host string `json:"host,omitempty"`
BasePath string `json:"basePath,omitempty"`
Schemes []string `json:"schemes,omitempty"`
Consumes []string `json:"consumes,omitempty"`
Produces []string `json:"produces,omitempty"`
Paths map[string]Item `json:"paths"`
Definitions map[string]Schema `json:"definitions,omitempty"`
SecurityDefinitions map[string]Scurity `json:"securityDefinitions,omitempty"`
Security map[string][]string `json:"security,omitempty"`
Tags []Tag `json:"tags,omitempty"`
ExternalDocs ExternalDocs `json:"externalDocs,omitempty"`
}
// Information Provides metadata about the API. The metadata can be used by the clients if needed.
type Information struct {
Title string `json:"title,omitempty"`
Description string `json:"description,omitempty"`
Version string `json:"version,omitempty"`
TermsOfServiceURL string `json:"termsOfServiceUrl,omitempty"`
Contact Contact `json:"contact,omitempty"`
License License `json:"license,omitempty"`
}
// Contact information for the exposed API.
type Contact struct {
Name string `json:"name,omitempty"`
URL string `json:"url,omitempty"`
EMail string `json:"email,omitempty"`
}
// License information for the exposed API.
type License struct {
Name string `json:"name,omitempty"`
URL string `json:"url,omitempty"`
}
// Item Describes the operations available on a single path.
type Item struct {
Ref string `json:"$ref,omitempty"`
Get *Operation `json:"get,omitempty"`
Put *Operation `json:"put,omitempty"`
Post *Operation `json:"post,omitempty"`
Delete *Operation `json:"delete,omitempty"`
Options *Operation `json:"options,omitempty"`
Head *Operation `json:"head,omitempty"`
Patch *Operation `json:"patch,omitempty"`
}
// Operation Describes a single API operation on a path.
type Operation struct {
Tags []string `json:"tags,omitempty"`
Summary string `json:"summary,omitempty"`
Description string `json:"description,omitempty"`
OperationID string `json:"operationId,omitempty"`
Consumes []string `json:"consumes,omitempty"`
Produces []string `json:"produces,omitempty"`
Schemes []string `json:"schemes,omitempty"`
Parameters []Parameter `json:"parameters,omitempty"`
Responses map[string]Response `json:"responses,omitempty"`
Deprecated bool `json:"deprecated,omitempty"`
}
// Parameter Describes a single operation parameter.
type Parameter struct {
In string `json:"in,omitempty"`
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Required bool `json:"required,omitempty"`
Schema Schema `json:"schema,omitempty"`
Type string `json:"type,omitempty"`
Format string `json:"format,omitempty"`
}
// Schema Object allows the definition of input and output data types.
type Schema struct {
Ref string `json:"$ref,omitempty"`
Title string `json:"title,omitempty"`
Format string `json:"format,omitempty"`
Description string `json:"description,omitempty"`
Required []string `json:"required,omitempty"`
Type string `json:"type,omitempty"`
Properties map[string]Propertie `json:"properties,omitempty"`
}
// Propertie are taken from the JSON Schema definition but their definitions were adjusted to the Swagger Specification
type Propertie struct {
Title string `json:"title,omitempty"`
Description string `json:"description,omitempty"`
Default string `json:"default,omitempty"`
Type string `json:"type,omitempty"`
Example string `json:"example,omitempty"`
Required []string `json:"required,omitempty"`
Format string `json:"format,omitempty"`
ReadOnly bool `json:"readOnly,omitempty"`
Properties map[string]Propertie `json:"properties,omitempty"`
}
// Response as they are returned from executing this operation.
type Response struct {
Description string `json:"description,omitempty"`
Schema Schema `json:"schema,omitempty"`
Ref string `json:"$ref,omitempty"`
}
// Scurity Allows the definition of a security scheme that can be used by the operations
type Scurity struct {
Type string `json:"type,omitempty"` // Valid values are "basic", "apiKey" or "oauth2".
Description string `json:"description,omitempty"`
Name string `json:"name,omitempty"`
In string `json:"in,omitempty"` // Valid values are "query" or "header".
Flow string `json:"flow,omitempty"` // Valid values are "implicit", "password", "application" or "accessCode".
AuthorizationURL string `json:"authorizationUrl,omitempty"`
TokenURL string `json:"tokenUrl,omitempty"`
Scopes map[string]string `json:"scopes,omitempty"` // The available scopes for the OAuth2 security scheme.
}
// Tag Allows adding meta data to a single tag that is used by the Operation Object
type Tag struct {
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
ExternalDocs ExternalDocs `json:"externalDocs,omitempty"`
}
// ExternalDocs include Additional external documentation
type ExternalDocs struct {
Description string `json:"description,omitempty"`
URL string `json:"url,omitempty"`
}

View File

@ -26,6 +26,7 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/utils" "github.com/astaxie/beego/utils"
) )
@ -36,19 +37,35 @@ var (
templatesLock sync.RWMutex templatesLock sync.RWMutex
// beeTemplateExt stores the template extension which will build // beeTemplateExt stores the template extension which will build
beeTemplateExt = []string{"tpl", "html"} beeTemplateExt = []string{"tpl", "html"}
// beeTemplatePreprocessors stores associations of extension -> preprocessor handler
beeTemplateEngines = map[string]templatePreProcessor{}
) )
func executeTemplate(wr io.Writer, name string, data interface{}) error { // ExecuteTemplate applies the template with name to the specified data object,
// writing the output to wr.
// A template will be executed safely in parallel.
func ExecuteTemplate(wr io.Writer, name string, data interface{}) error {
if BConfig.RunMode == DEV { if BConfig.RunMode == DEV {
templatesLock.RLock() templatesLock.RLock()
defer templatesLock.RUnlock() defer templatesLock.RUnlock()
} }
if t, ok := beeTemplates[name]; ok { if t, ok := beeTemplates[name]; ok {
err := t.ExecuteTemplate(wr, name, data) if t.Lookup(name) != nil {
if err != nil { err := t.ExecuteTemplate(wr, name, data)
Trace("template Execute err:", err) if err != nil {
logs.Trace("template Execute err:", err)
}
return err
} else {
err := t.Execute(wr, data)
if err != nil {
if err != nil {
logs.Trace("template Execute err:", err)
}
return err
}
} }
return err return nil
} }
panic("can't find templatefile in the path:" + name) panic("can't find templatefile in the path:" + name)
} }
@ -88,6 +105,8 @@ func AddFuncMap(key string, fn interface{}) error {
return nil return nil
} }
type templatePreProcessor func(root, path string, funcs template.FuncMap) (*template.Template, error)
type templateFile struct { type templateFile struct {
root string root string
files map[string][]string files map[string][]string
@ -156,13 +175,22 @@ func BuildTemplate(dir string, files ...string) error {
fmt.Printf("filepath.Walk() returned %v\n", err) fmt.Printf("filepath.Walk() returned %v\n", err)
return err return err
} }
buildAllFiles := len(files) == 0
for _, v := range self.files { for _, v := range self.files {
for _, file := range v { for _, file := range v {
if len(files) == 0 || utils.InSlice(file, files) { if buildAllFiles || utils.InSlice(file, files) {
templatesLock.Lock() templatesLock.Lock()
t, err := getTemplate(self.root, file, v...) ext := filepath.Ext(file)
var t *template.Template
if len(ext) == 0 {
t, err = getTemplate(self.root, file, v...)
} else if fn, ok := beeTemplateEngines[ext[1:]]; ok {
t, err = fn(self.root, file, beegoTplFuncMap)
} else {
t, err = getTemplate(self.root, file, v...)
}
if err != nil { if err != nil {
Trace("parse template err:", file, err) logs.Trace("parse template err:", file, err)
} else { } else {
beeTemplates[file] = t beeTemplates[file] = t
} }
@ -240,7 +268,7 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others
var subMods1 [][]string var subMods1 [][]string
t, subMods1, err = getTplDeep(root, otherFile, "", t) t, subMods1, err = getTplDeep(root, otherFile, "", t)
if err != nil { if err != nil {
Trace("template parse file err:", err) logs.Trace("template parse file err:", err)
} else if subMods1 != nil && len(subMods1) > 0 { } else if subMods1 != nil && len(subMods1) > 0 {
t, err = _getTemplate(t, root, subMods1, others...) t, err = _getTemplate(t, root, subMods1, others...)
} }
@ -261,7 +289,7 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others
var subMods1 [][]string var subMods1 [][]string
t, subMods1, err = getTplDeep(root, otherFile, "", t) t, subMods1, err = getTplDeep(root, otherFile, "", t)
if err != nil { if err != nil {
Trace("template parse file err:", err) logs.Trace("template parse file err:", err)
} else if subMods1 != nil && len(subMods1) > 0 { } else if subMods1 != nil && len(subMods1) > 0 {
t, err = _getTemplate(t, root, subMods1, others...) t, err = _getTemplate(t, root, subMods1, others...)
} }
@ -305,3 +333,9 @@ func DelStaticPath(url string) *App {
delete(BConfig.WebConfig.StaticDir, url) delete(BConfig.WebConfig.StaticDir, url)
return BeeApp return BeeApp
} }
func AddTemplateEngine(extension string, fn templatePreProcessor) *App {
AddTemplateExt(extension)
beeTemplateEngines[extension] = fn
return BeeApp
}

View File

@ -421,18 +421,18 @@ func RenderForm(obj interface{}) template.HTML {
fieldT := objT.Field(i) fieldT := objT.Field(i)
label, name, fType, id, class, ignored := parseFormTag(fieldT) label, name, fType, id, class, ignored, required := parseFormTag(fieldT)
if ignored { if ignored {
continue continue
} }
raw = append(raw, renderFormField(label, name, fType, fieldV.Interface(), id, class)) raw = append(raw, renderFormField(label, name, fType, fieldV.Interface(), id, class, required))
} }
return template.HTML(strings.Join(raw, "</br>")) return template.HTML(strings.Join(raw, "</br>"))
} }
// renderFormField returns a string containing HTML of a single form field. // renderFormField returns a string containing HTML of a single form field.
func renderFormField(label, name, fType string, value interface{}, id string, class string) string { func renderFormField(label, name, fType string, value interface{}, id string, class string, required bool) string {
if id != "" { if id != "" {
id = " id=\"" + id + "\"" id = " id=\"" + id + "\""
} }
@ -441,11 +441,16 @@ func renderFormField(label, name, fType string, value interface{}, id string, cl
class = " class=\"" + class + "\"" class = " class=\"" + class + "\""
} }
if isValidForInput(fType) { requiredString := ""
return fmt.Sprintf(`%v<input%v%v name="%v" type="%v" value="%v">`, label, id, class, name, fType, value) if required {
requiredString = " required"
} }
return fmt.Sprintf(`%v<%v%v%v name="%v">%v</%v>`, label, fType, id, class, name, value, fType) if isValidForInput(fType) {
return fmt.Sprintf(`%v<input%v%v name="%v" type="%v" value="%v"%v>`, label, id, class, name, fType, value, requiredString)
}
return fmt.Sprintf(`%v<%v%v%v name="%v"%v>%v</%v>`, label, fType, id, class, name, requiredString, value, fType)
} }
// isValidForInput checks if fType is a valid value for the `type` property of an HTML input element. // isValidForInput checks if fType is a valid value for the `type` property of an HTML input element.
@ -461,7 +466,7 @@ func isValidForInput(fType string) bool {
// parseFormTag takes the stuct-tag of a StructField and parses the `form` value. // parseFormTag takes the stuct-tag of a StructField and parses the `form` value.
// returned are the form label, name-property, type and wether the field should be ignored. // returned are the form label, name-property, type and wether the field should be ignored.
func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id string, class string, ignored bool) { func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id string, class string, ignored bool, required bool) {
tags := strings.Split(fieldT.Tag.Get("form"), ",") tags := strings.Split(fieldT.Tag.Get("form"), ",")
label = fieldT.Name + ": " label = fieldT.Name + ": "
name = fieldT.Name name = fieldT.Name
@ -470,6 +475,12 @@ func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id str
id = fieldT.Tag.Get("id") id = fieldT.Tag.Get("id")
class = fieldT.Tag.Get("class") class = fieldT.Tag.Get("class")
required = false
required_field := fieldT.Tag.Get("required")
if required_field != "-" && required_field != "" {
required, _ = strconv.ParseBool(required_field)
}
switch len(tags) { switch len(tags) {
case 1: case 1:
if tags[0] == "-" { if tags[0] == "-" {
@ -496,6 +507,7 @@ func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id str
label = tags[2] label = tags[2]
} }
} }
return return
} }

View File

@ -195,54 +195,78 @@ func TestRenderForm(t *testing.T) {
} }
func TestRenderFormField(t *testing.T) { func TestRenderFormField(t *testing.T) {
html := renderFormField("Label: ", "Name", "text", "Value", "", "") html := renderFormField("Label: ", "Name", "text", "Value", "", "", false)
if html != `Label: <input name="Name" type="text" value="Value">` { if html != `Label: <input name="Name" type="text" value="Value">` {
t.Errorf("Wrong html output for input[type=text]: %v ", html) t.Errorf("Wrong html output for input[type=text]: %v ", html)
} }
html = renderFormField("Label: ", "Name", "textarea", "Value", "", "") html = renderFormField("Label: ", "Name", "textarea", "Value", "", "", false)
if html != `Label: <textarea name="Name">Value</textarea>` { if html != `Label: <textarea name="Name">Value</textarea>` {
t.Errorf("Wrong html output for textarea: %v ", html) t.Errorf("Wrong html output for textarea: %v ", html)
} }
html = renderFormField("Label: ", "Name", "textarea", "Value", "", "", true)
if html != `Label: <textarea name="Name" required>Value</textarea>` {
t.Errorf("Wrong html output for textarea: %v ", html)
}
} }
func TestParseFormTag(t *testing.T) { func TestParseFormTag(t *testing.T) {
// create struct to contain field with different types of struct-tag `form` // create struct to contain field with different types of struct-tag `form`
type user struct { type user struct {
All int `form:"name,text,年龄:"` All int `form:"name,text,年龄:"`
NoName int `form:",hidden,年龄:"` NoName int `form:",hidden,年龄:"`
OnlyLabel int `form:",,年龄:"` OnlyLabel int `form:",,年龄:"`
OnlyName int `form:"name" id:"name" class:"form-name"` OnlyName int `form:"name" id:"name" class:"form-name"`
Ignored int `form:"-"` Ignored int `form:"-"`
Required int `form:"name" required:"true"`
IgnoreRequired int `form:"name"`
NotRequired int `form:"name" required:"false"`
} }
objT := reflect.TypeOf(&user{}).Elem() objT := reflect.TypeOf(&user{}).Elem()
label, name, fType, id, class, ignored := parseFormTag(objT.Field(0)) label, name, fType, id, class, ignored, required := parseFormTag(objT.Field(0))
if !(name == "name" && label == "年龄:" && fType == "text" && ignored == false) { if !(name == "name" && label == "年龄:" && fType == "text" && ignored == false) {
t.Errorf("Form Tag with name, label and type was not correctly parsed.") t.Errorf("Form Tag with name, label and type was not correctly parsed.")
} }
label, name, fType, id, class, ignored = parseFormTag(objT.Field(1)) label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(1))
if !(name == "NoName" && label == "年龄:" && fType == "hidden" && ignored == false) { if !(name == "NoName" && label == "年龄:" && fType == "hidden" && ignored == false) {
t.Errorf("Form Tag with label and type but without name was not correctly parsed.") t.Errorf("Form Tag with label and type but without name was not correctly parsed.")
} }
label, name, fType, id, class, ignored = parseFormTag(objT.Field(2)) label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(2))
if !(name == "OnlyLabel" && label == "年龄:" && fType == "text" && ignored == false) { if !(name == "OnlyLabel" && label == "年龄:" && fType == "text" && ignored == false) {
t.Errorf("Form Tag containing only label was not correctly parsed.") t.Errorf("Form Tag containing only label was not correctly parsed.")
} }
label, name, fType, id, class, ignored = parseFormTag(objT.Field(3)) label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(3))
if !(name == "name" && label == "OnlyName: " && fType == "text" && ignored == false && if !(name == "name" && label == "OnlyName: " && fType == "text" && ignored == false &&
id == "name" && class == "form-name") { id == "name" && class == "form-name") {
t.Errorf("Form Tag containing only name was not correctly parsed.") t.Errorf("Form Tag containing only name was not correctly parsed.")
} }
label, name, fType, id, class, ignored = parseFormTag(objT.Field(4)) label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(4))
if ignored == false { if ignored == false {
t.Errorf("Form Tag that should be ignored was not correctly parsed.") t.Errorf("Form Tag that should be ignored was not correctly parsed.")
} }
label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(5))
if !(name == "name" && required == true) {
t.Errorf("Form Tag containing only name and required was not correctly parsed.")
}
label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(6))
if !(name == "name" && required == false) {
t.Errorf("Form Tag containing only name and ignore required was not correctly parsed.")
}
label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(7))
if !(name == "name" && required == false) {
t.Errorf("Form Tag containing only name and not required was not correctly parsed.")
}
} }
func TestMapGet(t *testing.T) { func TestMapGet(t *testing.T) {

View File

@ -389,6 +389,10 @@ func dayMatches(s *Schedule, t time.Time) bool {
// StartTask start all tasks // StartTask start all tasks
func StartTask() { func StartTask() {
if isstart {
//If already started no need to start another goroutine.
return
}
isstart = true isstart = true
go run() go run()
} }
@ -432,8 +436,11 @@ func run() {
// StopTask stop all tasks // StopTask stop all tasks
func StopTask() { func StopTask() {
isstart = false if isstart {
stop <- true isstart = false
stop <- true
}
} }
// AddTask add task with name // AddTask add task with name

View File

@ -69,6 +69,7 @@ import (
"github.com/astaxie/beego" "github.com/astaxie/beego"
"github.com/astaxie/beego/cache" "github.com/astaxie/beego/cache"
"github.com/astaxie/beego/context" "github.com/astaxie/beego/context"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/utils" "github.com/astaxie/beego/utils"
) )
@ -139,7 +140,7 @@ func (c *Captcha) Handler(ctx *context.Context) {
if err := c.store.Put(key, chars, c.Expiration); err != nil { if err := c.store.Put(key, chars, c.Expiration); err != nil {
ctx.Output.SetStatus(500) ctx.Output.SetStatus(500)
ctx.WriteString("captcha reload error") ctx.WriteString("captcha reload error")
beego.Error("Reload Create Captcha Error:", err) logs.Error("Reload Create Captcha Error:", err)
return return
} }
} else { } else {
@ -154,7 +155,7 @@ func (c *Captcha) Handler(ctx *context.Context) {
img := NewImage(chars, c.StdWidth, c.StdHeight) img := NewImage(chars, c.StdWidth, c.StdHeight)
if _, err := img.WriteTo(ctx.ResponseWriter); err != nil { if _, err := img.WriteTo(ctx.ResponseWriter); err != nil {
beego.Error("Write Captcha Image Error:", err) logs.Error("Write Captcha Image Error:", err)
} }
} }
@ -162,7 +163,7 @@ func (c *Captcha) Handler(ctx *context.Context) {
func (c *Captcha) CreateCaptchaHTML() template.HTML { func (c *Captcha) CreateCaptchaHTML() template.HTML {
value, err := c.CreateCaptcha() value, err := c.CreateCaptcha()
if err != nil { if err != nil {
beego.Error("Create Captcha Error:", err) logs.Error("Create Captcha Error:", err)
return "" return ""
} }

View File

@ -18,7 +18,7 @@ import (
"github.com/astaxie/beego/context" "github.com/astaxie/beego/context"
) )
// SetPaginator Instantiates a Paginator and assigns it to context.Input.Data["paginator"]. // SetPaginator Instantiates a Paginator and assigns it to context.Input.Data("paginator").
func SetPaginator(context *context.Context, per int, nums int64) (paginator *Paginator) { func SetPaginator(context *context.Context, per int, nums int64) (paginator *Paginator) {
paginator = NewPaginator(context.Request, per, nums) paginator = NewPaginator(context.Request, per, nums)
context.Input.SetData("paginator", &paginator) context.Input.SetData("paginator", &paginator)

View File

@ -20,28 +20,24 @@ import (
"time" "time"
) )
var alphaNum = []byte(`0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz`)
// RandomCreateBytes generate random []byte by specify chars. // RandomCreateBytes generate random []byte by specify chars.
func RandomCreateBytes(n int, alphabets ...byte) []byte { func RandomCreateBytes(n int, alphabets ...byte) []byte {
const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" if len(alphabets) == 0 {
alphabets = alphaNum
}
var bytes = make([]byte, n) var bytes = make([]byte, n)
var randby bool var randBy bool
if num, err := rand.Read(bytes); num != n || err != nil { if num, err := rand.Read(bytes); num != n || err != nil {
r.Seed(time.Now().UnixNano()) r.Seed(time.Now().UnixNano())
randby = true randBy = true
} }
for i, b := range bytes { for i, b := range bytes {
if len(alphabets) == 0 { if randBy {
if randby { bytes[i] = alphabets[r.Intn(len(alphabets))]
bytes[i] = alphanum[r.Intn(len(alphanum))]
} else {
bytes[i] = alphanum[b%byte(len(alphanum))]
}
} else { } else {
if randby { bytes[i] = alphabets[b%byte(len(alphabets))]
bytes[i] = alphabets[r.Intn(len(alphabets))]
} else {
bytes[i] = alphabets[b%byte(len(alphabets))]
}
} }
} }
return bytes return bytes

View File

@ -1,4 +1,4 @@
// Copyright 2014 beego Author. All Rights Reserved. // Copyright 2016 beego Author. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -12,28 +12,22 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package beego package utils
import ( import "testing"
"github.com/astaxie/beego/context"
)
// GlobalDocAPI store the swagger api documents func TestRand_01(t *testing.T) {
var GlobalDocAPI = make(map[string]interface{}) bs0 := RandomCreateBytes(16)
bs1 := RandomCreateBytes(16)
func serverDocs(ctx *context.Context) { t.Log(string(bs0), string(bs1))
var obj interface{} if string(bs0) == string(bs1) {
if splat := ctx.Input.Param(":splat"); splat == "" { t.FailNow()
obj = GlobalDocAPI["Root"]
} else {
if v, ok := GlobalDocAPI[splat]; ok {
obj = v
}
} }
if obj != nil {
ctx.Output.Header("Access-Control-Allow-Origin", "*") bs0 = RandomCreateBytes(4, []byte(`a`)...)
ctx.Output.JSON(obj, false, false)
return if string(bs0) != "aaaa" {
t.FailNow()
} }
ctx.Output.SetStatus(404)
} }

View File

@ -73,6 +73,10 @@ func (e *Error) String() string {
return e.Message return e.Message
} }
// Implement Error interface.
// Return e.String()
func (e *Error) Error() string { return e.String() }
// Result is returned from every validation method. // Result is returned from every validation method.
// It provides an indication of success, and a pointer to the Error (if any). // It provides an indication of success, and a pointer to the Error (if any).
type Result struct { type Result struct {