1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-26 06:01:29 +00:00

Merge branch 'astaxie/develop' into develop

# Conflicts:
#	parser.go
This commit is contained in:
ysqi 2016-05-05 19:30:31 +08:00
commit 3c8ed9adfc
73 changed files with 2415 additions and 993 deletions

17
.github/ISSUE_TEMPLATE vendored Normal file
View File

@ -0,0 +1,17 @@
Please answer these questions before submitting your issue. Thanks!
1. What version of Go and beego are you using (`bee version`)?
2. What operating system and processor architecture are you using (`go env`)?
3. What did you do?
If possible, provide a recipe for reproducing the error.
A complete runnable program is good.
4. What did you expect to see?
5. What did you see instead?

View File

@ -1,8 +1,7 @@
language: go language: go
go: go:
- tip - 1.6
- 1.6.0
- 1.5.3 - 1.5.3
- 1.4.3 - 1.4.3
services: services:
@ -31,21 +30,20 @@ install:
- go get github.com/belogik/goes - go get github.com/belogik/goes
- go get github.com/siddontang/ledisdb/config - go get github.com/siddontang/ledisdb/config
- go get github.com/siddontang/ledisdb/ledis - go get github.com/siddontang/ledisdb/ledis
- go get golang.org/x/tools/cmd/vet
- go get github.com/golang/lint/golint
- go get github.com/ssdb/gossdb/ssdb - go get github.com/ssdb/gossdb/ssdb
before_script: before_script:
- psql --version
- sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi" - sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi"
- sh -c "if [ '$ORM_DRIVER' = 'mysql' ]; then mysql -u root -e 'create database orm_test;'; fi" - sh -c "if [ '$ORM_DRIVER' = 'mysql' ]; then mysql -u root -e 'create database orm_test;'; fi"
- sh -c "if [ '$ORM_DRIVER' = 'sqlite' ]; then touch $TRAVIS_BUILD_DIR/orm_test.db; fi" - sh -c "if [ '$ORM_DRIVER' = 'sqlite' ]; then touch $TRAVIS_BUILD_DIR/orm_test.db; fi"
- sh -c "if [ $(go version) == *1.[5-9]* ]; then go get github.com/golang/lint/golint; golint ./...; fi"
- sh -c "if [ $(go version) == *1.[5-9]* ]; then go tool vet .; fi"
- mkdir -p res/var - mkdir -p res/var
- ./ssdb/ssdb-server ./ssdb/ssdb.conf -d - ./ssdb/ssdb-server ./ssdb/ssdb.conf -d
after_script: after_script:
-killall -w ssdb-server -killall -w ssdb-server
- rm -rf ./res/var/* - rm -rf ./res/var/*
script: script:
- go vet -x ./...
- $HOME/gopath/bin/golint ./...
- go test -v ./... - go test -v ./...
notifications: addons:
webhooks: https://hooks.pubu.im/services/z7m9bvybl3rgtg9 postgresql: "9.4"

View File

@ -30,7 +30,7 @@ func main(){
``` ```
######Congratulations! ######Congratulations!
You just built your first beego app. You just built your first beego app.
Open your browser and visit `http://localhost:8000`. Open your browser and visit `http://localhost:8080`.
Please see [Documentation](http://beego.me/docs) for more. Please see [Documentation](http://beego.me/docs) for more.
## Features ## Features

View File

@ -23,7 +23,10 @@ import (
"text/template" "text/template"
"time" "time"
"reflect"
"github.com/astaxie/beego/grace" "github.com/astaxie/beego/grace"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/toolbox" "github.com/astaxie/beego/toolbox"
"github.com/astaxie/beego/utils" "github.com/astaxie/beego/utils"
) )
@ -90,57 +93,9 @@ func listConf(rw http.ResponseWriter, r *http.Request) {
switch command { switch command {
case "conf": case "conf":
m := make(map[string]interface{}) m := make(map[string]interface{})
list("BConfig", BConfig, m)
m["AppConfigPath"] = appConfigPath m["AppConfigPath"] = appConfigPath
m["AppConfigProvider"] = appConfigProvider m["AppConfigProvider"] = appConfigProvider
m["BConfig.AppName"] = BConfig.AppName
m["BConfig.RunMode"] = BConfig.RunMode
m["BConfig.RouterCaseSensitive"] = BConfig.RouterCaseSensitive
m["BConfig.ServerName"] = BConfig.ServerName
m["BConfig.RecoverPanic"] = BConfig.RecoverPanic
m["BConfig.CopyRequestBody"] = BConfig.CopyRequestBody
m["BConfig.EnableGzip"] = BConfig.EnableGzip
m["BConfig.MaxMemory"] = BConfig.MaxMemory
m["BConfig.EnableErrorsShow"] = BConfig.EnableErrorsShow
m["BConfig.Listen.Graceful"] = BConfig.Listen.Graceful
m["BConfig.Listen.ServerTimeOut"] = BConfig.Listen.ServerTimeOut
m["BConfig.Listen.ListenTCP4"] = BConfig.Listen.ListenTCP4
m["BConfig.Listen.EnableHTTP"] = BConfig.Listen.EnableHTTP
m["BConfig.Listen.HTTPAddr"] = BConfig.Listen.HTTPAddr
m["BConfig.Listen.HTTPPort"] = BConfig.Listen.HTTPPort
m["BConfig.Listen.EnableHTTPS"] = BConfig.Listen.EnableHTTPS
m["BConfig.Listen.HTTPSAddr"] = BConfig.Listen.HTTPSAddr
m["BConfig.Listen.HTTPSPort"] = BConfig.Listen.HTTPSPort
m["BConfig.Listen.HTTPSCertFile"] = BConfig.Listen.HTTPSCertFile
m["BConfig.Listen.HTTPSKeyFile"] = BConfig.Listen.HTTPSKeyFile
m["BConfig.Listen.EnableAdmin"] = BConfig.Listen.EnableAdmin
m["BConfig.Listen.AdminAddr"] = BConfig.Listen.AdminAddr
m["BConfig.Listen.AdminPort"] = BConfig.Listen.AdminPort
m["BConfig.Listen.EnableFcgi"] = BConfig.Listen.EnableFcgi
m["BConfig.Listen.EnableStdIo"] = BConfig.Listen.EnableStdIo
m["BConfig.WebConfig.AutoRender"] = BConfig.WebConfig.AutoRender
m["BConfig.WebConfig.EnableDocs"] = BConfig.WebConfig.EnableDocs
m["BConfig.WebConfig.FlashName"] = BConfig.WebConfig.FlashName
m["BConfig.WebConfig.FlashSeparator"] = BConfig.WebConfig.FlashSeparator
m["BConfig.WebConfig.DirectoryIndex"] = BConfig.WebConfig.DirectoryIndex
m["BConfig.WebConfig.StaticDir"] = BConfig.WebConfig.StaticDir
m["BConfig.WebConfig.StaticExtensionsToGzip"] = BConfig.WebConfig.StaticExtensionsToGzip
m["BConfig.WebConfig.TemplateLeft"] = BConfig.WebConfig.TemplateLeft
m["BConfig.WebConfig.TemplateRight"] = BConfig.WebConfig.TemplateRight
m["BConfig.WebConfig.ViewsPath"] = BConfig.WebConfig.ViewsPath
m["BConfig.WebConfig.EnableXSRF"] = BConfig.WebConfig.EnableXSRF
m["BConfig.WebConfig.XSRFKEY"] = BConfig.WebConfig.XSRFKey
m["BConfig.WebConfig.XSRFExpire"] = BConfig.WebConfig.XSRFExpire
m["BConfig.WebConfig.Session.SessionOn"] = BConfig.WebConfig.Session.SessionOn
m["BConfig.WebConfig.Session.SessionProvider"] = BConfig.WebConfig.Session.SessionProvider
m["BConfig.WebConfig.Session.SessionName"] = BConfig.WebConfig.Session.SessionName
m["BConfig.WebConfig.Session.SessionGCMaxLifetime"] = BConfig.WebConfig.Session.SessionGCMaxLifetime
m["BConfig.WebConfig.Session.SessionProviderConfig"] = BConfig.WebConfig.Session.SessionProviderConfig
m["BConfig.WebConfig.Session.SessionCookieLifeTime"] = BConfig.WebConfig.Session.SessionCookieLifeTime
m["BConfig.WebConfig.Session.SessionAutoSetCookie"] = BConfig.WebConfig.Session.SessionAutoSetCookie
m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain
m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs
m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum
m["BConfig.Log.Outputs"] = BConfig.Log.Outputs
tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
tmpl = template.Must(tmpl.Parse(configTpl)) tmpl = template.Must(tmpl.Parse(configTpl))
tmpl = template.Must(tmpl.Parse(defaultScriptsTpl)) tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
@ -196,7 +151,7 @@ func listConf(rw http.ResponseWriter, r *http.Request) {
BeforeExec: "Before Exec", BeforeExec: "Before Exec",
AfterExec: "After Exec", AfterExec: "After Exec",
FinishRouter: "Finish Router"} { FinishRouter: "Finish Router"} {
if bf, ok := BeeApp.Handlers.filters[k]; ok { if bf := BeeApp.Handlers.filters[k]; len(bf) > 0 {
filterType = fr filterType = fr
filterTypes = append(filterTypes, filterType) filterTypes = append(filterTypes, filterType)
resultList := new([][]string) resultList := new([][]string)
@ -223,6 +178,28 @@ func listConf(rw http.ResponseWriter, r *http.Request) {
} }
} }
func list(root string, p interface{}, m map[string]interface{}) {
pt := reflect.TypeOf(p)
pv := reflect.ValueOf(p)
if pt.Kind() == reflect.Ptr {
pt = pt.Elem()
pv = pv.Elem()
}
for i := 0; i < pv.NumField(); i++ {
var key string
if root == "" {
key = pt.Field(i).Name
} else {
key = root + "." + pt.Field(i).Name
}
if pv.Field(i).Kind() == reflect.Struct {
list(key, pv.Field(i).Interface(), m)
} else {
m[key] = pv.Field(i).Interface()
}
}
}
func printTree(resultList *[][]string, t *Tree) { func printTree(resultList *[][]string, t *Tree) {
for _, tr := range t.fixrouters { for _, tr := range t.fixrouters {
printTree(resultList, tr) printTree(resultList, tr)
@ -410,7 +387,7 @@ func (admin *adminApp) Run() {
for p, f := range admin.routers { for p, f := range admin.routers {
http.Handle(p, f) http.Handle(p, f)
} }
BeeLogger.Info("Admin server Running on %s", addr) logs.Info("Admin server Running on %s", addr)
var err error var err error
if BConfig.Listen.Graceful { if BConfig.Listen.Graceful {
@ -419,6 +396,6 @@ func (admin *adminApp) Run() {
err = http.ListenAndServe(addr, nil) err = http.ListenAndServe(addr, nil)
} }
if err != nil { if err != nil {
BeeLogger.Critical("Admin ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) logs.Critical("Admin ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid()))
} }
} }

72
admin_test.go Normal file
View File

@ -0,0 +1,72 @@
package beego
import (
"testing"
"fmt"
)
func TestList_01(t *testing.T) {
m := make(map[string]interface{})
list("BConfig", BConfig, m)
t.Log(m)
om := oldMap()
for k, v := range om {
if fmt.Sprint(m[k])!= fmt.Sprint(v) {
t.Log(k, "old-key",v,"new-key", m[k])
t.FailNow()
}
}
}
func oldMap() map[string]interface{} {
m := make(map[string]interface{})
m["BConfig.AppName"] = BConfig.AppName
m["BConfig.RunMode"] = BConfig.RunMode
m["BConfig.RouterCaseSensitive"] = BConfig.RouterCaseSensitive
m["BConfig.ServerName"] = BConfig.ServerName
m["BConfig.RecoverPanic"] = BConfig.RecoverPanic
m["BConfig.CopyRequestBody"] = BConfig.CopyRequestBody
m["BConfig.EnableGzip"] = BConfig.EnableGzip
m["BConfig.MaxMemory"] = BConfig.MaxMemory
m["BConfig.EnableErrorsShow"] = BConfig.EnableErrorsShow
m["BConfig.Listen.Graceful"] = BConfig.Listen.Graceful
m["BConfig.Listen.ServerTimeOut"] = BConfig.Listen.ServerTimeOut
m["BConfig.Listen.ListenTCP4"] = BConfig.Listen.ListenTCP4
m["BConfig.Listen.EnableHTTP"] = BConfig.Listen.EnableHTTP
m["BConfig.Listen.HTTPAddr"] = BConfig.Listen.HTTPAddr
m["BConfig.Listen.HTTPPort"] = BConfig.Listen.HTTPPort
m["BConfig.Listen.EnableHTTPS"] = BConfig.Listen.EnableHTTPS
m["BConfig.Listen.HTTPSAddr"] = BConfig.Listen.HTTPSAddr
m["BConfig.Listen.HTTPSPort"] = BConfig.Listen.HTTPSPort
m["BConfig.Listen.HTTPSCertFile"] = BConfig.Listen.HTTPSCertFile
m["BConfig.Listen.HTTPSKeyFile"] = BConfig.Listen.HTTPSKeyFile
m["BConfig.Listen.EnableAdmin"] = BConfig.Listen.EnableAdmin
m["BConfig.Listen.AdminAddr"] = BConfig.Listen.AdminAddr
m["BConfig.Listen.AdminPort"] = BConfig.Listen.AdminPort
m["BConfig.Listen.EnableFcgi"] = BConfig.Listen.EnableFcgi
m["BConfig.Listen.EnableStdIo"] = BConfig.Listen.EnableStdIo
m["BConfig.WebConfig.AutoRender"] = BConfig.WebConfig.AutoRender
m["BConfig.WebConfig.EnableDocs"] = BConfig.WebConfig.EnableDocs
m["BConfig.WebConfig.FlashName"] = BConfig.WebConfig.FlashName
m["BConfig.WebConfig.FlashSeparator"] = BConfig.WebConfig.FlashSeparator
m["BConfig.WebConfig.DirectoryIndex"] = BConfig.WebConfig.DirectoryIndex
m["BConfig.WebConfig.StaticDir"] = BConfig.WebConfig.StaticDir
m["BConfig.WebConfig.StaticExtensionsToGzip"] = BConfig.WebConfig.StaticExtensionsToGzip
m["BConfig.WebConfig.TemplateLeft"] = BConfig.WebConfig.TemplateLeft
m["BConfig.WebConfig.TemplateRight"] = BConfig.WebConfig.TemplateRight
m["BConfig.WebConfig.ViewsPath"] = BConfig.WebConfig.ViewsPath
m["BConfig.WebConfig.EnableXSRF"] = BConfig.WebConfig.EnableXSRF
m["BConfig.WebConfig.XSRFExpire"] = BConfig.WebConfig.XSRFExpire
m["BConfig.WebConfig.Session.SessionOn"] = BConfig.WebConfig.Session.SessionOn
m["BConfig.WebConfig.Session.SessionProvider"] = BConfig.WebConfig.Session.SessionProvider
m["BConfig.WebConfig.Session.SessionName"] = BConfig.WebConfig.Session.SessionName
m["BConfig.WebConfig.Session.SessionGCMaxLifetime"] = BConfig.WebConfig.Session.SessionGCMaxLifetime
m["BConfig.WebConfig.Session.SessionProviderConfig"] = BConfig.WebConfig.Session.SessionProviderConfig
m["BConfig.WebConfig.Session.SessionCookieLifeTime"] = BConfig.WebConfig.Session.SessionCookieLifeTime
m["BConfig.WebConfig.Session.SessionAutoSetCookie"] = BConfig.WebConfig.Session.SessionAutoSetCookie
m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain
m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs
m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum
m["BConfig.Log.Outputs"] = BConfig.Log.Outputs
return m
}

30
app.go
View File

@ -24,6 +24,7 @@ import (
"time" "time"
"github.com/astaxie/beego/grace" "github.com/astaxie/beego/grace"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/utils" "github.com/astaxie/beego/utils"
) )
@ -68,9 +69,9 @@ func (app *App) Run() {
if BConfig.Listen.EnableFcgi { if BConfig.Listen.EnableFcgi {
if BConfig.Listen.EnableStdIo { if BConfig.Listen.EnableStdIo {
if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O
BeeLogger.Info("Use FCGI via standard I/O") logs.Info("Use FCGI via standard I/O")
} else { } else {
BeeLogger.Critical("Cannot use FCGI via standard I/O", err) logs.Critical("Cannot use FCGI via standard I/O", err)
} }
return return
} }
@ -84,10 +85,10 @@ func (app *App) Run() {
l, err = net.Listen("tcp", addr) l, err = net.Listen("tcp", addr)
} }
if err != nil { if err != nil {
BeeLogger.Critical("Listen: ", err) logs.Critical("Listen: ", err)
} }
if err = fcgi.Serve(l, app.Handlers); err != nil { if err = fcgi.Serve(l, app.Handlers); err != nil {
BeeLogger.Critical("fcgi.Serve: ", err) logs.Critical("fcgi.Serve: ", err)
} }
return return
} }
@ -95,6 +96,7 @@ func (app *App) Run() {
app.Server.Handler = app.Handlers app.Server.Handler = app.Handlers
app.Server.ReadTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second app.Server.ReadTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second
app.Server.WriteTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second app.Server.WriteTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second
app.Server.ErrorLog = logs.GetLogger("HTTP")
// run graceful mode // run graceful mode
if BConfig.Listen.Graceful { if BConfig.Listen.Graceful {
@ -111,7 +113,7 @@ func (app *App) Run() {
server.Server.ReadTimeout = app.Server.ReadTimeout server.Server.ReadTimeout = app.Server.ReadTimeout
server.Server.WriteTimeout = app.Server.WriteTimeout server.Server.WriteTimeout = app.Server.WriteTimeout
if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil {
BeeLogger.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond) time.Sleep(100 * time.Microsecond)
endRunning <- true endRunning <- true
} }
@ -126,7 +128,7 @@ func (app *App) Run() {
server.Network = "tcp4" server.Network = "tcp4"
} }
if err := server.ListenAndServe(); err != nil { if err := server.ListenAndServe(); err != nil {
BeeLogger.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) logs.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond) time.Sleep(100 * time.Microsecond)
endRunning <- true endRunning <- true
} }
@ -137,16 +139,18 @@ func (app *App) Run() {
} }
// run normal mode // run normal mode
app.Server.Addr = addr
if BConfig.Listen.EnableHTTPS { if BConfig.Listen.EnableHTTPS {
go func() { go func() {
time.Sleep(20 * time.Microsecond) time.Sleep(20 * time.Microsecond)
if BConfig.Listen.HTTPSPort != 0 { if BConfig.Listen.HTTPSPort != 0 {
app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort)
} else if BConfig.Listen.EnableHTTP {
BeeLogger.Info("Start https server error, confict with http.Please reset https port")
return
} }
BeeLogger.Info("https server Running on %s", app.Server.Addr) logs.Info("https server Running on %s", app.Server.Addr)
if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil {
BeeLogger.Critical("ListenAndServeTLS: ", err) logs.Critical("ListenAndServeTLS: ", err)
time.Sleep(100 * time.Microsecond) time.Sleep(100 * time.Microsecond)
endRunning <- true endRunning <- true
} }
@ -155,24 +159,24 @@ func (app *App) Run() {
if BConfig.Listen.EnableHTTP { if BConfig.Listen.EnableHTTP {
go func() { go func() {
app.Server.Addr = addr app.Server.Addr = addr
BeeLogger.Info("http server Running on %s", app.Server.Addr) logs.Info("http server Running on %s", app.Server.Addr)
if BConfig.Listen.ListenTCP4 { if BConfig.Listen.ListenTCP4 {
ln, err := net.Listen("tcp4", app.Server.Addr) ln, err := net.Listen("tcp4", app.Server.Addr)
if err != nil { if err != nil {
BeeLogger.Critical("ListenAndServe: ", err) logs.Critical("ListenAndServe: ", err)
time.Sleep(100 * time.Microsecond) time.Sleep(100 * time.Microsecond)
endRunning <- true endRunning <- true
return return
} }
if err = app.Server.Serve(ln); err != nil { if err = app.Server.Serve(ln); err != nil {
BeeLogger.Critical("ListenAndServe: ", err) logs.Critical("ListenAndServe: ", err)
time.Sleep(100 * time.Microsecond) time.Sleep(100 * time.Microsecond)
endRunning <- true endRunning <- true
return return
} }
} else { } else {
if err := app.Server.ListenAndServe(); err != nil { if err := app.Server.ListenAndServe(); err != nil {
BeeLogger.Critical("ListenAndServe: ", err) logs.Critical("ListenAndServe: ", err)
time.Sleep(100 * time.Microsecond) time.Sleep(100 * time.Microsecond)
endRunning <- true endRunning <- true
} }

View File

@ -51,6 +51,7 @@ func AddAPPStartHook(hf hookfunc) {
// beego.Run(":8089") // beego.Run(":8089")
// beego.Run("127.0.0.1:8089") // beego.Run("127.0.0.1:8089")
func Run(params ...string) { func Run(params ...string) {
initBeforeHTTPRun() initBeforeHTTPRun()
if len(params) > 0 && params[0] != "" { if len(params) > 0 && params[0] != "" {
@ -71,9 +72,9 @@ func initBeforeHTTPRun() {
AddAPPStartHook(registerMime) AddAPPStartHook(registerMime)
AddAPPStartHook(registerDefaultErrorHandler) AddAPPStartHook(registerDefaultErrorHandler)
AddAPPStartHook(registerSession) AddAPPStartHook(registerSession)
AddAPPStartHook(registerDocs)
AddAPPStartHook(registerTemplate) AddAPPStartHook(registerTemplate)
AddAPPStartHook(registerAdmin) AddAPPStartHook(registerAdmin)
AddAPPStartHook(registerGzip)
for _, hk := range hooks { for _, hk := range hooks {
if err := hk(); err != nil { if err := hk(); err != nil {
@ -84,8 +85,11 @@ func initBeforeHTTPRun() {
// TestBeegoInit is for test package init // TestBeegoInit is for test package init
func TestBeegoInit(ap string) { func TestBeegoInit(ap string) {
os.Setenv("BEEGO_RUNMODE", "test")
appConfigPath = filepath.Join(ap, "conf", "app.conf") appConfigPath = filepath.Join(ap, "conf", "app.conf")
os.Chdir(ap) os.Chdir(ap)
if err := LoadAppConfig(appConfigProvider, appConfigPath); err != nil {
panic(err)
}
BConfig.RunMode = "test"
initBeforeHTTPRun() initBeforeHTTPRun()
} }

View File

@ -18,9 +18,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/garyburd/redigo/redis"
"github.com/astaxie/beego/cache" "github.com/astaxie/beego/cache"
"github.com/garyburd/redigo/redis"
) )
func TestRedisCache(t *testing.T) { func TestRedisCache(t *testing.T) {

View File

@ -1,10 +1,11 @@
package ssdb package ssdb
import ( import (
"github.com/astaxie/beego/cache"
"strconv" "strconv"
"testing" "testing"
"time" "time"
"github.com/astaxie/beego/cache"
) )
func TestSsdbcacheCache(t *testing.T) { func TestSsdbcacheCache(t *testing.T) {

155
config.go
View File

@ -18,9 +18,11 @@ import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"strings" "strings"
"github.com/astaxie/beego/config" "github.com/astaxie/beego/config"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/session" "github.com/astaxie/beego/session"
"github.com/astaxie/beego/utils" "github.com/astaxie/beego/utils"
) )
@ -89,6 +91,9 @@ type SessionConfig struct {
SessionCookieLifeTime int SessionCookieLifeTime int
SessionAutoSetCookie bool SessionAutoSetCookie bool
SessionDomain string SessionDomain string
EnableSidInHttpHeader bool // enable store/get the sessionId into/from http headers
SessionNameInHttpHeader string
EnableSidInUrlQuery bool // enable get the sessionId from Url Query params
} }
// LogConfig holds Log related config // LogConfig holds Log related config
@ -115,11 +120,30 @@ var (
) )
func init() { func init() {
AppPath, _ = filepath.Abs(filepath.Dir(os.Args[0])) BConfig = newBConfig()
var err error
if AppPath, err = filepath.Abs(filepath.Dir(os.Args[0])); err != nil {
panic(err)
}
workPath, err := os.Getwd()
if err != nil {
panic(err)
}
appConfigPath = filepath.Join(workPath, "conf", "app.conf")
if !utils.FileExists(appConfigPath) {
appConfigPath = filepath.Join(AppPath, "conf", "app.conf")
if !utils.FileExists(appConfigPath) {
AppConfig = &beegoAppConfig{innerConfig: config.NewFakeConfig()}
return
}
}
if err = parseConfig(appConfigPath); err != nil {
panic(err)
}
}
os.Chdir(AppPath) func newBConfig() *Config {
return &Config{
BConfig = &Config{
AppName: "beego", AppName: "beego",
RunMode: DEV, RunMode: DEV,
RouterCaseSensitive: true, RouterCaseSensitive: true,
@ -170,6 +194,9 @@ func init() {
SessionCookieLifeTime: 0, //set cookie default is the browser life SessionCookieLifeTime: 0, //set cookie default is the browser life
SessionAutoSetCookie: true, SessionAutoSetCookie: true,
SessionDomain: "", SessionDomain: "",
EnableSidInHttpHeader: false, // enable store/get the sessionId into/from http headers
SessionNameInHttpHeader: "Beegosessionid",
EnableSidInUrlQuery: false, // enable get the sessionId from Url Query params
}, },
}, },
Log: LogConfig{ Log: LogConfig{
@ -178,16 +205,6 @@ func init() {
Outputs: map[string]string{"console": ""}, Outputs: map[string]string{"console": ""},
}, },
} }
appConfigPath = filepath.Join(AppPath, "conf", "app.conf")
if !utils.FileExists(appConfigPath) {
AppConfig = &beegoAppConfig{innerConfig: config.NewFakeConfig()}
return
}
if err := parseConfig(appConfigPath); err != nil {
panic(err)
}
} }
// now only support ini, next will support json. // now only support ini, next will support json.
@ -196,63 +213,23 @@ func parseConfig(appConfigPath string) (err error) {
if err != nil { if err != nil {
return err return err
} }
return assignConfig(AppConfig)
}
func assignConfig(ac config.Configer) error {
// set the run mode first // set the run mode first
if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" { if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" {
BConfig.RunMode = envRunMode BConfig.RunMode = envRunMode
} else if runMode := AppConfig.String("RunMode"); runMode != "" { } else if runMode := ac.String("RunMode"); runMode != "" {
BConfig.RunMode = runMode BConfig.RunMode = runMode
} }
BConfig.AppName = AppConfig.DefaultString("AppName", BConfig.AppName) for _, i := range []interface{}{BConfig, &BConfig.Listen, &BConfig.WebConfig, &BConfig.Log, &BConfig.WebConfig.Session} {
BConfig.RecoverPanic = AppConfig.DefaultBool("RecoverPanic", BConfig.RecoverPanic) assignSingleConfig(i, ac)
BConfig.RouterCaseSensitive = AppConfig.DefaultBool("RouterCaseSensitive", BConfig.RouterCaseSensitive)
BConfig.ServerName = AppConfig.DefaultString("ServerName", BConfig.ServerName)
BConfig.EnableGzip = AppConfig.DefaultBool("EnableGzip", BConfig.EnableGzip)
BConfig.EnableErrorsShow = AppConfig.DefaultBool("EnableErrorsShow", BConfig.EnableErrorsShow)
BConfig.CopyRequestBody = AppConfig.DefaultBool("CopyRequestBody", BConfig.CopyRequestBody)
BConfig.MaxMemory = AppConfig.DefaultInt64("MaxMemory", BConfig.MaxMemory)
BConfig.Listen.Graceful = AppConfig.DefaultBool("Graceful", BConfig.Listen.Graceful)
BConfig.Listen.HTTPAddr = AppConfig.String("HTTPAddr")
BConfig.Listen.HTTPPort = AppConfig.DefaultInt("HTTPPort", BConfig.Listen.HTTPPort)
BConfig.Listen.ListenTCP4 = AppConfig.DefaultBool("ListenTCP4", BConfig.Listen.ListenTCP4)
BConfig.Listen.EnableHTTP = AppConfig.DefaultBool("EnableHTTP", BConfig.Listen.EnableHTTP)
BConfig.Listen.EnableHTTPS = AppConfig.DefaultBool("EnableHTTPS", BConfig.Listen.EnableHTTPS)
BConfig.Listen.HTTPSAddr = AppConfig.DefaultString("HTTPSAddr", BConfig.Listen.HTTPSAddr)
BConfig.Listen.HTTPSPort = AppConfig.DefaultInt("HTTPSPort", BConfig.Listen.HTTPSPort)
BConfig.Listen.HTTPSCertFile = AppConfig.DefaultString("HTTPSCertFile", BConfig.Listen.HTTPSCertFile)
BConfig.Listen.HTTPSKeyFile = AppConfig.DefaultString("HTTPSKeyFile", BConfig.Listen.HTTPSKeyFile)
BConfig.Listen.EnableAdmin = AppConfig.DefaultBool("EnableAdmin", BConfig.Listen.EnableAdmin)
BConfig.Listen.AdminAddr = AppConfig.DefaultString("AdminAddr", BConfig.Listen.AdminAddr)
BConfig.Listen.AdminPort = AppConfig.DefaultInt("AdminPort", BConfig.Listen.AdminPort)
BConfig.Listen.EnableFcgi = AppConfig.DefaultBool("EnableFcgi", BConfig.Listen.EnableFcgi)
BConfig.Listen.EnableStdIo = AppConfig.DefaultBool("EnableStdIo", BConfig.Listen.EnableStdIo)
BConfig.Listen.ServerTimeOut = AppConfig.DefaultInt64("ServerTimeOut", BConfig.Listen.ServerTimeOut)
BConfig.WebConfig.AutoRender = AppConfig.DefaultBool("AutoRender", BConfig.WebConfig.AutoRender)
BConfig.WebConfig.ViewsPath = AppConfig.DefaultString("ViewsPath", BConfig.WebConfig.ViewsPath)
BConfig.WebConfig.DirectoryIndex = AppConfig.DefaultBool("DirectoryIndex", BConfig.WebConfig.DirectoryIndex)
BConfig.WebConfig.FlashName = AppConfig.DefaultString("FlashName", BConfig.WebConfig.FlashName)
BConfig.WebConfig.FlashSeparator = AppConfig.DefaultString("FlashSeparator", BConfig.WebConfig.FlashSeparator)
BConfig.WebConfig.EnableDocs = AppConfig.DefaultBool("EnableDocs", BConfig.WebConfig.EnableDocs)
BConfig.WebConfig.XSRFKey = AppConfig.DefaultString("XSRFKEY", BConfig.WebConfig.XSRFKey)
BConfig.WebConfig.EnableXSRF = AppConfig.DefaultBool("EnableXSRF", BConfig.WebConfig.EnableXSRF)
BConfig.WebConfig.XSRFExpire = AppConfig.DefaultInt("XSRFExpire", BConfig.WebConfig.XSRFExpire)
BConfig.WebConfig.TemplateLeft = AppConfig.DefaultString("TemplateLeft", BConfig.WebConfig.TemplateLeft)
BConfig.WebConfig.TemplateRight = AppConfig.DefaultString("TemplateRight", BConfig.WebConfig.TemplateRight)
BConfig.WebConfig.Session.SessionOn = AppConfig.DefaultBool("SessionOn", BConfig.WebConfig.Session.SessionOn)
BConfig.WebConfig.Session.SessionProvider = AppConfig.DefaultString("SessionProvider", BConfig.WebConfig.Session.SessionProvider)
BConfig.WebConfig.Session.SessionName = AppConfig.DefaultString("SessionName", BConfig.WebConfig.Session.SessionName)
BConfig.WebConfig.Session.SessionProviderConfig = AppConfig.DefaultString("SessionProviderConfig", BConfig.WebConfig.Session.SessionProviderConfig)
BConfig.WebConfig.Session.SessionGCMaxLifetime = AppConfig.DefaultInt64("SessionGCMaxLifetime", BConfig.WebConfig.Session.SessionGCMaxLifetime)
BConfig.WebConfig.Session.SessionCookieLifeTime = AppConfig.DefaultInt("SessionCookieLifeTime", BConfig.WebConfig.Session.SessionCookieLifeTime)
BConfig.WebConfig.Session.SessionAutoSetCookie = AppConfig.DefaultBool("SessionAutoSetCookie", BConfig.WebConfig.Session.SessionAutoSetCookie)
BConfig.WebConfig.Session.SessionDomain = AppConfig.DefaultString("SessionDomain", BConfig.WebConfig.Session.SessionDomain)
BConfig.Log.AccessLogs = AppConfig.DefaultBool("LogAccessLogs", BConfig.Log.AccessLogs)
BConfig.Log.FileLineNum = AppConfig.DefaultBool("LogFileLineNum", BConfig.Log.FileLineNum)
if sd := AppConfig.String("StaticDir"); sd != "" {
for k := range BConfig.WebConfig.StaticDir {
delete(BConfig.WebConfig.StaticDir, k)
} }
if sd := ac.String("StaticDir"); sd != "" {
BConfig.WebConfig.StaticDir = map[string]string{}
sds := strings.Fields(sd) sds := strings.Fields(sd)
for _, v := range sds { for _, v := range sds {
if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 { if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 {
@ -263,7 +240,7 @@ func parseConfig(appConfigPath string) (err error) {
} }
} }
if sgz := AppConfig.String("StaticExtensionsToGzip"); sgz != "" { if sgz := ac.String("StaticExtensionsToGzip"); sgz != "" {
extensions := strings.Split(sgz, ",") extensions := strings.Split(sgz, ",")
fileExts := []string{} fileExts := []string{}
for _, ext := range extensions { for _, ext := range extensions {
@ -281,7 +258,7 @@ func parseConfig(appConfigPath string) (err error) {
} }
} }
if lo := AppConfig.String("LogOutputs"); lo != "" { if lo := ac.String("LogOutputs"); lo != "" {
los := strings.Split(lo, ";") los := strings.Split(lo, ";")
for _, v := range los { for _, v := range los {
if logType2Config := strings.SplitN(v, ",", 2); len(logType2Config) == 2 { if logType2Config := strings.SplitN(v, ",", 2); len(logType2Config) == 2 {
@ -293,18 +270,50 @@ func parseConfig(appConfigPath string) (err error) {
} }
//init log //init log
BeeLogger.Reset() logs.Reset()
for adaptor, config := range BConfig.Log.Outputs { for adaptor, config := range BConfig.Log.Outputs {
err = BeeLogger.SetLogger(adaptor, config) err := logs.SetLogger(adaptor, config)
if err != nil { if err != nil {
fmt.Printf("%s with the config `%s` got err:%s\n", adaptor, config, err) fmt.Fprintln(os.Stderr, fmt.Sprintf("%s with the config %q got err:%s", adaptor, config, err.Error()))
} }
} }
SetLogFuncCall(BConfig.Log.FileLineNum) logs.SetLogFuncCall(BConfig.Log.FileLineNum)
return nil return nil
} }
func assignSingleConfig(p interface{}, ac config.Configer) {
pt := reflect.TypeOf(p)
if pt.Kind() != reflect.Ptr {
return
}
pt = pt.Elem()
if pt.Kind() != reflect.Struct {
return
}
pv := reflect.ValueOf(p).Elem()
for i := 0; i < pt.NumField(); i++ {
pf := pv.Field(i)
if !pf.CanSet() {
continue
}
name := pt.Field(i).Name
switch pf.Kind() {
case reflect.String:
pf.SetString(ac.DefaultString(name, pf.String()))
case reflect.Int, reflect.Int64:
pf.SetInt(int64(ac.DefaultInt64(name, pf.Int())))
case reflect.Bool:
pf.SetBool(ac.DefaultBool(name, pf.Bool()))
case reflect.Struct:
default:
//do nothing here
}
}
}
// LoadAppConfig allow developer to apply a config file // LoadAppConfig allow developer to apply a config file
func LoadAppConfig(adapterName, configPath string) error { func LoadAppConfig(adapterName, configPath string) error {
absConfigPath, err := filepath.Abs(configPath) absConfigPath, err := filepath.Abs(configPath)
@ -316,10 +325,6 @@ func LoadAppConfig(adapterName, configPath string) error {
return fmt.Errorf("the target config file: %s don't exist", configPath) return fmt.Errorf("the target config file: %s don't exist", configPath)
} }
if absConfigPath == appConfigPath {
return nil
}
appConfigPath = absConfigPath appConfigPath = absConfigPath
appConfigProvider = adapterName appConfigProvider = adapterName
@ -353,7 +358,7 @@ func (b *beegoAppConfig) String(key string) string {
} }
func (b *beegoAppConfig) Strings(key string) []string { func (b *beegoAppConfig) Strings(key string) []string {
if v := b.innerConfig.Strings(BConfig.RunMode + "::" + key); v[0] != "" { if v := b.innerConfig.Strings(BConfig.RunMode + "::" + key); len(v) > 0 {
return v return v
} }
return b.innerConfig.Strings(key) return b.innerConfig.Strings(key)

View File

@ -12,11 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Package config is used to parse config // Package config is used to parse config.
// Usage: // Usage:
// import( // import "github.com/astaxie/beego/config"
// "github.com/astaxie/beego/config" //Examples.
// )
// //
// cnf, err := config.NewConfig("ini", "config.conf") // cnf, err := config.NewConfig("ini", "config.conf")
// //
@ -38,12 +37,12 @@
// cnf.DIY(key string) (interface{}, error) // cnf.DIY(key string) (interface{}, error)
// cnf.GetSection(section string) (map[string]string, error) // cnf.GetSection(section string) (map[string]string, error)
// cnf.SaveConfigFile(filename string) error // cnf.SaveConfigFile(filename string) error
// //More docs http://beego.me/docs/module/config.md
// more docs http://beego.me/docs/module/config.md
package config package config
import ( import (
"fmt" "fmt"
"os"
) )
// Configer defines how to get and set value from configuration raw data. // Configer defines how to get and set value from configuration raw data.
@ -107,6 +106,69 @@ func NewConfigData(adapterName string, data []byte) (Configer, error) {
return adapter.ParseData(data) return adapter.ParseData(data)
} }
// ExpandValueEnvForMap convert all string value with environment variable.
func ExpandValueEnvForMap(m map[string]interface{}) map[string]interface{} {
for k, v := range m {
switch value := v.(type) {
case string:
m[k] = ExpandValueEnv(value)
case map[string]interface{}:
m[k] = ExpandValueEnvForMap(value)
case map[string]string:
for k2, v2 := range value {
value[k2] = ExpandValueEnv(v2)
}
m[k] = value
}
}
return m
}
// ExpandValueEnv returns value of convert with environment variable.
//
// Return environment variable if value start with "${" and end with "}".
// Return default value if environment variable is empty or not exist.
//
// It accept value formats "${env}" , "${env||}}" , "${env||defaultValue}" , "defaultvalue".
// Examples:
// v1 := config.ExpandValueEnv("${GOPATH}") // return the GOPATH environment variable.
// v2 := config.ExpandValueEnv("${GOAsta||/usr/local/go}") // return the default value "/usr/local/go/".
// v3 := config.ExpandValueEnv("Astaxie") // return the value "Astaxie".
func ExpandValueEnv(value string) (realValue string) {
realValue = value
vLen := len(value)
// 3 = ${}
if vLen < 3 {
return
}
// Need start with "${" and end with "}", then return.
if value[0] != '$' || value[1] != '{' || value[vLen-1] != '}' {
return
}
key := ""
defalutV := ""
// value start with "${"
for i := 2; i < vLen; i++ {
if value[i] == '|' && (i+1 < vLen && value[i+1] == '|') {
key = value[2:i]
defalutV = value[i+2 : vLen-1] // other string is default value.
break
} else if value[i] == '}' {
key = value[2:i]
break
}
}
realValue = os.Getenv(key)
if realValue == "" {
realValue = defalutV
}
return
}
// ParseBool returns the boolean value represented by the string. // ParseBool returns the boolean value represented by the string.
// //
// It accepts 1, 1.0, t, T, TRUE, true, True, YES, yes, Yes,Y, y, ON, on, On, // It accepts 1, 1.0, t, T, TRUE, true, True, YES, yes, Yes,Y, y, ON, on, On,

55
config/config_test.go Normal file
View File

@ -0,0 +1,55 @@
// Copyright 2016 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package config
import (
"os"
"testing"
)
func TestExpandValueEnv(t *testing.T) {
testCases := []struct {
item string
want string
}{
{"", ""},
{"$", "$"},
{"{", "{"},
{"{}", "{}"},
{"${}", ""},
{"${|}", ""},
{"${}", ""},
{"${{}}", ""},
{"${{||}}", "}"},
{"${pwd||}", ""},
{"${pwd||}", ""},
{"${pwd||}", ""},
{"${pwd||}}", "}"},
{"${pwd||{{||}}}", "{{||}}"},
{"${GOPATH}", os.Getenv("GOPATH")},
{"${GOPATH||}", os.Getenv("GOPATH")},
{"${GOPATH||root}", os.Getenv("GOPATH")},
{"${GOPATH_NOT||root}", "root"},
{"${GOPATH_NOT||||root}", "||root"},
}
for _, c := range testCases {
if got := ExpandValueEnv(c.item); got != c.want {
t.Errorf("expand value error, item %q want %q, got %q", c.item, c.want, got)
}
}
}

View File

@ -38,7 +38,7 @@ func (c *fakeConfigContainer) String(key string) string {
} }
func (c *fakeConfigContainer) DefaultString(key string, defaultval string) string { func (c *fakeConfigContainer) DefaultString(key string, defaultval string) string {
v := c.getData(key) v := c.String(key)
if v == "" { if v == "" {
return defaultval return defaultval
} }
@ -46,7 +46,7 @@ func (c *fakeConfigContainer) DefaultString(key string, defaultval string) strin
} }
func (c *fakeConfigContainer) Strings(key string) []string { func (c *fakeConfigContainer) Strings(key string) []string {
v := c.getData(key) v := c.String(key)
if v == "" { if v == "" {
return nil return nil
} }

View File

@ -82,6 +82,10 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
if err == io.EOF { if err == io.EOF {
break break
} }
//It might be a good idea to throw a error on all unknonw errors?
if _, ok := err.(*os.PathError); ok {
return nil, err
}
if bytes.Equal(line, bEmpty) { if bytes.Equal(line, bEmpty) {
continue continue
} }
@ -162,7 +166,7 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
val = bytes.Trim(val, `"`) val = bytes.Trim(val, `"`)
} }
cfg.data[section][key] = string(val) cfg.data[section][key] = ExpandValueEnv(string(val))
if comment.Len() > 0 { if comment.Len() > 0 {
cfg.keyComment[section+"."+key] = comment.String() cfg.keyComment[section+"."+key] = comment.String()
comment.Reset() comment.Reset()
@ -296,7 +300,9 @@ func (c *IniConfigContainer) GetSection(section string) (map[string]string, erro
return nil, errors.New("not exist setction") return nil, errors.New("not exist setction")
} }
// SaveConfigFile save the config into file // SaveConfigFile save the config into file.
//
// BUG(env): The environment variable config item will be saved with real value in SaveConfigFile Funcation.
func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) {
// Write configuration file by filename. // Write configuration file by filename.
f, err := os.Create(filename) f, err := os.Create(filename)

View File

@ -42,11 +42,14 @@ needlogin = ON
enableSession = Y enableSession = Y
enableCookie = N enableCookie = N
flag = 1 flag = 1
path1 = ${GOPATH}
path2 = ${GOPATH||/home/go}
[demo] [demo]
key1="asta" key1="asta"
key2 = "xie" key2 = "xie"
CaseInsensitive = true CaseInsensitive = true
peers = one;two;three peers = one;two;three
password = ${GOPATH}
` `
keyValue = map[string]interface{}{ keyValue = map[string]interface{}{
@ -64,10 +67,13 @@ peers = one;two;three
"enableSession": true, "enableSession": true,
"enableCookie": false, "enableCookie": false,
"flag": true, "flag": true,
"path1": os.Getenv("GOPATH"),
"path2": os.Getenv("GOPATH"),
"demo::key1": "asta", "demo::key1": "asta",
"demo::key2": "xie", "demo::key2": "xie",
"demo::CaseInsensitive": true, "demo::CaseInsensitive": true,
"demo::peers": []string{"one", "two", "three"}, "demo::peers": []string{"one", "two", "three"},
"demo::password": os.Getenv("GOPATH"),
"null": "", "null": "",
"demo2::key1": "", "demo2::key1": "",
"error": "", "error": "",

View File

@ -57,6 +57,9 @@ func (js *JSONConfig) ParseData(data []byte) (Configer, error) {
} }
x.data["rootArray"] = wrappingArray x.data["rootArray"] = wrappingArray
} }
x.data = ExpandValueEnvForMap(x.data)
return x, nil return x, nil
} }

View File

@ -86,16 +86,19 @@ func TestJson(t *testing.T) {
"enableSession": "Y", "enableSession": "Y",
"enableCookie": "N", "enableCookie": "N",
"flag": 1, "flag": 1,
"path1": "${GOPATH}",
"path2": "${GOPATH||/home/go}",
"database": { "database": {
"host": "host", "host": "host",
"port": "port", "port": "port",
"database": "database", "database": "database",
"username": "username", "username": "username",
"password": "password", "password": "${GOPATH}",
"conns":{ "conns":{
"maxconnection":12, "maxconnection":12,
"autoconnect":true, "autoconnect":true,
"connectioninfo":"info" "connectioninfo":"info",
"root": "${GOPATH}"
} }
} }
}` }`
@ -115,13 +118,16 @@ func TestJson(t *testing.T) {
"enableSession": true, "enableSession": true,
"enableCookie": false, "enableCookie": false,
"flag": true, "flag": true,
"path1": os.Getenv("GOPATH"),
"path2": os.Getenv("GOPATH"),
"database::host": "host", "database::host": "host",
"database::port": "port", "database::port": "port",
"database::database": "database", "database::database": "database",
"database::password": "password", "database::password": os.Getenv("GOPATH"),
"database::conns::maxconnection": 12, "database::conns::maxconnection": 12,
"database::conns::autoconnect": true, "database::conns::autoconnect": true,
"database::conns::connectioninfo": "info", "database::conns::connectioninfo": "info",
"database::conns::root": os.Getenv("GOPATH"),
"unknown": "", "unknown": "",
} }
) )

View File

@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Package xml for config provider // Package xml for config provider.
// //
// depend on github.com/beego/x2j // depend on github.com/beego/x2j.
// //
// go install github.com/beego/x2j // go install github.com/beego/x2j.
// //
// Usage: // Usage:
// import( // import(
@ -26,7 +26,7 @@
// //
// cnf, err := config.NewConfig("xml", "config.xml") // cnf, err := config.NewConfig("xml", "config.xml")
// //
// more docs http://beego.me/docs/module/config.md //More docs http://beego.me/docs/module/config.md
package xml package xml
import ( import (
@ -69,7 +69,7 @@ func (xc *Config) Parse(filename string) (config.Configer, error) {
return nil, err return nil, err
} }
x.data = d["config"].(map[string]interface{}) x.data = config.ExpandValueEnvForMap(d["config"].(map[string]interface{}))
return x, nil return x, nil
} }
@ -92,7 +92,7 @@ type ConfigContainer struct {
// Bool returns the boolean value for a given key. // Bool returns the boolean value for a given key.
func (c *ConfigContainer) Bool(key string) (bool, error) { func (c *ConfigContainer) Bool(key string) (bool, error) {
if v, ok := c.data[key]; ok { if v := c.data[key]; v != nil {
return config.ParseBool(v) return config.ParseBool(v)
} }
return false, fmt.Errorf("not exist key: %q", key) return false, fmt.Errorf("not exist key: %q", key)

View File

@ -15,14 +15,18 @@
package xml package xml
import ( import (
"fmt"
"os" "os"
"testing" "testing"
"github.com/astaxie/beego/config" "github.com/astaxie/beego/config"
) )
//xml parse should incluce in <config></config> tags func TestXML(t *testing.T) {
var xmlcontext = `<?xml version="1.0" encoding="UTF-8"?>
var (
//xml parse should incluce in <config></config> tags
xmlcontext = `<?xml version="1.0" encoding="UTF-8"?>
<config> <config>
<appname>beeapi</appname> <appname>beeapi</appname>
<httpport>8080</httpport> <httpport>8080</httpport>
@ -31,10 +35,25 @@ var xmlcontext = `<?xml version="1.0" encoding="UTF-8"?>
<runmode>dev</runmode> <runmode>dev</runmode>
<autorender>false</autorender> <autorender>false</autorender>
<copyrequestbody>true</copyrequestbody> <copyrequestbody>true</copyrequestbody>
<path1>${GOPATH}</path1>
<path2>${GOPATH||/home/go}</path2>
</config> </config>
` `
keyValue = map[string]interface{}{
"appname": "beeapi",
"httpport": 8080,
"mysqlport": int64(3600),
"PI": 3.1415976,
"runmode": "dev",
"autorender": false,
"copyrequestbody": true,
"path1": os.Getenv("GOPATH"),
"path2": os.Getenv("GOPATH"),
"error": "",
"emptystrings": []string{},
}
)
func TestXML(t *testing.T) {
f, err := os.Create("testxml.conf") f, err := os.Create("testxml.conf")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -50,39 +69,42 @@ func TestXML(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if xmlconf.String("appname") != "beeapi" {
t.Fatal("appname not equal to beeapi") for k, v := range keyValue {
var (
value interface{}
err error
)
switch v.(type) {
case int:
value, err = xmlconf.Int(k)
case int64:
value, err = xmlconf.Int64(k)
case float64:
value, err = xmlconf.Float(k)
case bool:
value, err = xmlconf.Bool(k)
case []string:
value = xmlconf.Strings(k)
case string:
value = xmlconf.String(k)
default:
value, err = xmlconf.DIY(k)
} }
if port, err := xmlconf.Int("httpport"); err != nil || port != 8080 { if err != nil {
t.Error(port) t.Errorf("get key %q value fatal,%v err %s", k, v, err)
t.Fatal(err) } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) {
t.Errorf("get key %q value, want %v got %v .", k, v, value)
} }
if port, err := xmlconf.Int64("mysqlport"); err != nil || port != 3600 {
t.Error(port)
t.Fatal(err)
}
if pi, err := xmlconf.Float("PI"); err != nil || pi != 3.1415976 {
t.Error(pi)
t.Fatal(err)
}
if xmlconf.String("runmode") != "dev" {
t.Fatal("runmode not equal to dev")
}
if v, err := xmlconf.Bool("autorender"); err != nil || v != false {
t.Error(v)
t.Fatal(err)
}
if v, err := xmlconf.Bool("copyrequestbody"); err != nil || v != true {
t.Error(v)
t.Fatal(err)
} }
if err = xmlconf.Set("name", "astaxie"); err != nil { if err = xmlconf.Set("name", "astaxie"); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if xmlconf.String("name") != "astaxie" { if xmlconf.String("name") != "astaxie" {
t.Fatal("get name error") t.Fatal("get name error")
} }
if xmlconf.Strings("emptystrings") != nil {
t.Fatal("get emtpy strings error")
}
} }

View File

@ -26,7 +26,7 @@
// //
// cnf, err := config.NewConfig("yaml", "config.yaml") // cnf, err := config.NewConfig("yaml", "config.yaml")
// //
// more docs http://beego.me/docs/module/config.md //More docs http://beego.me/docs/module/config.md
package yaml package yaml
import ( import (
@ -110,6 +110,7 @@ func ReadYmlReader(path string) (cnf map[string]interface{}, err error) {
log.Println("Not a Map? >> ", string(buf), data) log.Println("Not a Map? >> ", string(buf), data)
cnf = nil cnf = nil
} }
cnf = config.ExpandValueEnvForMap(cnf)
return return
} }
@ -121,10 +122,11 @@ type ConfigContainer struct {
// Bool returns the boolean value for a given key. // Bool returns the boolean value for a given key.
func (c *ConfigContainer) Bool(key string) (bool, error) { func (c *ConfigContainer) Bool(key string) (bool, error) {
if v, ok := c.data[key]; ok { v, err := c.getData(key)
return config.ParseBool(v) if err != nil {
return false, err
} }
return false, fmt.Errorf("not exist key: %q", key) return config.ParseBool(v)
} }
// DefaultBool return the bool value if has no error // DefaultBool return the bool value if has no error
@ -139,8 +141,12 @@ func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool {
// Int returns the integer value for a given key. // Int returns the integer value for a given key.
func (c *ConfigContainer) Int(key string) (int, error) { func (c *ConfigContainer) Int(key string) (int, error) {
if v, ok := c.data[key].(int64); ok { if v, err := c.getData(key); err != nil {
return int(v), nil return 0, err
} else if vv, ok := v.(int); ok {
return vv, nil
} else if vv, ok := v.(int64); ok {
return int(vv), nil
} }
return 0, errors.New("not int value") return 0, errors.New("not int value")
} }
@ -157,8 +163,10 @@ func (c *ConfigContainer) DefaultInt(key string, defaultval int) int {
// Int64 returns the int64 value for a given key. // Int64 returns the int64 value for a given key.
func (c *ConfigContainer) Int64(key string) (int64, error) { func (c *ConfigContainer) Int64(key string) (int64, error) {
if v, ok := c.data[key].(int64); ok { if v, err := c.getData(key); err != nil {
return v, nil return 0, err
} else if vv, ok := v.(int64); ok {
return vv, nil
} }
return 0, errors.New("not bool value") return 0, errors.New("not bool value")
} }
@ -175,8 +183,14 @@ func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
// Float returns the float value for a given key. // Float returns the float value for a given key.
func (c *ConfigContainer) Float(key string) (float64, error) { func (c *ConfigContainer) Float(key string) (float64, error) {
if v, ok := c.data[key].(float64); ok { if v, err := c.getData(key); err != nil {
return v, nil return 0.0, err
} else if vv, ok := v.(float64); ok {
return vv, nil
} else if vv, ok := v.(int); ok {
return float64(vv), nil
} else if vv, ok := v.(int64); ok {
return float64(vv), nil
} }
return 0.0, errors.New("not float64 value") return 0.0, errors.New("not float64 value")
} }
@ -193,8 +207,10 @@ func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
// String returns the string value for a given key. // String returns the string value for a given key.
func (c *ConfigContainer) String(key string) string { func (c *ConfigContainer) String(key string) string {
if v, ok := c.data[key].(string); ok { if v, err := c.getData(key); err == nil {
return v if vv, ok := v.(string); ok {
return vv
}
} }
return "" return ""
} }
@ -230,8 +246,8 @@ func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []stri
// GetSection returns map for the given section // GetSection returns map for the given section
func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { func (c *ConfigContainer) GetSection(section string) (map[string]string, error) {
v, ok := c.data[section]
if ok { if v, ok := c.data[section]; ok {
return v.(map[string]string), nil return v.(map[string]string), nil
} }
return nil, errors.New("not exist setction") return nil, errors.New("not exist setction")
@ -259,10 +275,19 @@ func (c *ConfigContainer) Set(key, val string) error {
// DIY returns the raw value by a given key. // DIY returns the raw value by a given key.
func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { func (c *ConfigContainer) DIY(key string) (v interface{}, err error) {
return c.getData(key)
}
func (c *ConfigContainer) getData(key string) (interface{}, error) {
if len(key) == 0 {
return nil, errors.New("key is emtpy")
}
if v, ok := c.data[key]; ok { if v, ok := c.data[key]; ok {
return v, nil return v, nil
} }
return nil, errors.New("not exist key") return nil, fmt.Errorf("not exist key %q", key)
} }
func init() { func init() {

View File

@ -15,13 +15,17 @@
package yaml package yaml
import ( import (
"fmt"
"os" "os"
"testing" "testing"
"github.com/astaxie/beego/config" "github.com/astaxie/beego/config"
) )
var yamlcontext = ` func TestYaml(t *testing.T) {
var (
yamlcontext = `
"appname": beeapi "appname": beeapi
"httpport": 8080 "httpport": 8080
"mysqlport": 3600 "mysqlport": 3600
@ -29,9 +33,27 @@ var yamlcontext = `
"runmode": dev "runmode": dev
"autorender": false "autorender": false
"copyrequestbody": true "copyrequestbody": true
"PATH": GOPATH
"path1": ${GOPATH}
"path2": ${GOPATH||/home/go}
"empty": ""
` `
func TestYaml(t *testing.T) { keyValue = map[string]interface{}{
"appname": "beeapi",
"httpport": 8080,
"mysqlport": int64(3600),
"PI": 3.1415976,
"runmode": "dev",
"autorender": false,
"copyrequestbody": true,
"PATH": "GOPATH",
"path1": os.Getenv("GOPATH"),
"path2": os.Getenv("GOPATH"),
"error": "",
"emptystrings": []string{},
}
)
f, err := os.Create("testyaml.conf") f, err := os.Create("testyaml.conf")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -47,32 +69,42 @@ func TestYaml(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if yamlconf.String("appname") != "beeapi" { if yamlconf.String("appname") != "beeapi" {
t.Fatal("appname not equal to beeapi") t.Fatal("appname not equal to beeapi")
} }
if port, err := yamlconf.Int("httpport"); err != nil || port != 8080 {
t.Error(port) for k, v := range keyValue {
t.Fatal(err)
var (
value interface{}
err error
)
switch v.(type) {
case int:
value, err = yamlconf.Int(k)
case int64:
value, err = yamlconf.Int64(k)
case float64:
value, err = yamlconf.Float(k)
case bool:
value, err = yamlconf.Bool(k)
case []string:
value = yamlconf.Strings(k)
case string:
value = yamlconf.String(k)
default:
value, err = yamlconf.DIY(k)
} }
if port, err := yamlconf.Int64("mysqlport"); err != nil || port != 3600 { if err != nil {
t.Error(port) t.Errorf("get key %q value fatal,%v err %s", k, v, err)
t.Fatal(err) } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) {
t.Errorf("get key %q value, want %v got %v .", k, v, value)
} }
if pi, err := yamlconf.Float("PI"); err != nil || pi != 3.1415976 {
t.Error(pi)
t.Fatal(err)
}
if yamlconf.String("runmode") != "dev" {
t.Fatal("runmode not equal to dev")
}
if v, err := yamlconf.Bool("autorender"); err != nil || v != false {
t.Error(v)
t.Fatal(err)
}
if v, err := yamlconf.Bool("copyrequestbody"); err != nil || v != true {
t.Error(v)
t.Fatal(err)
} }
if err = yamlconf.Set("name", "astaxie"); err != nil { if err = yamlconf.Set("name", "astaxie"); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -80,7 +112,4 @@ func TestYaml(t *testing.T) {
t.Fatal("get name error") t.Fatal("get name error")
} }
if yamlconf.Strings("emptystrings") != nil {
t.Fatal("get emtpy strings error")
}
} }

View File

@ -15,7 +15,11 @@
package beego package beego
import ( import (
"encoding/json"
"reflect"
"testing" "testing"
"github.com/astaxie/beego/config"
) )
func TestDefaults(t *testing.T) { func TestDefaults(t *testing.T) {
@ -27,3 +31,109 @@ func TestDefaults(t *testing.T) {
t.Errorf("FlashName was not set to default.") t.Errorf("FlashName was not set to default.")
} }
} }
func TestAssignConfig_01(t *testing.T) {
_BConfig := &Config{}
_BConfig.AppName = "beego_test"
jcf := &config.JSONConfig{}
ac, _ := jcf.ParseData([]byte(`{"AppName":"beego_json"}`))
assignSingleConfig(_BConfig, ac)
if _BConfig.AppName != "beego_json" {
t.Log(_BConfig)
t.FailNow()
}
}
func TestAssignConfig_02(t *testing.T) {
_BConfig := &Config{}
bs, _ := json.Marshal(newBConfig())
jsonMap := map[string]interface{}{}
json.Unmarshal(bs, &jsonMap)
configMap := map[string]interface{}{}
for k, v := range jsonMap {
if reflect.TypeOf(v).Kind() == reflect.Map {
for k1, v1 := range v.(map[string]interface{}) {
if reflect.TypeOf(v1).Kind() == reflect.Map {
for k2, v2 := range v1.(map[string]interface{}) {
configMap[k2] = v2
}
} else {
configMap[k1] = v1
}
}
} else {
configMap[k] = v
}
}
configMap["MaxMemory"] = 1024
configMap["Graceful"] = true
configMap["XSRFExpire"] = 32
configMap["SessionProviderConfig"] = "file"
configMap["FileLineNum"] = true
jcf := &config.JSONConfig{}
bs, _ = json.Marshal(configMap)
ac, _ := jcf.ParseData([]byte(bs))
for _, i := range []interface{}{_BConfig, &_BConfig.Listen, &_BConfig.WebConfig, &_BConfig.Log, &_BConfig.WebConfig.Session} {
assignSingleConfig(i, ac)
}
if _BConfig.MaxMemory != 1024 {
t.Log(_BConfig.MaxMemory)
t.FailNow()
}
if !_BConfig.Listen.Graceful {
t.Log(_BConfig.Listen.Graceful)
t.FailNow()
}
if _BConfig.WebConfig.XSRFExpire != 32 {
t.Log(_BConfig.WebConfig.XSRFExpire)
t.FailNow()
}
if _BConfig.WebConfig.Session.SessionProviderConfig != "file" {
t.Log(_BConfig.WebConfig.Session.SessionProviderConfig)
t.FailNow()
}
if !_BConfig.Log.FileLineNum {
t.Log(_BConfig.Log.FileLineNum)
t.FailNow()
}
}
func TestAssignConfig_03(t *testing.T) {
jcf := &config.JSONConfig{}
ac, _ := jcf.ParseData([]byte(`{"AppName":"beego"}`))
ac.Set("AppName", "test_app")
ac.Set("RunMode", "online")
ac.Set("StaticDir", "download:down download2:down2")
ac.Set("StaticExtensionsToGzip", ".css,.js,.html,.jpg,.png")
assignConfig(ac)
t.Logf("%#v",BConfig)
if BConfig.AppName != "test_app" {
t.FailNow()
}
if BConfig.RunMode != "online" {
t.FailNow()
}
if BConfig.WebConfig.StaticDir["/download"] != "down" {
t.FailNow()
}
if BConfig.WebConfig.StaticDir["/download2"] != "down2" {
t.FailNow()
}
if len(BConfig.WebConfig.StaticExtensionsToGzip) != 5 {
t.FailNow()
}
}

View File

@ -27,6 +27,33 @@ import (
"sync" "sync"
) )
var (
//Default size==20B same as nginx
defaultGzipMinLength = 20
//Content will only be compressed if content length is either unknown or greater than gzipMinLength.
gzipMinLength = defaultGzipMinLength
//The compression level used for deflate compression. (0-9).
gzipCompressLevel int
//List of HTTP methods to compress. If not set, only GET requests are compressed.
includedMethods map[string]bool
getMethodOnly bool
)
func InitGzip(minLength, compressLevel int, methods []string) {
if minLength >= 0 {
gzipMinLength = minLength
}
gzipCompressLevel = compressLevel
if gzipCompressLevel < flate.NoCompression || gzipCompressLevel > flate.BestCompression {
gzipCompressLevel = flate.BestSpeed
}
getMethodOnly = (len(methods) == 0) || (len(methods) == 1 && strings.ToUpper(methods[0]) == "GET")
includedMethods = make(map[string]bool, len(methods))
for _, v := range methods {
includedMethods[strings.ToUpper(v)] = true
}
}
type resetWriter interface { type resetWriter interface {
io.Writer io.Writer
Reset(w io.Writer) Reset(w io.Writer)
@ -43,18 +70,18 @@ func (n nopResetWriter) Reset(w io.Writer) {
type acceptEncoder struct { type acceptEncoder struct {
name string name string
levelEncode func(int) resetWriter levelEncode func(int) resetWriter
bestSpeedPool *sync.Pool customCompressLevelPool *sync.Pool
bestCompressionPool *sync.Pool bestCompressionPool *sync.Pool
} }
func (ac acceptEncoder) encode(wr io.Writer, level int) resetWriter { func (ac acceptEncoder) encode(wr io.Writer, level int) resetWriter {
if ac.bestSpeedPool == nil || ac.bestCompressionPool == nil { if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil {
return nopResetWriter{wr} return nopResetWriter{wr}
} }
var rwr resetWriter var rwr resetWriter
switch level { switch level {
case flate.BestSpeed: case flate.BestSpeed:
rwr = ac.bestSpeedPool.Get().(resetWriter) rwr = ac.customCompressLevelPool.Get().(resetWriter)
case flate.BestCompression: case flate.BestCompression:
rwr = ac.bestCompressionPool.Get().(resetWriter) rwr = ac.bestCompressionPool.Get().(resetWriter)
default: default:
@ -65,13 +92,18 @@ func (ac acceptEncoder) encode(wr io.Writer, level int) resetWriter {
} }
func (ac acceptEncoder) put(wr resetWriter, level int) { func (ac acceptEncoder) put(wr resetWriter, level int) {
if ac.bestSpeedPool == nil || ac.bestCompressionPool == nil { if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil {
return return
} }
wr.Reset(nil) wr.Reset(nil)
//notice
//compressionLevel==BestCompression DOES NOT MATTER
//sync.Pool will not memory leak
switch level { switch level {
case flate.BestSpeed: case gzipCompressLevel:
ac.bestSpeedPool.Put(wr) ac.customCompressLevelPool.Put(wr)
case flate.BestCompression: case flate.BestCompression:
ac.bestCompressionPool.Put(wr) ac.bestCompressionPool.Put(wr)
} }
@ -79,28 +111,22 @@ func (ac acceptEncoder) put(wr resetWriter, level int) {
var ( var (
noneCompressEncoder = acceptEncoder{"", nil, nil, nil} noneCompressEncoder = acceptEncoder{"", nil, nil, nil}
gzipCompressEncoder = acceptEncoder{"gzip", gzipCompressEncoder = acceptEncoder{
func(level int) resetWriter { wr, _ := gzip.NewWriterLevel(nil, level); return wr }, name: "gzip",
&sync.Pool{ levelEncode: func(level int) resetWriter { wr, _ := gzip.NewWriterLevel(nil, level); return wr },
New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestSpeed); return wr }, customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, gzipCompressLevel); return wr }},
}, bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestCompression); return wr }},
&sync.Pool{
New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestCompression); return wr },
},
} }
//according to the sec :http://tools.ietf.org/html/rfc2616#section-3.5 ,the deflate compress in http is zlib indeed //according to the sec :http://tools.ietf.org/html/rfc2616#section-3.5 ,the deflate compress in http is zlib indeed
//deflate //deflate
//The "zlib" format defined in RFC 1950 [31] in combination with //The "zlib" format defined in RFC 1950 [31] in combination with
//the "deflate" compression mechanism described in RFC 1951 [29]. //the "deflate" compression mechanism described in RFC 1951 [29].
deflateCompressEncoder = acceptEncoder{"deflate", deflateCompressEncoder = acceptEncoder{
func(level int) resetWriter { wr, _ := zlib.NewWriterLevel(nil, level); return wr }, name: "deflate",
&sync.Pool{ levelEncode: func(level int) resetWriter { wr, _ := zlib.NewWriterLevel(nil, level); return wr },
New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestSpeed); return wr }, customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, gzipCompressLevel); return wr }},
}, bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestCompression); return wr }},
&sync.Pool{
New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestCompression); return wr },
},
} }
) )
@ -120,7 +146,11 @@ func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string,
// WriteBody reads writes content to writer by the specific encoding(gzip/deflate) // WriteBody reads writes content to writer by the specific encoding(gzip/deflate)
func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) { func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) {
return writeLevel(encoding, writer, bytes.NewReader(content), flate.BestSpeed) if encoding == "" || len(content) < gzipMinLength {
_, err := writer.Write(content)
return false, "", err
}
return writeLevel(encoding, writer, bytes.NewReader(content), gzipCompressLevel)
} }
// writeLevel reads from reader,writes to writer by specific encoding and compress level // writeLevel reads from reader,writes to writer by specific encoding and compress level
@ -156,7 +186,10 @@ func ParseEncoding(r *http.Request) string {
if r == nil { if r == nil {
return "" return ""
} }
if (getMethodOnly && r.Method == "GET") || includedMethods[r.Method] {
return parseEncoding(r) return parseEncoding(r)
}
return ""
} }
type q struct { type q struct {

View File

@ -24,13 +24,11 @@ package context
import ( import (
"bufio" "bufio"
"bytes"
"crypto/hmac" "crypto/hmac"
"crypto/sha1" "crypto/sha1"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"net/http" "net/http"
"strconv" "strconv"
@ -67,6 +65,7 @@ func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) {
ctx.ResponseWriter.reset(rw) ctx.ResponseWriter.reset(rw)
ctx.Input.Reset(ctx) ctx.Input.Reset(ctx)
ctx.Output.Reset(ctx) ctx.Output.Reset(ctx)
ctx._xsrfToken = ""
} }
// Redirect does redirection to localurl with http header status code. // Redirect does redirection to localurl with http header status code.
@ -79,6 +78,7 @@ func (ctx *Context) Redirect(status int, localurl string) {
// Abort stops this request. // Abort stops this request.
// if beego.ErrorMaps exists, panic body. // if beego.ErrorMaps exists, panic body.
func (ctx *Context) Abort(status int, body string) { func (ctx *Context) Abort(status int, body string) {
ctx.Output.SetStatus(status)
panic(body) panic(body)
} }
@ -195,14 +195,6 @@ func (r *Response) Write(p []byte) (int, error) {
return r.ResponseWriter.Write(p) return r.ResponseWriter.Write(p)
} }
// Copy writes the data to the connection as part of an HTTP reply,
// and sets `started` to true.
// started means the response has sent out.
func (r *Response) Copy(buf *bytes.Buffer) (int64, error) {
r.Started = true
return io.Copy(r.ResponseWriter, buf)
}
// WriteHeader sends an HTTP response header with status code, // WriteHeader sends an HTTP response header with status code,
// and sets `started` to true. // and sets `started` to true.
func (r *Response) WriteHeader(code int) { func (r *Response) WriteHeader(code int) {

47
context/context_test.go Normal file
View File

@ -0,0 +1,47 @@
// Copyright 2016 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package context
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestXsrfReset_01(t *testing.T) {
r := &http.Request{}
c := NewContext()
c.Request = r
c.ResponseWriter = &Response{}
c.ResponseWriter.reset(httptest.NewRecorder())
c.Output.Reset(c)
c.Input.Reset(c)
c.XSRFToken("key", 16)
if c._xsrfToken == "" {
t.FailNow()
}
token := c._xsrfToken
c.Reset(&Response{ResponseWriter: httptest.NewRecorder()}, r)
if c._xsrfToken != "" {
t.FailNow()
}
c.XSRFToken("key", 16)
if c._xsrfToken == "" {
t.FailNow()
}
if token == c._xsrfToken {
t.FailNow()
}
}

View File

@ -100,7 +100,7 @@ func TestSubDomain(t *testing.T) {
/* TODO Fix this /* TODO Fix this
r, _ = http.NewRequest("GET", "http://127.0.0.1/", nil) r, _ = http.NewRequest("GET", "http://127.0.0.1/", nil)
beegoInput.Request = r beegoInput.Context.Request = r
if beegoInput.SubDomains() != "" { if beegoInput.SubDomains() != "" {
t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains())
} }

View File

@ -21,8 +21,11 @@ import (
"errors" "errors"
"fmt" "fmt"
"html/template" "html/template"
"io"
"mime" "mime"
"net/http" "net/http"
"net/url"
"os"
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
@ -72,10 +75,11 @@ func (output *BeegoOutput) Body(content []byte) error {
if output.Status != 0 { if output.Status != 0 {
output.Context.ResponseWriter.WriteHeader(output.Status) output.Context.ResponseWriter.WriteHeader(output.Status)
output.Status = 0 output.Status = 0
} else {
output.Context.ResponseWriter.Started = true
} }
io.Copy(output.Context.ResponseWriter, buf)
_, err := output.Context.ResponseWriter.Copy(buf) return nil
return err
} }
// Cookie sets cookie value via given key. // Cookie sets cookie value via given key.
@ -235,13 +239,21 @@ func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error {
// Download forces response for download file. // Download forces response for download file.
// it prepares the download response header automatically. // it prepares the download response header automatically.
func (output *BeegoOutput) Download(file string, filename ...string) { func (output *BeegoOutput) Download(file string, filename ...string) {
// check get file error, file not found or other error.
if _, err := os.Stat(file); err != nil {
http.ServeFile(output.Context.ResponseWriter, output.Context.Request, file)
return
}
var fName string
if len(filename) > 0 && filename[0] != "" {
fName = filename[0]
} else {
fName = filepath.Base(file)
}
output.Header("Content-Disposition", "attachment; filename="+url.QueryEscape(fName))
output.Header("Content-Description", "File Transfer") output.Header("Content-Description", "File Transfer")
output.Header("Content-Type", "application/octet-stream") output.Header("Content-Type", "application/octet-stream")
if len(filename) > 0 && filename[0] != "" {
output.Header("Content-Disposition", "attachment; filename="+filename[0])
} else {
output.Header("Content-Disposition", "attachment; filename="+filepath.Base(file))
}
output.Header("Content-Transfer-Encoding", "binary") output.Header("Content-Transfer-Encoding", "binary")
output.Header("Expires", "0") output.Header("Expires", "0")
output.Header("Cache-Control", "must-revalidate") output.Header("Cache-Control", "must-revalidate")
@ -269,55 +281,55 @@ func (output *BeegoOutput) SetStatus(status int) {
// IsCachable returns boolean of this request is cached. // IsCachable returns boolean of this request is cached.
// HTTP 304 means cached. // HTTP 304 means cached.
func (output *BeegoOutput) IsCachable(status int) bool { func (output *BeegoOutput) IsCachable() bool {
return output.Status >= 200 && output.Status < 300 || output.Status == 304 return output.Status >= 200 && output.Status < 300 || output.Status == 304
} }
// IsEmpty returns boolean of this request is empty. // IsEmpty returns boolean of this request is empty.
// HTTP 201204 and 304 means empty. // HTTP 201204 and 304 means empty.
func (output *BeegoOutput) IsEmpty(status int) bool { func (output *BeegoOutput) IsEmpty() bool {
return output.Status == 201 || output.Status == 204 || output.Status == 304 return output.Status == 201 || output.Status == 204 || output.Status == 304
} }
// IsOk returns boolean of this request runs well. // IsOk returns boolean of this request runs well.
// HTTP 200 means ok. // HTTP 200 means ok.
func (output *BeegoOutput) IsOk(status int) bool { func (output *BeegoOutput) IsOk() bool {
return output.Status == 200 return output.Status == 200
} }
// IsSuccessful returns boolean of this request runs successfully. // IsSuccessful returns boolean of this request runs successfully.
// HTTP 2xx means ok. // HTTP 2xx means ok.
func (output *BeegoOutput) IsSuccessful(status int) bool { func (output *BeegoOutput) IsSuccessful() bool {
return output.Status >= 200 && output.Status < 300 return output.Status >= 200 && output.Status < 300
} }
// IsRedirect returns boolean of this request is redirection header. // IsRedirect returns boolean of this request is redirection header.
// HTTP 301,302,307 means redirection. // HTTP 301,302,307 means redirection.
func (output *BeegoOutput) IsRedirect(status int) bool { func (output *BeegoOutput) IsRedirect() bool {
return output.Status == 301 || output.Status == 302 || output.Status == 303 || output.Status == 307 return output.Status == 301 || output.Status == 302 || output.Status == 303 || output.Status == 307
} }
// IsForbidden returns boolean of this request is forbidden. // IsForbidden returns boolean of this request is forbidden.
// HTTP 403 means forbidden. // HTTP 403 means forbidden.
func (output *BeegoOutput) IsForbidden(status int) bool { func (output *BeegoOutput) IsForbidden() bool {
return output.Status == 403 return output.Status == 403
} }
// IsNotFound returns boolean of this request is not found. // IsNotFound returns boolean of this request is not found.
// HTTP 404 means forbidden. // HTTP 404 means forbidden.
func (output *BeegoOutput) IsNotFound(status int) bool { func (output *BeegoOutput) IsNotFound() bool {
return output.Status == 404 return output.Status == 404
} }
// IsClientError returns boolean of this request client sends error data. // IsClientError returns boolean of this request client sends error data.
// HTTP 4xx means forbidden. // HTTP 4xx means forbidden.
func (output *BeegoOutput) IsClientError(status int) bool { func (output *BeegoOutput) IsClientError() bool {
return output.Status >= 400 && output.Status < 500 return output.Status >= 400 && output.Status < 500
} }
// IsServerError returns boolean of this server handler errors. // IsServerError returns boolean of this server handler errors.
// HTTP 5xx means server internal error. // HTTP 5xx means server internal error.
func (output *BeegoOutput) IsServerError(status int) bool { func (output *BeegoOutput) IsServerError() bool {
return output.Status >= 500 && output.Status < 600 return output.Status >= 500 && output.Status < 600
} }

View File

@ -208,7 +208,7 @@ func (c *Controller) RenderBytes() ([]byte, error) {
continue continue
} }
buf.Reset() buf.Reset()
err = executeTemplate(&buf, sectionTpl, c.Data) err = ExecuteTemplate(&buf, sectionTpl, c.Data)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -217,7 +217,7 @@ func (c *Controller) RenderBytes() ([]byte, error) {
} }
buf.Reset() buf.Reset()
executeTemplate(&buf, c.Layout, c.Data) ExecuteTemplate(&buf, c.Layout, c.Data)
} }
return buf.Bytes(), err return buf.Bytes(), err
} }
@ -242,7 +242,7 @@ func (c *Controller) renderTemplate() (bytes.Buffer, error) {
} }
BuildTemplate(BConfig.WebConfig.ViewsPath, buildFiles...) BuildTemplate(BConfig.WebConfig.ViewsPath, buildFiles...)
} }
return buf, executeTemplate(&buf, c.TplName, c.Data) return buf, ExecuteTemplate(&buf, c.TplName, c.Data)
} }
// Redirect sends the redirection response to url with status code. // Redirect sends the redirection response to url with status code.
@ -261,12 +261,13 @@ func (c *Controller) Abort(code string) {
// CustomAbort stops controller handler and show the error data, it's similar Aborts, but support status code and body. // CustomAbort stops controller handler and show the error data, it's similar Aborts, but support status code and body.
func (c *Controller) CustomAbort(status int, body string) { func (c *Controller) CustomAbort(status int, body string) {
c.Ctx.Output.Status = status // first panic from ErrorMaps, it is user defined error functions.
// first panic from ErrorMaps, is is user defined error functions.
if _, ok := ErrorMaps[body]; ok { if _, ok := ErrorMaps[body]; ok {
c.Ctx.Output.Status = status
panic(body) panic(body)
} }
// last panic user string // last panic user string
c.Ctx.ResponseWriter.WriteHeader(status)
c.Ctx.ResponseWriter.Write([]byte(body)) c.Ctx.ResponseWriter.Write([]byte(body))
panic(ErrAbort) panic(ErrAbort)
} }

211
error.go
View File

@ -210,159 +210,139 @@ var ErrorMaps = make(map[string]*errorInfo, 10)
// show 401 unauthorized error. // show 401 unauthorized error.
func unauthorized(rw http.ResponseWriter, r *http.Request) { func unauthorized(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) responseError(rw, r,
data := map[string]interface{}{ 401,
"Title": http.StatusText(401), "<br>The page you have requested can't be authorized."+
"BeegoVersion": VERSION, "<br>Perhaps you are here because:"+
} "<br><br><ul>"+
data["Content"] = template.HTML("<br>The page you have requested can't be authorized." + "<br>The credentials you supplied are incorrect"+
"<br>Perhaps you are here because:" + "<br>There are errors in the website address"+
"<br><br><ul>" + "</ul>",
"<br>The credentials you supplied are incorrect" + )
"<br>There are errors in the website address" +
"</ul>")
t.Execute(rw, data)
} }
// show 402 Payment Required // show 402 Payment Required
func paymentRequired(rw http.ResponseWriter, r *http.Request) { func paymentRequired(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) responseError(rw, r,
data := map[string]interface{}{ 402,
"Title": http.StatusText(402), "<br>The page you have requested Payment Required."+
"BeegoVersion": VERSION, "<br>Perhaps you are here because:"+
} "<br><br><ul>"+
data["Content"] = template.HTML("<br>The page you have requested Payment Required." + "<br>The credentials you supplied are incorrect"+
"<br>Perhaps you are here because:" + "<br>There are errors in the website address"+
"<br><br><ul>" + "</ul>",
"<br>The credentials you supplied are incorrect" + )
"<br>There are errors in the website address" +
"</ul>")
t.Execute(rw, data)
} }
// show 403 forbidden error. // show 403 forbidden error.
func forbidden(rw http.ResponseWriter, r *http.Request) { func forbidden(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) responseError(rw, r,
data := map[string]interface{}{ 403,
"Title": http.StatusText(403), "<br>The page you have requested is forbidden."+
"BeegoVersion": VERSION, "<br>Perhaps you are here because:"+
} "<br><br><ul>"+
data["Content"] = template.HTML("<br>The page you have requested is forbidden." + "<br>Your address may be blocked"+
"<br>Perhaps you are here because:" + "<br>The site may be disabled"+
"<br><br><ul>" + "<br>You need to log in"+
"<br>Your address may be blocked" + "</ul>",
"<br>The site may be disabled" + )
"<br>You need to log in" +
"</ul>")
t.Execute(rw, data)
} }
// show 404 notfound error. // show 404 not found error.
func notFound(rw http.ResponseWriter, r *http.Request) { func notFound(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) responseError(rw, r,
data := map[string]interface{}{ 404,
"Title": http.StatusText(404), "<br>The page you have requested has flown the coop."+
"BeegoVersion": VERSION, "<br>Perhaps you are here because:"+
} "<br><br><ul>"+
data["Content"] = template.HTML("<br>The page you have requested has flown the coop." + "<br>The page has moved"+
"<br>Perhaps you are here because:" + "<br>The page no longer exists"+
"<br><br><ul>" + "<br>You were looking for your puppy and got lost"+
"<br>The page has moved" + "<br>You like 404 pages"+
"<br>The page no longer exists" + "</ul>",
"<br>You were looking for your puppy and got lost" + )
"<br>You like 404 pages" +
"</ul>")
t.Execute(rw, data)
} }
// show 405 Method Not Allowed // show 405 Method Not Allowed
func methodNotAllowed(rw http.ResponseWriter, r *http.Request) { func methodNotAllowed(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) responseError(rw, r,
data := map[string]interface{}{ 405,
"Title": http.StatusText(405), "<br>The method you have requested Not Allowed."+
"BeegoVersion": VERSION, "<br>Perhaps you are here because:"+
} "<br><br><ul>"+
data["Content"] = template.HTML("<br>The method you have requested Not Allowed." + "<br>The method specified in the Request-Line is not allowed for the resource identified by the Request-URI"+
"<br>Perhaps you are here because:" + "<br>The response MUST include an Allow header containing a list of valid methods for the requested resource."+
"<br><br><ul>" + "</ul>",
"<br>The method specified in the Request-Line is not allowed for the resource identified by the Request-URI" + )
"<br>The response MUST include an Allow header containing a list of valid methods for the requested resource." +
"</ul>")
t.Execute(rw, data)
} }
// show 500 internal server error. // show 500 internal server error.
func internalServerError(rw http.ResponseWriter, r *http.Request) { func internalServerError(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) responseError(rw, r,
data := map[string]interface{}{ 500,
"Title": http.StatusText(500), "<br>The page you have requested is down right now."+
"BeegoVersion": VERSION, "<br><br><ul>"+
} "<br>Please try again later and report the error to the website administrator"+
data["Content"] = template.HTML("<br>The page you have requested is down right now." + "<br></ul>",
"<br><br><ul>" + )
"<br>Please try again later and report the error to the website administrator" +
"<br></ul>")
t.Execute(rw, data)
} }
// show 501 Not Implemented. // show 501 Not Implemented.
func notImplemented(rw http.ResponseWriter, r *http.Request) { func notImplemented(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) responseError(rw, r,
data := map[string]interface{}{ 501,
"Title": http.StatusText(504), "<br>The page you have requested is Not Implemented."+
"BeegoVersion": VERSION, "<br><br><ul>"+
} "<br>Please try again later and report the error to the website administrator"+
data["Content"] = template.HTML("<br>The page you have requested is Not Implemented." + "<br></ul>",
"<br><br><ul>" + )
"<br>Please try again later and report the error to the website administrator" +
"<br></ul>")
t.Execute(rw, data)
} }
// show 502 Bad Gateway. // show 502 Bad Gateway.
func badGateway(rw http.ResponseWriter, r *http.Request) { func badGateway(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) responseError(rw, r,
data := map[string]interface{}{ 502,
"Title": http.StatusText(502), "<br>The page you have requested is down right now."+
"BeegoVersion": VERSION, "<br><br><ul>"+
} "<br>The server, while acting as a gateway or proxy, received an invalid response from the upstream server it accessed in attempting to fulfill the request."+
data["Content"] = template.HTML("<br>The page you have requested is down right now." + "<br>Please try again later and report the error to the website administrator"+
"<br><br><ul>" + "<br></ul>",
"<br>The server, while acting as a gateway or proxy, received an invalid response from the upstream server it accessed in attempting to fulfill the request." + )
"<br>Please try again later and report the error to the website administrator" +
"<br></ul>")
t.Execute(rw, data)
} }
// show 503 service unavailable error. // show 503 service unavailable error.
func serviceUnavailable(rw http.ResponseWriter, r *http.Request) { func serviceUnavailable(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) responseError(rw, r,
data := map[string]interface{}{ 503,
"Title": http.StatusText(503), "<br>The page you have requested is unavailable."+
"BeegoVersion": VERSION, "<br>Perhaps you are here because:"+
} "<br><br><ul>"+
data["Content"] = template.HTML("<br>The page you have requested is unavailable." + "<br><br>The page is overloaded"+
"<br>Perhaps you are here because:" + "<br>Please try again later."+
"<br><br><ul>" + "</ul>",
"<br><br>The page is overloaded" + )
"<br>Please try again later." +
"</ul>")
t.Execute(rw, data)
} }
// show 504 Gateway Timeout. // show 504 Gateway Timeout.
func gatewayTimeout(rw http.ResponseWriter, r *http.Request) { func gatewayTimeout(rw http.ResponseWriter, r *http.Request) {
responseError(rw, r,
504,
"<br>The page you have requested is unavailable"+
"<br>Perhaps you are here because:"+
"<br><br><ul>"+
"<br><br>The server, while acting as a gateway or proxy, did not receive a timely response from the upstream server specified by the URI."+
"<br>Please try again later."+
"</ul>",
)
}
func responseError(rw http.ResponseWriter, r *http.Request, errCode int, errContent string) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := map[string]interface{}{ data := map[string]interface{}{
"Title": http.StatusText(504), "Title": http.StatusText(errCode),
"BeegoVersion": VERSION, "BeegoVersion": VERSION,
"Content": template.HTML(errContent),
} }
data["Content"] = template.HTML("<br>The page you have requested is unavailable." +
"<br>Perhaps you are here because:" +
"<br><br><ul>" +
"<br><br>The server, while acting as a gateway or proxy, did not receive a timely response from the upstream server specified by the URI." +
"<br>Please try again later." +
"</ul>")
t.Execute(rw, data) t.Execute(rw, data)
} }
@ -408,8 +388,11 @@ func exception(errCode string, ctx *context.Context) {
if err == nil { if err == nil {
return v return v
} }
if ctx.Output.Status == 0 {
return 503 return 503
} }
return ctx.Output.Status
}
for _, ec := range []string{errCode, "503", "500"} { for _, ec := range []string{errCode, "503", "500"} {
if h, ok := ErrorMaps[ec]; ok { if h, ok := ErrorMaps[ec]; ok {

88
error_test.go Normal file
View File

@ -0,0 +1,88 @@
// Copyright 2016 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package beego
import (
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
)
type errorTestController struct {
Controller
}
const parseCodeError = "parse code error"
func (ec *errorTestController) Get() {
errorCode, err := ec.GetInt("code")
if err != nil {
ec.Abort(parseCodeError)
}
if errorCode != 0 {
ec.CustomAbort(errorCode, ec.GetString("code"))
}
ec.Abort("404")
}
func TestErrorCode_01(t *testing.T) {
registerDefaultErrorHandler()
for k := range ErrorMaps {
r, _ := http.NewRequest("GET", "/error?code="+k, nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/error", &errorTestController{})
handler.ServeHTTP(w, r)
code, _ := strconv.Atoi(k)
if w.Code != code {
t.Fail()
}
if !strings.Contains(string(w.Body.Bytes()), http.StatusText(code)) {
t.Fail()
}
}
}
func TestErrorCode_02(t *testing.T) {
registerDefaultErrorHandler()
r, _ := http.NewRequest("GET", "/error?code=0", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/error", &errorTestController{})
handler.ServeHTTP(w, r)
if w.Code != 404 {
t.Fail()
}
}
func TestErrorCode_03(t *testing.T) {
registerDefaultErrorHandler()
r, _ := http.NewRequest("GET", "/error?code=panic", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/error", &errorTestController{})
handler.ServeHTTP(w, r)
if w.Code != 200 {
t.Fail()
}
if string(w.Body.Bytes()) != parseCodeError {
t.Fail()
}
}

View File

@ -20,14 +20,8 @@ import (
"testing" "testing"
"github.com/astaxie/beego/context" "github.com/astaxie/beego/context"
"github.com/astaxie/beego/logs"
) )
func init() {
BeeLogger = logs.NewLogger(10000)
BeeLogger.SetLogger("console", "")
}
var FilterUser = func(ctx *context.Context) { var FilterUser = func(ctx *context.Context) {
ctx.Output.Body([]byte("i am " + ctx.Input.Param(":last") + ctx.Input.Param(":first"))) ctx.Output.Body([]byte("i am " + ctx.Input.Param(":last") + ctx.Input.Param(":first")))
} }

View File

@ -6,6 +6,8 @@ import (
"net/http" "net/http"
"path/filepath" "path/filepath"
"github.com/astaxie/beego/context"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/session" "github.com/astaxie/beego/session"
) )
@ -52,6 +54,9 @@ func registerSession() error {
"enableSetCookie": BConfig.WebConfig.Session.SessionAutoSetCookie, "enableSetCookie": BConfig.WebConfig.Session.SessionAutoSetCookie,
"domain": BConfig.WebConfig.Session.SessionDomain, "domain": BConfig.WebConfig.Session.SessionDomain,
"cookieLifeTime": BConfig.WebConfig.Session.SessionCookieLifeTime, "cookieLifeTime": BConfig.WebConfig.Session.SessionCookieLifeTime,
"enableSidInHttpHeader": BConfig.WebConfig.Session.EnableSidInHttpHeader,
"sessionNameInHttpHeader": BConfig.WebConfig.Session.SessionNameInHttpHeader,
"enableSidInUrlQuery": BConfig.WebConfig.Session.EnableSidInUrlQuery,
} }
confBytes, err := json.Marshal(conf) confBytes, err := json.Marshal(conf)
if err != nil { if err != nil {
@ -70,24 +75,27 @@ func registerSession() error {
func registerTemplate() error { func registerTemplate() error {
if err := BuildTemplate(BConfig.WebConfig.ViewsPath); err != nil { if err := BuildTemplate(BConfig.WebConfig.ViewsPath); err != nil {
if BConfig.RunMode == DEV { if BConfig.RunMode == DEV {
Warn(err) logs.Warn(err)
} }
return err return err
} }
return nil return nil
} }
func registerDocs() error {
if BConfig.WebConfig.EnableDocs {
Get("/docs", serverDocs)
Get("/docs/*", serverDocs)
}
return nil
}
func registerAdmin() error { func registerAdmin() error {
if BConfig.Listen.EnableAdmin { if BConfig.Listen.EnableAdmin {
go beeAdminApp.Run() go beeAdminApp.Run()
} }
return nil return nil
} }
func registerGzip() error {
if BConfig.EnableGzip {
context.InitGzip(
AppConfig.DefaultInt("gzipMinLength", -1),
AppConfig.DefaultInt("gzipCompressLevel", -1),
AppConfig.DefaultStrings("includedMethods", []string{"GET"}),
)
}
return nil
}

35
log.go
View File

@ -33,82 +33,77 @@ const (
) )
// BeeLogger references the used application logger. // BeeLogger references the used application logger.
var BeeLogger = logs.NewLogger(100) var BeeLogger = logs.GetBeeLogger()
// SetLevel sets the global log level used by the simple logger. // SetLevel sets the global log level used by the simple logger.
func SetLevel(l int) { func SetLevel(l int) {
BeeLogger.SetLevel(l) logs.SetLevel(l)
} }
// SetLogFuncCall set the CallDepth, default is 3 // SetLogFuncCall set the CallDepth, default is 3
func SetLogFuncCall(b bool) { func SetLogFuncCall(b bool) {
BeeLogger.EnableFuncCallDepth(b) logs.SetLogFuncCall(b)
BeeLogger.SetLogFuncCallDepth(3)
} }
// SetLogger sets a new logger. // SetLogger sets a new logger.
func SetLogger(adaptername string, config string) error { func SetLogger(adaptername string, config string) error {
err := BeeLogger.SetLogger(adaptername, config) return logs.SetLogger(adaptername, config)
if err != nil {
return err
}
return nil
} }
// Emergency logs a message at emergency level. // Emergency logs a message at emergency level.
func Emergency(v ...interface{}) { func Emergency(v ...interface{}) {
BeeLogger.Emergency(generateFmtStr(len(v)), v...) logs.Emergency(generateFmtStr(len(v)), v...)
} }
// Alert logs a message at alert level. // Alert logs a message at alert level.
func Alert(v ...interface{}) { func Alert(v ...interface{}) {
BeeLogger.Alert(generateFmtStr(len(v)), v...) logs.Alert(generateFmtStr(len(v)), v...)
} }
// Critical logs a message at critical level. // Critical logs a message at critical level.
func Critical(v ...interface{}) { func Critical(v ...interface{}) {
BeeLogger.Critical(generateFmtStr(len(v)), v...) logs.Critical(generateFmtStr(len(v)), v...)
} }
// Error logs a message at error level. // Error logs a message at error level.
func Error(v ...interface{}) { func Error(v ...interface{}) {
BeeLogger.Error(generateFmtStr(len(v)), v...) logs.Error(generateFmtStr(len(v)), v...)
} }
// Warning logs a message at warning level. // Warning logs a message at warning level.
func Warning(v ...interface{}) { func Warning(v ...interface{}) {
BeeLogger.Warning(generateFmtStr(len(v)), v...) logs.Warning(generateFmtStr(len(v)), v...)
} }
// Warn compatibility alias for Warning() // Warn compatibility alias for Warning()
func Warn(v ...interface{}) { func Warn(v ...interface{}) {
BeeLogger.Warn(generateFmtStr(len(v)), v...) logs.Warn(generateFmtStr(len(v)), v...)
} }
// Notice logs a message at notice level. // Notice logs a message at notice level.
func Notice(v ...interface{}) { func Notice(v ...interface{}) {
BeeLogger.Notice(generateFmtStr(len(v)), v...) logs.Notice(generateFmtStr(len(v)), v...)
} }
// Informational logs a message at info level. // Informational logs a message at info level.
func Informational(v ...interface{}) { func Informational(v ...interface{}) {
BeeLogger.Informational(generateFmtStr(len(v)), v...) logs.Informational(generateFmtStr(len(v)), v...)
} }
// Info compatibility alias for Warning() // Info compatibility alias for Warning()
func Info(v ...interface{}) { func Info(v ...interface{}) {
BeeLogger.Info(generateFmtStr(len(v)), v...) logs.Info(generateFmtStr(len(v)), v...)
} }
// Debug logs a message at debug level. // Debug logs a message at debug level.
func Debug(v ...interface{}) { func Debug(v ...interface{}) {
BeeLogger.Debug(generateFmtStr(len(v)), v...) logs.Debug(generateFmtStr(len(v)), v...)
} }
// Trace logs a message at trace level. // Trace logs a message at trace level.
// compatibility alias for Warning() // compatibility alias for Warning()
func Trace(v ...interface{}) { func Trace(v ...interface{}) {
BeeLogger.Trace(generateFmtStr(len(v)), v...) logs.Trace(generateFmtStr(len(v)), v...)
} }
func generateFmtStr(n int) string { func generateFmtStr(n int) string {

View File

@ -113,5 +113,5 @@ func (c *connWriter) needToConnectOnMsg() bool {
} }
func init() { func init() {
Register("conn", NewConn) Register(AdapterConn, NewConn)
} }

View File

@ -97,5 +97,5 @@ func (c *consoleWriter) Flush() {
} }
func init() { func init() {
Register("console", NewConsole) Register(AdapterConsole, NewConsole)
} }

View File

@ -76,5 +76,5 @@ func (el *esLogger) Flush() {
} }
func init() { func init() {
logs.Register("es", NewES) logs.Register(logs.AdapterEs, NewES)
} }

View File

@ -30,7 +30,7 @@ import (
// fileLogWriter implements LoggerInterface. // fileLogWriter implements LoggerInterface.
// It writes messages by lines limit, file size limit, or time frequency. // It writes messages by lines limit, file size limit, or time frequency.
type fileLogWriter struct { type fileLogWriter struct {
sync.Mutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize sync.RWMutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize
// The opened file // The opened file
Filename string `json:"filename"` Filename string `json:"filename"`
fileWriter *os.File fileWriter *os.File
@ -47,6 +47,7 @@ type fileLogWriter struct {
Daily bool `json:"daily"` Daily bool `json:"daily"`
MaxDays int64 `json:"maxdays"` MaxDays int64 `json:"maxdays"`
dailyOpenDate int dailyOpenDate int
dailyOpenTime time.Time
Rotate bool `json:"rotate"` Rotate bool `json:"rotate"`
@ -60,9 +61,6 @@ type fileLogWriter struct {
// newFileWriter create a FileLogWriter returning as LoggerInterface. // newFileWriter create a FileLogWriter returning as LoggerInterface.
func newFileWriter() Logger { func newFileWriter() Logger {
w := &fileLogWriter{ w := &fileLogWriter{
Filename: "",
MaxLines: 1000000,
MaxSize: 1 << 28, //256 MB
Daily: true, Daily: true,
MaxDays: 7, MaxDays: 7,
Rotate: true, Rotate: true,
@ -77,7 +75,7 @@ func newFileWriter() Logger {
// { // {
// "filename":"logs/beego.log", // "filename":"logs/beego.log",
// "maxLines":10000, // "maxLines":10000,
// "maxsize":1<<30, // "maxsize":1024,
// "daily":true, // "daily":true,
// "maxDays":15, // "maxDays":15,
// "rotate":true, // "rotate":true,
@ -128,7 +126,9 @@ func (w *fileLogWriter) WriteMsg(when time.Time, msg string, level int) error {
h, d := formatTimeHeader(when) h, d := formatTimeHeader(when)
msg = string(h) + msg + "\n" msg = string(h) + msg + "\n"
if w.Rotate { if w.Rotate {
w.RLock()
if w.needRotate(len(msg), d) { if w.needRotate(len(msg), d) {
w.RUnlock()
w.Lock() w.Lock()
if w.needRotate(len(msg), d) { if w.needRotate(len(msg), d) {
if err := w.doRotate(when); err != nil { if err := w.doRotate(when); err != nil {
@ -136,6 +136,8 @@ func (w *fileLogWriter) WriteMsg(when time.Time, msg string, level int) error {
} }
} }
w.Unlock() w.Unlock()
} else {
w.RUnlock()
} }
} }
@ -162,7 +164,8 @@ func (w *fileLogWriter) initFd() error {
return fmt.Errorf("get stat err: %s\n", err) return fmt.Errorf("get stat err: %s\n", err)
} }
w.maxSizeCurSize = int(fInfo.Size()) w.maxSizeCurSize = int(fInfo.Size())
w.dailyOpenDate = time.Now().Day() w.dailyOpenTime = time.Now()
w.dailyOpenDate = w.dailyOpenTime.Day()
w.maxLinesCurLines = 0 w.maxLinesCurLines = 0
if fInfo.Size() > 0 { if fInfo.Size() > 0 {
count, err := w.lines() count, err := w.lines()
@ -204,22 +207,29 @@ func (w *fileLogWriter) lines() (int, error) {
// DoRotate means it need to write file in new file. // DoRotate means it need to write file in new file.
// new file name like xx.2013-01-01.log (daily) or xx.001.log (by line or size) // new file name like xx.2013-01-01.log (daily) or xx.001.log (by line or size)
func (w *fileLogWriter) doRotate(logTime time.Time) error { func (w *fileLogWriter) doRotate(logTime time.Time) error {
_, err := os.Lstat(w.Filename)
if err != nil {
return err
}
// file exists // file exists
// Find the next available number // Find the next available number
num := 1 num := 1
fName := "" fName := ""
_, err := os.Lstat(w.Filename)
if err != nil {
//even if the file is not exist or other ,we should RESTART the logger
goto RESTART_LOGGER
}
if w.MaxLines > 0 || w.MaxSize > 0 { if w.MaxLines > 0 || w.MaxSize > 0 {
for ; err == nil && num <= 999; num++ { for ; err == nil && num <= 999; num++ {
fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", logTime.Format("2006-01-02"), num, w.suffix) fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", logTime.Format("2006-01-02"), num, w.suffix)
_, err = os.Lstat(fName) _, err = os.Lstat(fName)
} }
} else { } else {
fName = fmt.Sprintf("%s.%s%s", w.fileNameOnly, logTime.Format("2006-01-02"), w.suffix) fName = fmt.Sprintf("%s.%s%s", w.fileNameOnly, w.dailyOpenTime.Format("2006-01-02"), w.suffix)
_, err = os.Lstat(fName) _, err = os.Lstat(fName)
for ; err == nil && num <= 999; num++ {
fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", w.dailyOpenTime.Format("2006-01-02"), num, w.suffix)
_, err = os.Lstat(fName)
}
} }
// return error if the last file checked still existed // return error if the last file checked still existed
if err == nil { if err == nil {
@ -231,16 +241,18 @@ func (w *fileLogWriter) doRotate(logTime time.Time) error {
// Rename the file to its new found name // Rename the file to its new found name
// even if occurs error,we MUST guarantee to restart new logger // even if occurs error,we MUST guarantee to restart new logger
renameErr := os.Rename(w.Filename, fName) err = os.Rename(w.Filename, fName)
// re-start logger // re-start logger
RESTART_LOGGER:
startLoggerErr := w.startLogger() startLoggerErr := w.startLogger()
go w.deleteOldLog() go w.deleteOldLog()
if startLoggerErr != nil { if startLoggerErr != nil {
return fmt.Errorf("Rotate StartLogger: %s\n", startLoggerErr) return fmt.Errorf("Rotate StartLogger: %s\n", startLoggerErr)
} }
if renameErr != nil { if err != nil {
return fmt.Errorf("Rotate: %s\n", renameErr) return fmt.Errorf("Rotate: %s\n", err)
} }
return nil return nil
@ -255,8 +267,12 @@ func (w *fileLogWriter) deleteOldLog() {
} }
}() }()
if !info.IsDir() && info.ModTime().Unix() < (time.Now().Unix()-60*60*24*w.MaxDays) { if info == nil {
if strings.HasPrefix(filepath.Base(path), w.fileNameOnly) && return
}
if !info.IsDir() && info.ModTime().Add(24*time.Hour*time.Duration(w.MaxDays)).Before(time.Now()) {
if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) &&
strings.HasSuffix(filepath.Base(path), w.suffix) { strings.HasSuffix(filepath.Base(path), w.suffix) {
os.Remove(path) os.Remove(path)
} }
@ -278,5 +294,5 @@ func (w *fileLogWriter) Flush() {
} }
func init() { func init() {
Register("file", newFileWriter) Register(AdapterFile, newFileWriter)
} }

View File

@ -89,7 +89,7 @@ func TestFile2(t *testing.T) {
os.Remove("test2.log") os.Remove("test2.log")
} }
func TestFileRotate(t *testing.T) { func TestFileRotate_01(t *testing.T) {
log := NewLogger(10000) log := NewLogger(10000)
log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`) log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`)
log.Debug("debug") log.Debug("debug")
@ -110,6 +110,43 @@ func TestFileRotate(t *testing.T) {
os.Remove("test3.log") os.Remove("test3.log")
} }
func TestFileRotate_02(t *testing.T) {
fn1 := "rotate_day.log"
fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log"
testFileRotate(t, fn1, fn2)
}
func TestFileRotate_03(t *testing.T) {
fn1 := "rotate_day.log"
fn := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log"
os.Create(fn)
fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log"
testFileRotate(t, fn1, fn2)
os.Remove(fn)
}
func testFileRotate(t *testing.T, fn1, fn2 string) {
fw := &fileLogWriter{
Daily: true,
MaxDays: 7,
Rotate: true,
Level: LevelTrace,
Perm: 0660,
}
fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1))
fw.dailyOpenTime = time.Now().Add(-24 * time.Hour)
fw.dailyOpenDate = fw.dailyOpenTime.Day()
fw.WriteMsg(time.Now(), "this is a msg for test", LevelDebug)
for _, file := range []string{fn1, fn2} {
_, err := os.Stat(file)
if err != nil {
t.FailNow()
}
os.Remove(file)
}
}
func exists(path string) (bool, error) { func exists(path string) (bool, error) {
_, err := os.Stat(path) _, err := os.Stat(path)
if err == nil { if err == nil {

View File

@ -35,10 +35,12 @@ package logs
import ( import (
"fmt" "fmt"
"log"
"os" "os"
"path" "path"
"runtime" "runtime"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
) )
@ -55,16 +57,28 @@ const (
LevelDebug LevelDebug
) )
// Legacy loglevel constants to ensure backwards compatibility. // levelLogLogger is defined to implement log.Logger
// // the real log level will be LevelEmergency
// Deprecated: will be removed in 1.5.0. const levelLoggerImpl = -1
// Name for adapter with beego official support
const (
AdapterConsole = "console"
AdapterFile = "file"
AdapterMultiFile = "multifile"
AdapterMail = "stmp"
AdapterConn = "conn"
AdapterEs = "es"
)
// Legacy log level constants to ensure backwards compatibility.
const ( const (
LevelInfo = LevelInformational LevelInfo = LevelInformational
LevelTrace = LevelDebug LevelTrace = LevelDebug
LevelWarn = LevelWarning LevelWarn = LevelWarning
) )
type loggerType func() Logger type newLoggerFunc func() Logger
// Logger defines the behavior of a log provider. // Logger defines the behavior of a log provider.
type Logger interface { type Logger interface {
@ -74,12 +88,13 @@ type Logger interface {
Flush() Flush()
} }
var adapters = make(map[string]loggerType) var adapters = make(map[string]newLoggerFunc)
var levelPrefix = [LevelDebug + 1]string{"[M] ", "[A] ", "[C] ", "[E] ", "[W] ", "[N] ", "[I] ", "[D] "}
// Register makes a log provide available by the provided name. // Register makes a log provide available by the provided name.
// If Register is called twice with the same name or if driver is nil, // If Register is called twice with the same name or if driver is nil,
// it panics. // it panics.
func Register(name string, log loggerType) { func Register(name string, log newLoggerFunc) {
if log == nil { if log == nil {
panic("logs: Register provide is nil") panic("logs: Register provide is nil")
} }
@ -94,15 +109,19 @@ func Register(name string, log loggerType) {
type BeeLogger struct { type BeeLogger struct {
lock sync.Mutex lock sync.Mutex
level int level int
init bool
enableFuncCallDepth bool enableFuncCallDepth bool
loggerFuncCallDepth int loggerFuncCallDepth int
asynchronous bool asynchronous bool
msgChanLen int64
msgChan chan *logMsg msgChan chan *logMsg
signalChan chan string signalChan chan string
wg sync.WaitGroup wg sync.WaitGroup
outputs []*nameLogger outputs []*nameLogger
} }
const defaultAsyncMsgLen = 1e3
type nameLogger struct { type nameLogger struct {
Logger Logger
name string name string
@ -119,18 +138,31 @@ var logMsgPool *sync.Pool
// NewLogger returns a new BeeLogger. // NewLogger returns a new BeeLogger.
// channelLen means the number of messages in chan(used where asynchronous is true). // channelLen means the number of messages in chan(used where asynchronous is true).
// if the buffering chan is full, logger adapters write to file or other way. // if the buffering chan is full, logger adapters write to file or other way.
func NewLogger(channelLen int64) *BeeLogger { func NewLogger(channelLens ...int64) *BeeLogger {
bl := new(BeeLogger) bl := new(BeeLogger)
bl.level = LevelDebug bl.level = LevelDebug
bl.loggerFuncCallDepth = 2 bl.loggerFuncCallDepth = 2
bl.msgChan = make(chan *logMsg, channelLen) bl.msgChanLen = append(channelLens, 0)[0]
if bl.msgChanLen <= 0 {
bl.msgChanLen = defaultAsyncMsgLen
}
bl.signalChan = make(chan string, 1) bl.signalChan = make(chan string, 1)
bl.setLogger(AdapterConsole)
return bl return bl
} }
// Async set the log to asynchronous and start the goroutine // Async set the log to asynchronous and start the goroutine
func (bl *BeeLogger) Async() *BeeLogger { func (bl *BeeLogger) Async(msgLen ...int64) *BeeLogger {
bl.lock.Lock()
defer bl.lock.Unlock()
if bl.asynchronous {
return bl
}
bl.asynchronous = true bl.asynchronous = true
if len(msgLen) > 0 && msgLen[0] > 0 {
bl.msgChanLen = msgLen[0]
}
bl.msgChan = make(chan *logMsg, bl.msgChanLen)
logMsgPool = &sync.Pool{ logMsgPool = &sync.Pool{
New: func() interface{} { New: func() interface{} {
return &logMsg{} return &logMsg{}
@ -143,10 +175,8 @@ func (bl *BeeLogger) Async() *BeeLogger {
// SetLogger provides a given logger adapter into BeeLogger with config string. // SetLogger provides a given logger adapter into BeeLogger with config string.
// config need to be correct JSON as string: {"interval":360}. // config need to be correct JSON as string: {"interval":360}.
func (bl *BeeLogger) SetLogger(adapterName string, config string) error { func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error {
bl.lock.Lock() config := append(configs, "{}")[0]
defer bl.lock.Unlock()
for _, l := range bl.outputs { for _, l := range bl.outputs {
if l.name == adapterName { if l.name == adapterName {
return fmt.Errorf("logs: duplicate adaptername %q (you have set this logger before)", adapterName) return fmt.Errorf("logs: duplicate adaptername %q (you have set this logger before)", adapterName)
@ -168,6 +198,18 @@ func (bl *BeeLogger) SetLogger(adapterName string, config string) error {
return nil return nil
} }
// SetLogger provides a given logger adapter into BeeLogger with config string.
// config need to be correct JSON as string: {"interval":360}.
func (bl *BeeLogger) SetLogger(adapterName string, configs ...string) error {
bl.lock.Lock()
defer bl.lock.Unlock()
if !bl.init {
bl.outputs = []*nameLogger{}
bl.init = true
}
return bl.setLogger(adapterName, configs...)
}
// DelLogger remove a logger adapter in BeeLogger. // DelLogger remove a logger adapter in BeeLogger.
func (bl *BeeLogger) DelLogger(adapterName string) error { func (bl *BeeLogger) DelLogger(adapterName string) error {
bl.lock.Lock() bl.lock.Lock()
@ -196,7 +238,37 @@ func (bl *BeeLogger) writeToLoggers(when time.Time, msg string, level int) {
} }
} }
func (bl *BeeLogger) writeMsg(logLevel int, msg string) error { func (bl *BeeLogger) Write(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
// writeMsg will always add a '\n' character
if p[len(p)-1] == '\n' {
p = p[0 : len(p)-1]
}
// set levelLoggerImpl to ensure all log message will be write out
err = bl.writeMsg(levelLoggerImpl, string(p))
if err == nil {
return len(p), err
}
return 0, err
}
func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error {
if !bl.init {
bl.lock.Lock()
bl.setLogger(AdapterConsole)
bl.lock.Unlock()
}
if logLevel == levelLoggerImpl {
// set to emergency to ensure all log will be print out correctly
logLevel = LevelEmergency
} else {
msg = levelPrefix[logLevel] + msg
}
if len(v) > 0 {
msg = fmt.Sprintf(msg, v...)
}
when := time.Now() when := time.Now()
if bl.enableFuncCallDepth { if bl.enableFuncCallDepth {
_, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth) _, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth)
@ -205,7 +277,7 @@ func (bl *BeeLogger) writeMsg(logLevel int, msg string) error {
line = 0 line = 0
} }
_, filename := path.Split(file) _, filename := path.Split(file)
msg = "[" + filename + ":" + strconv.FormatInt(int64(line), 10) + "]" + msg msg = "[" + filename + ":" + strconv.FormatInt(int64(line), 10) + "] " + msg
} }
if bl.asynchronous { if bl.asynchronous {
lm := logMsgPool.Get().(*logMsg) lm := logMsgPool.Get().(*logMsg)
@ -273,8 +345,7 @@ func (bl *BeeLogger) Emergency(format string, v ...interface{}) {
if LevelEmergency > bl.level { if LevelEmergency > bl.level {
return return
} }
msg := fmt.Sprintf("[M] "+format, v...) bl.writeMsg(LevelEmergency, format, v...)
bl.writeMsg(LevelEmergency, msg)
} }
// Alert Log ALERT level message. // Alert Log ALERT level message.
@ -282,8 +353,7 @@ func (bl *BeeLogger) Alert(format string, v ...interface{}) {
if LevelAlert > bl.level { if LevelAlert > bl.level {
return return
} }
msg := fmt.Sprintf("[A] "+format, v...) bl.writeMsg(LevelAlert, format, v...)
bl.writeMsg(LevelAlert, msg)
} }
// Critical Log CRITICAL level message. // Critical Log CRITICAL level message.
@ -291,8 +361,7 @@ func (bl *BeeLogger) Critical(format string, v ...interface{}) {
if LevelCritical > bl.level { if LevelCritical > bl.level {
return return
} }
msg := fmt.Sprintf("[C] "+format, v...) bl.writeMsg(LevelCritical, format, v...)
bl.writeMsg(LevelCritical, msg)
} }
// Error Log ERROR level message. // Error Log ERROR level message.
@ -300,17 +369,12 @@ func (bl *BeeLogger) Error(format string, v ...interface{}) {
if LevelError > bl.level { if LevelError > bl.level {
return return
} }
msg := fmt.Sprintf("[E] "+format, v...) bl.writeMsg(LevelError, format, v...)
bl.writeMsg(LevelError, msg)
} }
// Warning Log WARNING level message. // Warning Log WARNING level message.
func (bl *BeeLogger) Warning(format string, v ...interface{}) { func (bl *BeeLogger) Warning(format string, v ...interface{}) {
if LevelWarning > bl.level { bl.Warn(format, v...)
return
}
msg := fmt.Sprintf("[W] "+format, v...)
bl.writeMsg(LevelWarning, msg)
} }
// Notice Log NOTICE level message. // Notice Log NOTICE level message.
@ -318,17 +382,12 @@ func (bl *BeeLogger) Notice(format string, v ...interface{}) {
if LevelNotice > bl.level { if LevelNotice > bl.level {
return return
} }
msg := fmt.Sprintf("[N] "+format, v...) bl.writeMsg(LevelNotice, format, v...)
bl.writeMsg(LevelNotice, msg)
} }
// Informational Log INFORMATIONAL level message. // Informational Log INFORMATIONAL level message.
func (bl *BeeLogger) Informational(format string, v ...interface{}) { func (bl *BeeLogger) Informational(format string, v ...interface{}) {
if LevelInformational > bl.level { bl.Info(format, v...)
return
}
msg := fmt.Sprintf("[I] "+format, v...)
bl.writeMsg(LevelInformational, msg)
} }
// Debug Log DEBUG level message. // Debug Log DEBUG level message.
@ -336,38 +395,31 @@ func (bl *BeeLogger) Debug(format string, v ...interface{}) {
if LevelDebug > bl.level { if LevelDebug > bl.level {
return return
} }
msg := fmt.Sprintf("[D] "+format, v...) bl.writeMsg(LevelDebug, format, v...)
bl.writeMsg(LevelDebug, msg)
} }
// Warn Log WARN level message. // Warn Log WARN level message.
// compatibility alias for Warning() // compatibility alias for Warning()
func (bl *BeeLogger) Warn(format string, v ...interface{}) { func (bl *BeeLogger) Warn(format string, v ...interface{}) {
if LevelWarning > bl.level { if LevelWarn > bl.level {
return return
} }
msg := fmt.Sprintf("[W] "+format, v...) bl.writeMsg(LevelWarn, format, v...)
bl.writeMsg(LevelWarning, msg)
} }
// Info Log INFO level message. // Info Log INFO level message.
// compatibility alias for Informational() // compatibility alias for Informational()
func (bl *BeeLogger) Info(format string, v ...interface{}) { func (bl *BeeLogger) Info(format string, v ...interface{}) {
if LevelInformational > bl.level { if LevelInfo > bl.level {
return return
} }
msg := fmt.Sprintf("[I] "+format, v...) bl.writeMsg(LevelInfo, format, v...)
bl.writeMsg(LevelInformational, msg)
} }
// Trace Log TRACE level message. // Trace Log TRACE level message.
// compatibility alias for Debug() // compatibility alias for Debug()
func (bl *BeeLogger) Trace(format string, v ...interface{}) { func (bl *BeeLogger) Trace(format string, v ...interface{}) {
if LevelDebug > bl.level { bl.Debug(format, v...)
return
}
msg := fmt.Sprintf("[D] "+format, v...)
bl.writeMsg(LevelDebug, msg)
} }
// Flush flush all chan data. // Flush flush all chan data.
@ -386,6 +438,7 @@ func (bl *BeeLogger) Close() {
if bl.asynchronous { if bl.asynchronous {
bl.signalChan <- "close" bl.signalChan <- "close"
bl.wg.Wait() bl.wg.Wait()
close(bl.msgChan)
} else { } else {
bl.flush() bl.flush()
for _, l := range bl.outputs { for _, l := range bl.outputs {
@ -393,7 +446,6 @@ func (bl *BeeLogger) Close() {
} }
bl.outputs = nil bl.outputs = nil
} }
close(bl.msgChan)
close(bl.signalChan) close(bl.signalChan)
} }
@ -407,6 +459,7 @@ func (bl *BeeLogger) Reset() {
} }
func (bl *BeeLogger) flush() { func (bl *BeeLogger) flush() {
if bl.asynchronous {
for { for {
if len(bl.msgChan) > 0 { if len(bl.msgChan) > 0 {
bm := <-bl.msgChan bm := <-bl.msgChan
@ -416,7 +469,165 @@ func (bl *BeeLogger) flush() {
} }
break break
} }
}
for _, l := range bl.outputs { for _, l := range bl.outputs {
l.Flush() l.Flush()
} }
} }
// beeLogger references the used application logger.
var beeLogger *BeeLogger = NewLogger()
// GetLogger returns the default BeeLogger
func GetBeeLogger() *BeeLogger {
return beeLogger
}
var beeLoggerMap = struct {
sync.RWMutex
logs map[string]*log.Logger
}{
logs: map[string]*log.Logger{},
}
// GetLogger returns the default BeeLogger
func GetLogger(prefixes ...string) *log.Logger {
prefix := append(prefixes, "")[0]
if prefix != "" {
prefix = fmt.Sprintf(`[%s] `, strings.ToUpper(prefix))
}
beeLoggerMap.RLock()
l, ok := beeLoggerMap.logs[prefix]
if ok {
beeLoggerMap.RUnlock()
return l
}
beeLoggerMap.RUnlock()
beeLoggerMap.Lock()
defer beeLoggerMap.Unlock()
l, ok = beeLoggerMap.logs[prefix]
if !ok {
l = log.New(beeLogger, prefix, 0)
beeLoggerMap.logs[prefix] = l
}
return l
}
// Reset will remove all the adapter
func Reset() {
beeLogger.Reset()
}
func Async(msgLen ...int64) *BeeLogger {
return beeLogger.Async(msgLen...)
}
// SetLevel sets the global log level used by the simple logger.
func SetLevel(l int) {
beeLogger.SetLevel(l)
}
// EnableFuncCallDepth enable log funcCallDepth
func EnableFuncCallDepth(b bool) {
beeLogger.enableFuncCallDepth = b
}
// SetLogFuncCall set the CallDepth, default is 3
func SetLogFuncCall(b bool) {
beeLogger.EnableFuncCallDepth(b)
beeLogger.SetLogFuncCallDepth(3)
}
// SetLogFuncCallDepth set log funcCallDepth
func SetLogFuncCallDepth(d int) {
beeLogger.loggerFuncCallDepth = d
}
// SetLogger sets a new logger.
func SetLogger(adapter string, config ...string) error {
err := beeLogger.SetLogger(adapter, config...)
if err != nil {
return err
}
return nil
}
// Emergency logs a message at emergency level.
func Emergency(f interface{}, v ...interface{}) {
beeLogger.Emergency(formatLog(f, v...))
}
// Alert logs a message at alert level.
func Alert(f interface{}, v ...interface{}) {
beeLogger.Alert(formatLog(f, v...))
}
// Critical logs a message at critical level.
func Critical(f interface{}, v ...interface{}) {
beeLogger.Critical(formatLog(f, v...))
}
// Error logs a message at error level.
func Error(f interface{}, v ...interface{}) {
beeLogger.Error(formatLog(f, v...))
}
// Warning logs a message at warning level.
func Warning(f interface{}, v ...interface{}) {
beeLogger.Warn(formatLog(f, v...))
}
// Warn compatibility alias for Warning()
func Warn(f interface{}, v ...interface{}) {
beeLogger.Warn(formatLog(f, v...))
}
// Notice logs a message at notice level.
func Notice(f interface{}, v ...interface{}) {
beeLogger.Notice(formatLog(f, v...))
}
// Informational logs a message at info level.
func Informational(f interface{}, v ...interface{}) {
beeLogger.Info(formatLog(f, v...))
}
// Info compatibility alias for Warning()
func Info(f interface{}, v ...interface{}) {
beeLogger.Info(formatLog(f, v...))
}
// Debug logs a message at debug level.
func Debug(f interface{}, v ...interface{}) {
beeLogger.Debug(formatLog(f, v...))
}
// Trace logs a message at trace level.
// compatibility alias for Warning()
func Trace(f interface{}, v ...interface{}) {
beeLogger.Trace(formatLog(f, v...))
}
func formatLog(f interface{}, v ...interface{}) string {
var msg string
switch f.(type) {
case string:
msg = f.(string)
if len(v) == 0 {
return msg
}
if strings.Contains(msg, "%") && !strings.Contains(msg, "%%") {
//format string
} else {
//do not contain format char
msg += strings.Repeat(" %v", len(v))
}
default:
msg = fmt.Sprint(f)
if len(v) == 0 {
return msg
}
msg += strings.Repeat(" %v", len(v))
}
return fmt.Sprintf(msg, v...)
}

View File

@ -36,43 +36,46 @@ func (lg *logWriter) println(when time.Time, msg string) {
lg.Unlock() lg.Unlock()
} }
const y1 = `0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999`
const y2 = `0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789`
const mo1 = `000000000111`
const mo2 = `123456789012`
const d1 = `0000000001111111111222222222233`
const d2 = `1234567890123456789012345678901`
const h1 = `000000000011111111112222`
const h2 = `012345678901234567890123`
const mi1 = `000000000011111111112222222222333333333344444444445555555555`
const mi2 = `012345678901234567890123456789012345678901234567890123456789`
const s1 = `000000000011111111112222222222333333333344444444445555555555`
const s2 = `012345678901234567890123456789012345678901234567890123456789`
func formatTimeHeader(when time.Time) ([]byte, int) { func formatTimeHeader(when time.Time) ([]byte, int) {
y, mo, d := when.Date() y, mo, d := when.Date()
h, mi, s := when.Clock() h, mi, s := when.Clock()
//len(2006/01/02 15:03:04)==19 //len("2006/01/02 15:04:05 ")==20
var buf [20]byte var buf [20]byte
t := 3
for y >= 10 { //change to '3' after 984 years, LOL
p := y / 10 buf[0] = '2'
buf[t] = byte('0' + y - p*10) //change to '1' after 84 years, LOL
y = p buf[1] = '0'
t-- buf[2] = y1[y-2000]
} buf[3] = y2[y-2000]
buf[0] = byte('0' + y)
buf[4] = '/' buf[4] = '/'
if mo > 9 { buf[5] = mo1[mo-1]
buf[5] = '1' buf[6] = mo2[mo-1]
buf[6] = byte('0' + mo - 9)
} else {
buf[5] = '0'
buf[6] = byte('0' + mo)
}
buf[7] = '/' buf[7] = '/'
t = d / 10 buf[8] = d1[d-1]
buf[8] = byte('0' + t) buf[9] = d2[d-1]
buf[9] = byte('0' + d - t*10)
buf[10] = ' ' buf[10] = ' '
t = h / 10 buf[11] = h1[h]
buf[11] = byte('0' + t) buf[12] = h2[h]
buf[12] = byte('0' + h - t*10)
buf[13] = ':' buf[13] = ':'
t = mi / 10 buf[14] = mi1[mi]
buf[14] = byte('0' + t) buf[15] = mi2[mi]
buf[15] = byte('0' + mi - t*10)
buf[16] = ':' buf[16] = ':'
t = s / 10 buf[17] = s1[s]
buf[17] = byte('0' + t) buf[18] = s2[s]
buf[18] = byte('0' + s - t*10)
buf[19] = ' ' buf[19] = ' '
return buf[0:], d return buf[0:], d

57
logs/logger_test.go Normal file
View File

@ -0,0 +1,57 @@
// Copyright 2016 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package logs
import (
"testing"
"time"
)
func TestFormatHeader_0(t *testing.T) {
tm := time.Now()
if tm.Year() >= 2100 {
t.FailNow()
}
dur := time.Second
for {
if tm.Year() >= 2100 {
break
}
h, _ := formatTimeHeader(tm)
if tm.Format("2006/01/02 15:04:05 ") != string(h) {
t.Log(tm)
t.FailNow()
}
tm = tm.Add(dur)
dur *= 2
}
}
func TestFormatHeader_1(t *testing.T) {
tm := time.Now()
year := tm.Year()
dur := time.Second
for {
if tm.Year() >= year+1 {
break
}
h, _ := formatTimeHeader(tm)
if tm.Format("2006/01/02 15:04:05 ") != string(h) {
t.Log(tm)
t.FailNow()
}
tm = tm.Add(dur)
}
}

View File

@ -112,5 +112,5 @@ func newFilesWriter() Logger {
} }
func init() { func init() {
Register("multifile", newFilesWriter) Register(AdapterMultiFile, newFilesWriter)
} }

View File

@ -156,5 +156,5 @@ func (s *SMTPWriter) Destroy() {
} }
func init() { func init() {
Register("smtp", newSMTPWriter) Register(AdapterMail, newSMTPWriter)
} }

View File

@ -33,7 +33,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/astaxie/beego" "github.com/astaxie/beego/logs"
"github.com/astaxie/beego/orm" "github.com/astaxie/beego/orm"
) )
@ -90,7 +90,7 @@ func (m *Migration) Reset() {
func (m *Migration) Exec(name, status string) error { func (m *Migration) Exec(name, status string) error {
o := orm.NewOrm() o := orm.NewOrm()
for _, s := range m.sqls { for _, s := range m.sqls {
beego.Info("exec sql:", s) logs.Info("exec sql:", s)
r := o.Raw(s) r := o.Raw(s)
_, err := r.Exec() _, err := r.Exec()
if err != nil { if err != nil {
@ -144,20 +144,20 @@ func Upgrade(lasttime int64) error {
i := 0 i := 0
for _, v := range sm { for _, v := range sm {
if v.created > lasttime { if v.created > lasttime {
beego.Info("start upgrade", v.name) logs.Info("start upgrade", v.name)
v.m.Reset() v.m.Reset()
v.m.Up() v.m.Up()
err := v.m.Exec(v.name, "up") err := v.m.Exec(v.name, "up")
if err != nil { if err != nil {
beego.Error("execute error:", err) logs.Error("execute error:", err)
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
return err return err
} }
beego.Info("end upgrade:", v.name) logs.Info("end upgrade:", v.name)
i++ i++
} }
} }
beego.Info("total success upgrade:", i, " migration") logs.Info("total success upgrade:", i, " migration")
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
return nil return nil
} }
@ -165,20 +165,20 @@ func Upgrade(lasttime int64) error {
// Rollback rollback the migration by the name // Rollback rollback the migration by the name
func Rollback(name string) error { func Rollback(name string) error {
if v, ok := migrationMap[name]; ok { if v, ok := migrationMap[name]; ok {
beego.Info("start rollback") logs.Info("start rollback")
v.Reset() v.Reset()
v.Down() v.Down()
err := v.Exec(name, "down") err := v.Exec(name, "down")
if err != nil { if err != nil {
beego.Error("execute error:", err) logs.Error("execute error:", err)
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
return err return err
} }
beego.Info("end rollback") logs.Info("end rollback")
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
return nil return nil
} }
beego.Error("not exist the migrationMap name:" + name) logs.Error("not exist the migrationMap name:" + name)
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
return errors.New("not exist the migrationMap name:" + name) return errors.New("not exist the migrationMap name:" + name)
} }
@ -191,23 +191,23 @@ func Reset() error {
for j := len(sm) - 1; j >= 0; j-- { for j := len(sm) - 1; j >= 0; j-- {
v := sm[j] v := sm[j]
if isRollBack(v.name) { if isRollBack(v.name) {
beego.Info("skip the", v.name) logs.Info("skip the", v.name)
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
continue continue
} }
beego.Info("start reset:", v.name) logs.Info("start reset:", v.name)
v.m.Reset() v.m.Reset()
v.m.Down() v.m.Down()
err := v.m.Exec(v.name, "down") err := v.m.Exec(v.name, "down")
if err != nil { if err != nil {
beego.Error("execute error:", err) logs.Error("execute error:", err)
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
return err return err
} }
i++ i++
beego.Info("end reset:", v.name) logs.Info("end reset:", v.name)
} }
beego.Info("total success reset:", i, " migration") logs.Info("total success reset:", i, " migration")
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
return nil return nil
} }
@ -216,7 +216,7 @@ func Reset() error {
func Refresh() error { func Refresh() error {
err := Reset() err := Reset()
if err != nil { if err != nil {
beego.Error("execute error:", err) logs.Error("execute error:", err)
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
return err return err
} }
@ -265,7 +265,7 @@ func isRollBack(name string) bool {
var maps []orm.Params var maps []orm.Params
num, err := o.Raw("select * from migrations where `name` = ? order by id_migration desc", name).Values(&maps) num, err := o.Raw("select * from migrations where `name` = ? order by id_migration desc", name).Values(&maps)
if err != nil { if err != nil {
beego.Info("get name has error", err) logs.Info("get name has error", err)
return false return false
} }
if num <= 0 { if num <= 0 {

View File

@ -44,7 +44,7 @@ func NewNamespace(prefix string, params ...LinkNamespace) *Namespace {
return ns return ns
} }
// Cond set condtion function // Cond set condition function
// if cond return true can run this namespace, else can't // if cond return true can run this namespace, else can't
// usage: // usage:
// ns.Cond(func (ctx *context.Context) bool{ // ns.Cond(func (ctx *context.Context) bool{
@ -60,7 +60,7 @@ func (n *Namespace) Cond(cond namespaceCond) *Namespace {
exception("405", ctx) exception("405", ctx)
} }
} }
if v, ok := n.handlers.filters[BeforeRouter]; ok { if v := n.handlers.filters[BeforeRouter]; len(v) > 0 {
mr := new(FilterRouter) mr := new(FilterRouter)
mr.tree = NewTree() mr.tree = NewTree()
mr.pattern = "*" mr.pattern = "*"

View File

@ -52,9 +52,15 @@ checkColumn:
case TypeBooleanField: case TypeBooleanField:
col = T["bool"] col = T["bool"]
case TypeCharField: case TypeCharField:
if al.Driver == DRPostgres && fi.toText {
col = T["string-text"]
} else {
col = fmt.Sprintf(T["string"], fieldSize) col = fmt.Sprintf(T["string"], fieldSize)
}
case TypeTextField: case TypeTextField:
col = T["string-text"] col = T["string-text"]
case TypeTimeField:
col = T["time.Time-clock"]
case TypeDateField: case TypeDateField:
col = T["time.Time-date"] col = T["time.Time-date"]
case TypeDateTimeField: case TypeDateTimeField:
@ -88,6 +94,18 @@ checkColumn:
} else { } else {
col = fmt.Sprintf(s, fi.digits, fi.decimals) col = fmt.Sprintf(s, fi.digits, fi.decimals)
} }
case TypeJSONField:
if al.Driver != DRPostgres {
fieldType = TypeCharField
goto checkColumn
}
col = T["json"]
case TypeJsonbField:
if al.Driver != DRPostgres {
fieldType = TypeCharField
goto checkColumn
}
col = T["jsonb"]
case RelForeignKey, RelOneToOne: case RelForeignKey, RelOneToOne:
fieldType = fi.relModelInfo.fields.pk.fieldType fieldType = fi.relModelInfo.fields.pk.fieldType
fieldSize = fi.relModelInfo.fields.pk.size fieldSize = fi.relModelInfo.fields.pk.size
@ -264,7 +282,7 @@ func getColumnDefault(fi *fieldInfo) string {
// These defaults will be useful if there no config value orm:"default" and NOT NULL is on // These defaults will be useful if there no config value orm:"default" and NOT NULL is on
switch fi.fieldType { switch fi.fieldType {
case TypeDateField, TypeDateTimeField, TypeTextField: case TypeTimeField, TypeDateField, TypeDateTimeField, TypeTextField:
return v return v
case TypeBitField, TypeSmallIntegerField, TypeIntegerField, case TypeBitField, TypeSmallIntegerField, TypeIntegerField,
@ -276,6 +294,8 @@ func getColumnDefault(fi *fieldInfo) string {
case TypeBooleanField: case TypeBooleanField:
t = " DEFAULT %s " t = " DEFAULT %s "
d = "FALSE" d = "FALSE"
case TypeJSONField, TypeJsonbField:
d = "{}"
} }
if fi.colDefault { if fi.colDefault {

116
orm/db.go
View File

@ -24,6 +24,7 @@ import (
) )
const ( const (
formatTime = "15:04:05"
formatDate = "2006-01-02" formatDate = "2006-01-02"
formatDateTime = "2006-01-02 15:04:05" formatDateTime = "2006-01-02 15:04:05"
) )
@ -71,12 +72,12 @@ type dbBase struct {
var _ dbBaser = new(dbBase) var _ dbBaser = new(dbBase)
// get struct columns values as interface slice. // get struct columns values as interface slice.
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, err error) { func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, autoFields []string, err error) {
var columns []string if names == nil {
ns := make([]string, 0, len(cols))
if names != nil { names = &ns
columns = *names
} }
values = make([]interface{}, 0, len(cols))
for _, column := range cols { for _, column := range cols {
var fi *fieldInfo var fi *fieldInfo
@ -90,18 +91,24 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
} }
value, err := d.collectFieldValue(mi, fi, ind, insert, tz) value, err := d.collectFieldValue(mi, fi, ind, insert, tz)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
if names != nil { // ignore empty value auto field
columns = append(columns, column) if insert && fi.auto {
if fi.fieldType&IsPositiveIntegerField > 0 {
if vu, ok := value.(uint64); !ok || vu == 0 {
continue
}
} else {
if vu, ok := value.(int64); !ok || vu == 0 {
continue
}
}
autoFields = append(autoFields, fi.column)
} }
values = append(values, value) *names, values = append(*names, column), append(values, value)
}
if names != nil {
*names = columns
} }
return return
@ -134,7 +141,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
} else { } else {
value = field.Bool() value = field.Bool()
} }
case TypeCharField, TypeTextField: case TypeCharField, TypeTextField, TypeJSONField, TypeJsonbField:
if ns, ok := field.Interface().(sql.NullString); ok { if ns, ok := field.Interface().(sql.NullString); ok {
value = nil value = nil
if ns.Valid { if ns.Valid {
@ -169,7 +176,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
value = field.Float() value = field.Float()
} }
} }
case TypeDateField, TypeDateTimeField: case TypeTimeField, TypeDateField, TypeDateTimeField:
value = field.Interface() value = field.Interface()
if t, ok := value.(time.Time); ok { if t, ok := value.(time.Time); ok {
d.ins.TimeToDB(&t, tz) d.ins.TimeToDB(&t, tz)
@ -181,7 +188,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
} }
default: default:
switch { switch {
case fi.fieldType&IsPostiveIntegerField > 0: case fi.fieldType&IsPositiveIntegerField > 0:
if field.Kind() == reflect.Ptr { if field.Kind() == reflect.Ptr {
if field.IsNil() { if field.IsNil() {
value = nil value = nil
@ -223,7 +230,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
} }
} }
switch fi.fieldType { switch fi.fieldType {
case TypeDateField, TypeDateTimeField: case TypeTimeField, TypeDateField, TypeDateTimeField:
if fi.autoNow || fi.autoNowAdd && insert { if fi.autoNow || fi.autoNowAdd && insert {
if insert { if insert {
if t, ok := value.(time.Time); ok && !t.IsZero() { if t, ok := value.(time.Time); ok && !t.IsZero() {
@ -240,6 +247,14 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
field.Set(reflect.ValueOf(tnow.In(DefaultTimeLoc))) field.Set(reflect.ValueOf(tnow.In(DefaultTimeLoc)))
} }
} }
case TypeJSONField, TypeJsonbField:
if s, ok := value.(string); (ok && len(s) == 0) || value == nil {
if fi.colDefault && fi.initial.Exist() {
value = fi.initial.String()
} else {
value = nil
}
}
} }
} }
return value, nil return value, nil
@ -273,7 +288,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
// insert struct with prepared statement and given struct reflect value. // insert struct with prepared statement and given struct reflect value.
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -300,7 +315,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
if len(cols) > 0 { if len(cols) > 0 {
var err error var err error
whereCols = make([]string, 0, len(cols)) whereCols = make([]string, 0, len(cols))
args, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz) args, _, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz)
if err != nil { if err != nil {
return err return err
} }
@ -349,13 +364,21 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
// execute insert sql dbQuerier with given struct reflect.Value. // execute insert sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
names := make([]string, 0, len(mi.fields.dbcols)-1) names := make([]string, 0, len(mi.fields.dbcols))
values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz) values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
return d.InsertValue(q, mi, false, names, values) id, err := d.InsertValue(q, mi, false, names, values)
if err != nil {
return 0, err
}
if len(autoFields) > 0 {
err = d.ins.setval(q, mi, autoFields)
}
return id, err
} }
// multi-insert sql with given slice struct reflect.Value. // multi-insert sql with given slice struct reflect.Value.
@ -369,7 +392,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
// typ := reflect.Indirect(mi.addrField).Type() // typ := reflect.Indirect(mi.addrField).Type()
length := sind.Len() length, autoFields := sind.Len(), make([]string, 0, 1)
for i := 1; i <= length; i++ { for i := 1; i <= length; i++ {
@ -381,16 +404,18 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
// } // }
if i == 1 { if i == 1 {
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz) var (
vus []interface{}
err error
)
vus, autoFields, err = d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz)
if err != nil { if err != nil {
return cnt, err return cnt, err
} }
values = make([]interface{}, bulk*len(vus)) values = make([]interface{}, bulk*len(vus))
nums += copy(values, vus) nums += copy(values, vus)
} else { } else {
vus, _, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, nil, tz)
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
if err != nil { if err != nil {
return cnt, err return cnt, err
} }
@ -412,7 +437,12 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul
} }
} }
return cnt, nil var err error
if len(autoFields) > 0 {
err = d.ins.setval(q, mi, autoFields)
}
return cnt, err
} }
// execute insert sql with given struct and given values. // execute insert sql with given struct and given values.
@ -472,7 +502,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
setNames = make([]string, 0, len(cols)) setNames = make([]string, 0, len(cols))
} }
setValues, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz) setValues, _, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -516,7 +546,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
} }
if num > 0 { if num > 0 {
if mi.fields.pk.auto { if mi.fields.pk.auto {
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(0) ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(0)
} else { } else {
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0) ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0)
@ -1071,13 +1101,13 @@ setValue:
} }
value = b value = b
} }
case fieldType == TypeCharField || fieldType == TypeTextField: case fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField:
if str == nil { if str == nil {
value = ToStr(val) value = ToStr(val)
} else { } else {
value = str.String() value = str.String()
} }
case fieldType == TypeDateField || fieldType == TypeDateTimeField: case fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField:
if str == nil { if str == nil {
switch t := val.(type) { switch t := val.(type) {
case time.Time: case time.Time:
@ -1097,15 +1127,20 @@ setValue:
if len(s) >= 19 { if len(s) >= 19 {
s = s[:19] s = s[:19]
t, err = time.ParseInLocation(formatDateTime, s, tz) t, err = time.ParseInLocation(formatDateTime, s, tz)
} else { } else if len(s) >= 10 {
if len(s) > 10 { if len(s) > 10 {
s = s[:10] s = s[:10]
} }
t, err = time.ParseInLocation(formatDate, s, tz) t, err = time.ParseInLocation(formatDate, s, tz)
} else if len(s) >= 8 {
if len(s) > 8 {
s = s[:8]
}
t, err = time.ParseInLocation(formatTime, s, tz)
} }
t = t.In(DefaultTimeLoc) t = t.In(DefaultTimeLoc)
if err != nil && s != "0000-00-00" && s != "0000-00-00 00:00:00" { if err != nil && s != "00:00:00" && s != "0000-00-00" && s != "0000-00-00 00:00:00" {
tErr = err tErr = err
goto end goto end
} }
@ -1140,7 +1175,7 @@ setValue:
tErr = err tErr = err
goto end goto end
} }
if fieldType&IsPostiveIntegerField > 0 { if fieldType&IsPositiveIntegerField > 0 {
v, _ := str.Uint64() v, _ := str.Uint64()
value = v value = v
} else { } else {
@ -1212,7 +1247,7 @@ setValue:
field.SetBool(value.(bool)) field.SetBool(value.(bool))
} }
} }
case fieldType == TypeCharField || fieldType == TypeTextField: case fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField:
if isNative { if isNative {
if ns, ok := field.Interface().(sql.NullString); ok { if ns, ok := field.Interface().(sql.NullString); ok {
if value == nil { if value == nil {
@ -1234,7 +1269,7 @@ setValue:
field.SetString(value.(string)) field.SetString(value.(string))
} }
} }
case fieldType == TypeDateField || fieldType == TypeDateTimeField: case fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField:
if isNative { if isNative {
if value == nil { if value == nil {
value = time.Time{} value = time.Time{}
@ -1292,7 +1327,7 @@ setValue:
field.Set(reflect.ValueOf(&v)) field.Set(reflect.ValueOf(&v))
} }
case fieldType&IsIntegerField > 0: case fieldType&IsIntegerField > 0:
if fieldType&IsPostiveIntegerField > 0 { if fieldType&IsPositiveIntegerField > 0 {
if isNative { if isNative {
if value == nil { if value == nil {
value = uint64(0) value = uint64(0)
@ -1562,6 +1597,11 @@ func (d *dbBase) HasReturningID(*modelInfo, *string) bool {
return false return false
} }
// sync auto key
func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error {
return nil
}
// convert time from db. // convert time from db.
func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) { func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) {
*t = t.In(tz) *t = t.In(tz)

View File

@ -56,6 +56,8 @@ var postgresTypes = map[string]string{
"uint64": `bigint CHECK("%COL%" >= 0)`, "uint64": `bigint CHECK("%COL%" >= 0)`,
"float64": "double precision", "float64": "double precision",
"float64-decimal": "numeric(%d, %d)", "float64-decimal": "numeric(%d, %d)",
"json": "json",
"jsonb": "jsonb",
} }
// postgresql dbBaser. // postgresql dbBaser.
@ -123,14 +125,35 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
} }
// make returning sql support for postgresql. // make returning sql support for postgresql.
func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) { func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool {
if mi.fields.pk.auto { fi := mi.fields.pk
if fi.fieldType&IsPositiveIntegerField == 0 && fi.fieldType&IsIntegerField == 0 {
return false
}
if query != nil { if query != nil {
*query = fmt.Sprintf(`%s RETURNING "%s"`, *query, mi.fields.pk.column) *query = fmt.Sprintf(`%s RETURNING "%s"`, *query, fi.column)
} }
has = true return true
}
// sync auto key
func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error {
if len(autoFields) == 0 {
return nil
} }
return
Q := d.ins.TableQuote()
for _, name := range autoFields {
query := fmt.Sprintf("SELECT setval(pg_get_serial_sequence('%s', '%s'), (SELECT MAX(%s%s%s) FROM %s%s%s));",
mi.table, name,
Q, name, Q,
Q, mi.table, Q)
if _, err := db.Exec(query); err != nil {
return err
}
}
return nil
} }
// show table sql for postgresql. // show table sql for postgresql.

View File

@ -33,13 +33,13 @@ func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interfac
fi := mi.fields.pk fi := mi.fields.pk
v := ind.FieldByIndex(fi.fieldIndex) v := ind.FieldByIndex(fi.fieldIndex)
if fi.fieldType&IsPostiveIntegerField > 0 { if fi.fieldType&IsPositiveIntegerField > 0 {
vu := v.Uint() vu := v.Uint()
exist = vu > 0 exist = vu > 0
value = vu value = vu
} else if fi.fieldType&IsIntegerField > 0 { } else if fi.fieldType&IsIntegerField > 0 {
vu := v.Int() vu := v.Int()
exist = vu > 0 exist = true
value = vu value = vu
} else { } else {
vu := v.String() vu := v.String()
@ -74,24 +74,32 @@ outFor:
case reflect.String: case reflect.String:
v := val.String() v := val.String()
if fi != nil { if fi != nil {
if fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField { if fi.fieldType == TypeTimeField || fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField {
var t time.Time var t time.Time
var err error var err error
if len(v) >= 19 { if len(v) >= 19 {
s := v[:19] s := v[:19]
t, err = time.ParseInLocation(formatDateTime, s, DefaultTimeLoc) t, err = time.ParseInLocation(formatDateTime, s, DefaultTimeLoc)
} else { } else if len(v) >= 10 {
s := v s := v
if len(v) > 10 { if len(v) > 10 {
s = v[:10] s = v[:10]
} }
t, err = time.ParseInLocation(formatDate, s, tz) t, err = time.ParseInLocation(formatDate, s, tz)
} else {
s := v
if len(s) > 8 {
s = v[:8]
}
t, err = time.ParseInLocation(formatTime, s, tz)
} }
if err == nil { if err == nil {
if fi.fieldType == TypeDateField { if fi.fieldType == TypeDateField {
v = t.In(tz).Format(formatDate) v = t.In(tz).Format(formatDate)
} else { } else if fi.fieldType == TypeDateTimeField {
v = t.In(tz).Format(formatDateTime) v = t.In(tz).Format(formatDateTime)
} else {
v = t.In(tz).Format(formatTime)
} }
} }
} }
@ -137,8 +145,10 @@ outFor:
if v, ok := arg.(time.Time); ok { if v, ok := arg.(time.Time); ok {
if fi != nil && fi.fieldType == TypeDateField { if fi != nil && fi.fieldType == TypeDateField {
arg = v.In(tz).Format(formatDate) arg = v.In(tz).Format(formatDate)
} else { } else if fi.fieldType == TypeDateTimeField {
arg = v.In(tz).Format(formatDateTime) arg = v.In(tz).Format(formatDateTime)
} else {
arg = v.In(tz).Format(formatTime)
} }
} else { } else {
typ := val.Type() typ := val.Type()

View File

@ -25,6 +25,7 @@ const (
TypeBooleanField = 1 << iota TypeBooleanField = 1 << iota
TypeCharField TypeCharField
TypeTextField TypeTextField
TypeTimeField
TypeDateField TypeDateField
TypeDateTimeField TypeDateTimeField
TypeBitField TypeBitField
@ -37,6 +38,8 @@ const (
TypePositiveBigIntegerField TypePositiveBigIntegerField
TypeFloatField TypeFloatField
TypeDecimalField TypeDecimalField
TypeJSONField
TypeJsonbField
RelForeignKey RelForeignKey
RelOneToOne RelOneToOne
RelManyToMany RelManyToMany
@ -46,9 +49,9 @@ const (
// Define some logic enum // Define some logic enum
const ( const (
IsIntegerField = ^-TypePositiveBigIntegerField >> 4 << 5 IsIntegerField = ^-TypePositiveBigIntegerField >> 5 << 6
IsPostiveIntegerField = ^-TypePositiveBigIntegerField >> 8 << 9 IsPositiveIntegerField = ^-TypePositiveBigIntegerField >> 9 << 10
IsRelField = ^-RelReverseMany >> 14 << 15 IsRelField = ^-RelReverseMany >> 17 << 18
IsFieldType = ^-RelReverseMany<<1 + 1 IsFieldType = ^-RelReverseMany<<1 + 1
) )
@ -145,6 +148,65 @@ func (e *CharField) RawValue() interface{} {
// verify CharField implement Fielder // verify CharField implement Fielder
var _ Fielder = new(CharField) var _ Fielder = new(CharField)
// TimeField A time, represented in go by a time.Time instance.
// only time values like 10:00:00
// Has a few extra, optional attr tag:
//
// auto_now:
// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps.
// Note that the current date is always used; its not just a default value that you can override.
//
// auto_now_add:
// Automatically set the field to now when the object is first created. Useful for creation of timestamps.
// Note that the current date is always used; its not just a default value that you can override.
//
// eg: `orm:"auto_now"` or `orm:"auto_now_add"`
type TimeField time.Time
// Value return the time.Time
func (e TimeField) Value() time.Time {
return time.Time(e)
}
// Set set the TimeField's value
func (e *TimeField) Set(d time.Time) {
*e = TimeField(d)
}
// String convert time to string
func (e *TimeField) String() string {
return e.Value().String()
}
// FieldType return enum type Date
func (e *TimeField) FieldType() int {
return TypeDateField
}
// SetRaw convert the interface to time.Time. Allow string and time.Time
func (e *TimeField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
v, err := timeParse(d, formatTime)
if err != nil {
e.Set(v)
}
return err
default:
return fmt.Errorf("<TimeField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return time value
func (e *TimeField) RawValue() interface{} {
return e.Value()
}
var _ Fielder = new(TimeField)
// DateField A date, represented in go by a time.Time instance. // DateField A date, represented in go by a time.Time instance.
// only date values like 2006-01-02 // only date values like 2006-01-02
// Has a few extra, optional attr tag: // Has a few extra, optional attr tag:
@ -627,3 +689,87 @@ func (e *TextField) RawValue() interface{} {
// verify TextField implement Fielder // verify TextField implement Fielder
var _ Fielder = new(TextField) var _ Fielder = new(TextField)
// JSONField postgres json field.
type JSONField string
// Value return JSONField value
func (j JSONField) Value() string {
return string(j)
}
// Set the JSONField value
func (j *JSONField) Set(d string) {
*j = JSONField(d)
}
// String convert JSONField to string
func (j *JSONField) String() string {
return j.Value()
}
// FieldType return enum type
func (j *JSONField) FieldType() int {
return TypeJSONField
}
// SetRaw convert interface string to string
func (j *JSONField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
j.Set(d)
default:
return fmt.Errorf("<JSONField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return JSONField value
func (j *JSONField) RawValue() interface{} {
return j.Value()
}
// verify JSONField implement Fielder
var _ Fielder = new(JSONField)
// JsonbField postgres json field.
type JsonbField string
// Value return JsonbField value
func (j JsonbField) Value() string {
return string(j)
}
// Set the JsonbField value
func (j *JsonbField) Set(d string) {
*j = JsonbField(d)
}
// String convert JsonbField to string
func (j *JsonbField) String() string {
return j.Value()
}
// FieldType return enum type
func (j *JsonbField) FieldType() int {
return TypeJsonbField
}
// SetRaw convert interface string to string
func (j *JsonbField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
j.Set(d)
default:
return fmt.Errorf("<JsonbField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return JsonbField value
func (j *JsonbField) RawValue() interface{} {
return j.Value()
}
// verify JsonbField implement Fielder
var _ Fielder = new(JsonbField)

View File

@ -119,6 +119,7 @@ type fieldInfo struct {
colDefault bool colDefault bool
initial StrTo initial StrTo
size int size int
toText bool
autoNow bool autoNow bool
autoNowAdd bool autoNowAdd bool
rel bool rel bool
@ -239,8 +240,15 @@ checkType:
if err != nil { if err != nil {
goto end goto end
} }
if fieldType == TypeCharField && tags["type"] == "text" { if fieldType == TypeCharField {
switch tags["type"] {
case "text":
fieldType = TypeTextField fieldType = TypeTextField
case "json":
fieldType = TypeJSONField
case "jsonb":
fieldType = TypeJsonbField
}
} }
if fieldType == TypeFloatField && (digits != "" || decimals != "") { if fieldType == TypeFloatField && (digits != "" || decimals != "") {
fieldType = TypeDecimalField fieldType = TypeDecimalField
@ -248,6 +256,9 @@ checkType:
if fieldType == TypeDateTimeField && tags["type"] == "date" { if fieldType == TypeDateTimeField && tags["type"] == "date" {
fieldType = TypeDateField fieldType = TypeDateField
} }
if fieldType == TypeTimeField && tags["type"] == "time" {
fieldType = TypeTimeField
}
} }
switch fieldType { switch fieldType {
@ -339,7 +350,7 @@ checkType:
switch fieldType { switch fieldType {
case TypeBooleanField: case TypeBooleanField:
case TypeCharField: case TypeCharField, TypeJSONField, TypeJsonbField:
if size != "" { if size != "" {
v, e := StrTo(size).Int32() v, e := StrTo(size).Int32()
if e != nil { if e != nil {
@ -349,11 +360,12 @@ checkType:
} }
} else { } else {
fi.size = 255 fi.size = 255
fi.toText = true
} }
case TypeTextField: case TypeTextField:
fi.index = false fi.index = false
fi.unique = false fi.unique = false
case TypeDateField, TypeDateTimeField: case TypeTimeField, TypeDateField, TypeDateTimeField:
if attrs["auto_now"] { if attrs["auto_now"] {
fi.autoNow = true fi.autoNow = true
} else if attrs["auto_now_add"] { } else if attrs["auto_now_add"] {
@ -406,7 +418,7 @@ checkType:
fi.index = false fi.index = false
} }
if fi.auto || fi.pk || fi.unique || fieldType == TypeDateField || fieldType == TypeDateTimeField { if fi.auto || fi.pk || fi.unique || fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField {
// can not set default // can not set default
initial.Clear() initial.Clear()
} }

View File

@ -78,40 +78,43 @@ func (e *SliceStringField) RawValue() interface{} {
var _ Fielder = new(SliceStringField) var _ Fielder = new(SliceStringField)
// A json field. // A json field.
type JSONField struct { type JSONFieldTest struct {
Name string Name string
Data string Data string
} }
func (e *JSONField) String() string { func (e *JSONFieldTest) String() string {
data, _ := json.Marshal(e) data, _ := json.Marshal(e)
return string(data) return string(data)
} }
func (e *JSONField) FieldType() int { func (e *JSONFieldTest) FieldType() int {
return TypeTextField return TypeTextField
} }
func (e *JSONField) SetRaw(value interface{}) error { func (e *JSONFieldTest) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case string: case string:
return json.Unmarshal([]byte(d), e) return json.Unmarshal([]byte(d), e)
default: default:
return fmt.Errorf("<JsonField.SetRaw> unknown value `%v`", value) return fmt.Errorf("<JSONField.SetRaw> unknown value `%v`", value)
} }
} }
func (e *JSONField) RawValue() interface{} { func (e *JSONFieldTest) RawValue() interface{} {
return e.String() return e.String()
} }
var _ Fielder = new(JSONField) var _ Fielder = new(JSONFieldTest)
type Data struct { type Data struct {
ID int `orm:"column(id)"` ID int `orm:"column(id)"`
Boolean bool Boolean bool
Char string `orm:"size(50)"` Char string `orm:"size(50)"`
Text string `orm:"type(text)"` Text string `orm:"type(text)"`
JSON string `orm:"type(json);default({\"name\":\"json\"})"`
Jsonb string `orm:"type(jsonb)"`
Time time.Time `orm:"type(time)"`
Date time.Time `orm:"type(date)"` Date time.Time `orm:"type(date)"`
DateTime time.Time `orm:"column(datetime)"` DateTime time.Time `orm:"column(datetime)"`
Byte byte Byte byte
@ -136,6 +139,9 @@ type DataNull struct {
Boolean bool `orm:"null"` Boolean bool `orm:"null"`
Char string `orm:"null;size(50)"` Char string `orm:"null;size(50)"`
Text string `orm:"null;type(text)"` Text string `orm:"null;type(text)"`
JSON string `orm:"type(json);null"`
Jsonb string `orm:"type(jsonb);null"`
Time time.Time `orm:"null;type(time)"`
Date time.Time `orm:"null;type(date)"` Date time.Time `orm:"null;type(date)"`
DateTime time.Time `orm:"null;column(datetime)"` DateTime time.Time `orm:"null;column(datetime)"`
Byte byte `orm:"null"` Byte byte `orm:"null"`
@ -237,7 +243,7 @@ type User struct {
ShouldSkip string `orm:"-"` ShouldSkip string `orm:"-"`
Nums int Nums int
Langs SliceStringField `orm:"size(100)"` Langs SliceStringField `orm:"size(100)"`
Extra JSONField `orm:"type(text)"` Extra JSONFieldTest `orm:"type(text)"`
unexport bool `orm:"-"` unexport bool `orm:"-"`
unexportBool bool unexportBool bool
} }
@ -375,6 +381,28 @@ func NewInLine() *InLine {
return new(InLine) return new(InLine)
} }
type InLineOneToOne struct {
// Common Fields
ModelBase
Note string
InLine *InLine `orm:"rel(fk);column(inline)"`
}
func NewInLineOneToOne() *InLineOneToOne {
return new(InLineOneToOne)
}
type IntegerPk struct {
ID int64 `orm:"pk"`
Value string
}
type UintPk struct {
ID uint32 `orm:"pk"`
Name string
}
var DBARGS = struct { var DBARGS = struct {
Driver string Driver string
Source string Source string

View File

@ -140,7 +140,14 @@ func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, i
return (err == nil), id, err return (err == nil), id, err
} }
return false, ind.FieldByIndex(mi.fields.pk.fieldIndex).Int(), err id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex)
if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
id = int64(vid.Uint())
} else {
id = vid.Int()
}
return false, id, err
} }
// insert model data to database // insert model data to database
@ -159,7 +166,7 @@ func (o *orm) Insert(md interface{}) (int64, error) {
// set auto pk field // set auto pk field
func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) { func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) {
if mi.fields.pk.auto { if mi.fields.pk.auto {
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id)) ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id))
} else { } else {
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(id) ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(id)
@ -184,7 +191,7 @@ func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) {
if bulk <= 1 { if bulk <= 1 {
for i := 0; i < sind.Len(); i++ { for i := 0; i < sind.Len(); i++ {
ind := sind.Index(i) ind := reflect.Indirect(sind.Index(i))
mi, _ := o.getMiInd(ind.Interface(), false) mi, _ := o.getMiInd(ind.Interface(), false)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
if err != nil { if err != nil {

View File

@ -31,7 +31,7 @@ type Log struct {
// NewLog set io.Writer to create a Logger. // NewLog set io.Writer to create a Logger.
func NewLog(out io.Writer) *Log { func NewLog(out io.Writer) *Log {
d := new(Log) d := new(Log)
d.Logger = log.New(out, "[ORM]", 1e9) d.Logger = log.New(out, "[ORM]", log.LstdFlags)
return d return d
} }

View File

@ -50,7 +50,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) {
} }
if id > 0 { if id > 0 {
if o.mi.fields.pk.auto { if o.mi.fields.pk.auto {
if o.mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { if o.mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetUint(uint64(id)) ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetUint(uint64(id))
} else { } else {
ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetInt(id) ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetInt(id)

View File

@ -192,16 +192,18 @@ func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
// query one row data and map to containers. // query one row data and map to containers.
// cols means the columns when querying. // cols means the columns when querying.
func (o *querySet) One(container interface{}, cols ...string) error { func (o *querySet) One(container interface{}, cols ...string) error {
o.limit = 1
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
if err != nil { if err != nil {
return err return err
} }
if num > 1 {
return ErrMultiRows
}
if num == 0 { if num == 0 {
return ErrNoRows return ErrNoRows
} }
if num > 1 {
return ErrMultiRows
}
return nil return nil
} }

View File

@ -19,6 +19,7 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"math"
"os" "os"
"path/filepath" "path/filepath"
"reflect" "reflect"
@ -33,6 +34,7 @@ var _ = os.PathSeparator
var ( var (
testDate = formatDate + " -0700" testDate = formatDate + " -0700"
testDateTime = formatDateTime + " -0700" testDateTime = formatDateTime + " -0700"
testTime = formatTime + " -0700"
) )
type argAny []interface{} type argAny []interface{}
@ -188,6 +190,9 @@ func TestSyncDb(t *testing.T) {
RegisterModel(new(Permission)) RegisterModel(new(Permission))
RegisterModel(new(GroupPermissions)) RegisterModel(new(GroupPermissions))
RegisterModel(new(InLine)) RegisterModel(new(InLine))
RegisterModel(new(InLineOneToOne))
RegisterModel(new(IntegerPk))
RegisterModel(new(UintPk))
err := RunSyncdb("default", true, Debug) err := RunSyncdb("default", true, Debug)
throwFail(t, err) throwFail(t, err)
@ -208,6 +213,9 @@ func TestRegisterModels(t *testing.T) {
RegisterModel(new(Permission)) RegisterModel(new(Permission))
RegisterModel(new(GroupPermissions)) RegisterModel(new(GroupPermissions))
RegisterModel(new(InLine)) RegisterModel(new(InLine))
RegisterModel(new(InLineOneToOne))
RegisterModel(new(IntegerPk))
RegisterModel(new(UintPk))
BootStrap() BootStrap()
@ -233,6 +241,9 @@ var DataValues = map[string]interface{}{
"Boolean": true, "Boolean": true,
"Char": "char", "Char": "char",
"Text": "text", "Text": "text",
"JSON": `{"name":"json"}`,
"Jsonb": `{"name": "jsonb"}`,
"Time": time.Now(),
"Date": time.Now(), "Date": time.Now(),
"DateTime": time.Now(), "DateTime": time.Now(),
"Byte": byte(1<<8 - 1), "Byte": byte(1<<8 - 1),
@ -257,10 +268,12 @@ func TestDataTypes(t *testing.T) {
ind := reflect.Indirect(reflect.ValueOf(&d)) ind := reflect.Indirect(reflect.ValueOf(&d))
for name, value := range DataValues { for name, value := range DataValues {
if name == "JSON" {
continue
}
e := ind.FieldByName(name) e := ind.FieldByName(name)
e.Set(reflect.ValueOf(value)) e.Set(reflect.ValueOf(value))
} }
id, err := dORM.Insert(&d) id, err := dORM.Insert(&d)
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(id, 1)) throwFail(t, AssertIs(id, 1))
@ -281,6 +294,9 @@ func TestDataTypes(t *testing.T) {
case "DateTime": case "DateTime":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
case "Time":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(testTime)
} }
throwFail(t, AssertIs(vu == value, true), value, vu) throwFail(t, AssertIs(vu == value, true), value, vu)
} }
@ -299,10 +315,18 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(id, 1)) throwFail(t, AssertIs(id, 1))
data := `{"ok":1,"data":{"arr":[1,2],"msg":"gopher"}}`
d = DataNull{ID: 1, JSON: data}
num, err := dORM.Update(&d)
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
d = DataNull{ID: 1} d = DataNull{ID: 1}
err = dORM.Read(&d) err = dORM.Read(&d)
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(d.JSON, data))
throwFail(t, AssertIs(d.NullBool.Valid, false)) throwFail(t, AssertIs(d.NullBool.Valid, false))
throwFail(t, AssertIs(d.NullString.Valid, false)) throwFail(t, AssertIs(d.NullString.Valid, false))
throwFail(t, AssertIs(d.NullInt64.Valid, false)) throwFail(t, AssertIs(d.NullInt64.Valid, false))
@ -969,12 +993,19 @@ func TestOne(t *testing.T) {
var user User var user User
qs := dORM.QueryTable("user") qs := dORM.QueryTable("user")
err := qs.One(&user) err := qs.One(&user)
throwFail(t, AssertIs(err, ErrMultiRows)) throwFail(t, err)
user = User{} user = User{}
err = qs.OrderBy("Id").Limit(1).One(&user) err = qs.OrderBy("Id").Limit(1).One(&user)
throwFailNow(t, err) throwFailNow(t, err)
throwFail(t, AssertIs(user.UserName, "slene")) throwFail(t, AssertIs(user.UserName, "slene"))
throwFail(t, AssertNot(err, ErrMultiRows))
user = User{}
err = qs.OrderBy("-Id").Limit(100).One(&user)
throwFailNow(t, err)
throwFail(t, AssertIs(user.UserName, "nobody"))
throwFail(t, AssertNot(err, ErrMultiRows))
err = qs.Filter("user_name", "nothing").One(&user) err = qs.Filter("user_name", "nothing").One(&user)
throwFail(t, AssertIs(err, ErrNoRows)) throwFail(t, AssertIs(err, ErrNoRows))
@ -1514,6 +1545,7 @@ func TestRawQueryRow(t *testing.T) {
Boolean bool Boolean bool
Char string Char string
Text string Text string
Time time.Time
Date time.Time Date time.Time
DateTime time.Time DateTime time.Time
Byte byte Byte byte
@ -1542,14 +1574,14 @@ func TestRawQueryRow(t *testing.T) {
Q := dDbBaser.TableQuote() Q := dDbBaser.TableQuote()
cols := []string{ cols := []string{
"id", "boolean", "char", "text", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32", "id", "boolean", "char", "text", "time", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32",
"int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal", "int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal",
} }
sep := fmt.Sprintf("%s, %s", Q, Q) sep := fmt.Sprintf("%s, %s", Q, Q)
query := fmt.Sprintf("SELECT %s%s%s FROM data WHERE id = ?", Q, strings.Join(cols, sep), Q) query := fmt.Sprintf("SELECT %s%s%s FROM data WHERE id = ?", Q, strings.Join(cols, sep), Q)
var id int var id int
values := []interface{}{ values := []interface{}{
&id, &Boolean, &Char, &Text, &Date, &DateTime, &Byte, &Rune, &Int, &Int8, &Int16, &Int32, &id, &Boolean, &Char, &Text, &Time, &Date, &DateTime, &Byte, &Rune, &Int, &Int8, &Int16, &Int32,
&Int64, &Uint, &Uint8, &Uint16, &Uint32, &Uint64, &Float32, &Float64, &Decimal, &Int64, &Uint, &Uint8, &Uint16, &Uint32, &Uint64, &Float32, &Float64, &Decimal,
} }
err := dORM.Raw(query, 1).QueryRow(values...) err := dORM.Raw(query, 1).QueryRow(values...)
@ -1560,6 +1592,10 @@ func TestRawQueryRow(t *testing.T) {
switch col { switch col {
case "id": case "id":
throwFail(t, AssertIs(id, 1)) throwFail(t, AssertIs(id, 1))
case "time":
v = v.(time.Time).In(DefaultTimeLoc)
value := dataValues[col].(time.Time).In(DefaultTimeLoc)
throwFail(t, AssertIs(v, value, testTime))
case "date": case "date":
v = v.(time.Time).In(DefaultTimeLoc) v = v.(time.Time).In(DefaultTimeLoc)
value := dataValues[col].(time.Time).In(DefaultTimeLoc) value := dataValues[col].(time.Time).In(DefaultTimeLoc)
@ -1607,6 +1643,9 @@ func TestQueryRows(t *testing.T) {
e := ind.FieldByName(name) e := ind.FieldByName(name)
vu := e.Interface() vu := e.Interface()
switch name { switch name {
case "Time":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(testTime)
case "Date": case "Date":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate)
value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) value = value.(time.Time).In(DefaultTimeLoc).Format(testDate)
@ -1631,6 +1670,9 @@ func TestQueryRows(t *testing.T) {
e := ind.FieldByName(name) e := ind.FieldByName(name)
vu := e.Interface() vu := e.Interface()
switch name { switch name {
case "Time":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(testTime)
case "Date": case "Date":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate)
value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) value = value.(time.Time).In(DefaultTimeLoc).Format(testDate)
@ -1952,3 +1994,126 @@ func TestInLine(t *testing.T) {
throwFail(t, AssertIs(il.Created.In(DefaultTimeLoc), inline.Created.In(DefaultTimeLoc), testDate)) throwFail(t, AssertIs(il.Created.In(DefaultTimeLoc), inline.Created.In(DefaultTimeLoc), testDate))
throwFail(t, AssertIs(il.Updated.In(DefaultTimeLoc), inline.Updated.In(DefaultTimeLoc), testDateTime)) throwFail(t, AssertIs(il.Updated.In(DefaultTimeLoc), inline.Updated.In(DefaultTimeLoc), testDateTime))
} }
func TestInLineOneToOne(t *testing.T) {
name := "121"
email := "121@go.com"
inline := NewInLine()
inline.Name = name
inline.Email = email
id, err := dORM.Insert(inline)
throwFail(t, err)
throwFail(t, AssertIs(id, 2))
note := "one2one"
il121 := NewInLineOneToOne()
il121.Note = note
il121.InLine = inline
_, err = dORM.Insert(il121)
throwFail(t, err)
throwFail(t, AssertIs(il121.ID, 1))
il := NewInLineOneToOne()
err = dORM.QueryTable(il).Filter("Id", 1).RelatedSel().One(il)
throwFail(t, err)
throwFail(t, AssertIs(il.Note, note))
throwFail(t, AssertIs(il.InLine.ID, id))
throwFail(t, AssertIs(il.InLine.Name, name))
throwFail(t, AssertIs(il.InLine.Email, email))
rinline := NewInLine()
err = dORM.QueryTable(rinline).Filter("InLineOneToOne__Id", 1).One(rinline)
throwFail(t, err)
throwFail(t, AssertIs(rinline.ID, id))
throwFail(t, AssertIs(rinline.Name, name))
throwFail(t, AssertIs(rinline.Email, email))
}
func TestIntegerPk(t *testing.T) {
its := []IntegerPk{
{ID: math.MinInt64, Value: "-"},
{ID: 0, Value: "0"},
{ID: math.MaxInt64, Value: "+"},
}
num, err := dORM.InsertMulti(len(its), its)
throwFail(t, err)
throwFail(t, AssertIs(num, len(its)))
for _, intPk := range its {
out := IntegerPk{ID: intPk.ID}
err = dORM.Read(&out)
throwFail(t, err)
throwFail(t, AssertIs(out.Value, intPk.Value))
}
num, err = dORM.InsertMulti(1, []*IntegerPk{&IntegerPk{
ID: 1, Value: "ok",
}})
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
}
func TestInsertAuto(t *testing.T) {
u := &User{
UserName: "autoPre",
Email: "autoPre@gmail.com",
}
id, err := dORM.Insert(u)
throwFail(t, err)
id += 100
su := &User{
ID: int(id),
UserName: "auto",
Email: "auto@gmail.com",
}
nid, err := dORM.Insert(su)
throwFail(t, err)
throwFail(t, AssertIs(nid, id))
users := []User{
{ID: int(id + 100), UserName: "auto_100"},
{ID: int(id + 110), UserName: "auto_110"},
{ID: int(id + 120), UserName: "auto_120"},
}
num, err := dORM.InsertMulti(100, users)
throwFail(t, err)
throwFail(t, AssertIs(num, 3))
u = &User{
UserName: "auto_121",
}
nid, err = dORM.Insert(u)
throwFail(t, err)
throwFail(t, AssertIs(nid, id+120+1))
}
func TestUintPk(t *testing.T) {
name := "go"
u := &UintPk{
ID: 8,
Name: name,
}
created, pk, err := dORM.ReadOrCreate(u, "ID")
throwFail(t, err)
throwFail(t, AssertIs(created, true))
throwFail(t, AssertIs(u.Name, name))
nu := &UintPk{ID: 8}
created, pk, err = dORM.ReadOrCreate(nu, "ID")
throwFail(t, err)
throwFail(t, AssertIs(created, false))
throwFail(t, AssertIs(nu.ID, u.ID))
throwFail(t, AssertIs(pk, u.ID))
throwFail(t, AssertIs(nu.Name, name))
dORM.Delete(u)
}

View File

@ -420,4 +420,5 @@ type dbBaser interface {
ShowColumnsQuery(string) string ShowColumnsQuery(string) string
IndexExists(dbQuerier, string, string) bool IndexExists(dbQuerier, string, string) bool
collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error)
setval(dbQuerier, *modelInfo, []string) error
} }

View File

@ -28,6 +28,7 @@ import (
"sort" "sort"
"strings" "strings"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/utils" "github.com/astaxie/beego/utils"
) )
@ -59,7 +60,7 @@ func parserPkg(pkgRealpath, pkgpath string) error {
rep := strings.NewReplacer("\\", "_", "/", "_", ".", "_") rep := strings.NewReplacer("\\", "_", "/", "_", ".", "_")
commentFilename = coomentPrefix + rep.Replace(strings.Replace(pkgRealpath, AppPath, "", -1)) + ".go" commentFilename = coomentPrefix + rep.Replace(strings.Replace(pkgRealpath, AppPath, "", -1)) + ".go"
if !compareFile(pkgRealpath) { if !compareFile(pkgRealpath) {
Info(pkgRealpath + " no changed") logs.Info(pkgRealpath + " no changed")
return nil return nil
} }
genInfoList = make(map[string][]ControllerComments) genInfoList = make(map[string][]ControllerComments)
@ -132,7 +133,7 @@ func parserComments(comments *ast.CommentGroup, funcName, controllerName, pkgpat
func genRouterCode(pkgRealpath string) { func genRouterCode(pkgRealpath string) {
os.Mkdir(getRouterDir(pkgRealpath), 0755) os.Mkdir(getRouterDir(pkgRealpath), 0755)
Info("generate router from comments") logs.Info("generate router from comments")
var ( var (
globalinfo string globalinfo string
sortKey []string sortKey []string

View File

@ -28,6 +28,7 @@ import (
"time" "time"
beecontext "github.com/astaxie/beego/context" beecontext "github.com/astaxie/beego/context"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/toolbox" "github.com/astaxie/beego/toolbox"
"github.com/astaxie/beego/utils" "github.com/astaxie/beego/utils"
) )
@ -114,7 +115,7 @@ type controllerInfo struct {
type ControllerRegister struct { type ControllerRegister struct {
routers map[string]*Tree routers map[string]*Tree
enableFilter bool enableFilter bool
filters map[int][]*FilterRouter filters [FinishRouter + 1][]*FilterRouter
pool sync.Pool pool sync.Pool
} }
@ -122,7 +123,6 @@ type ControllerRegister struct {
func NewControllerRegister() *ControllerRegister { func NewControllerRegister() *ControllerRegister {
cr := &ControllerRegister{ cr := &ControllerRegister{
routers: make(map[string]*Tree), routers: make(map[string]*Tree),
filters: make(map[int][]*FilterRouter),
} }
cr.pool.New = func() interface{} { cr.pool.New = func() interface{} {
return beecontext.NewContext() return beecontext.NewContext()
@ -408,7 +408,6 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface)
// InsertFilter Add a FilterFunc with pattern rule and action constant. // InsertFilter Add a FilterFunc with pattern rule and action constant.
// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) // The bool params is for setting the returnOnOutput value (false allows multiple filters to execute)
func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error {
mr := new(FilterRouter) mr := new(FilterRouter)
mr.tree = NewTree() mr.tree = NewTree()
mr.pattern = pattern mr.pattern = pattern
@ -426,9 +425,13 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter
} }
// add Filter into // add Filter into
func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) error { func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) {
p.filters[pos] = append(p.filters[pos], mr) if pos < BeforeStatic || pos > FinishRouter {
err = fmt.Errorf("can not find your filter postion")
return
}
p.enableFilter = true p.enableFilter = true
p.filters[pos] = append(p.filters[pos], mr)
return nil return nil
} }
@ -437,11 +440,11 @@ func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) error
func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) string { func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) string {
paths := strings.Split(endpoint, ".") paths := strings.Split(endpoint, ".")
if len(paths) <= 1 { if len(paths) <= 1 {
Warn("urlfor endpoint must like path.controller.method") logs.Warn("urlfor endpoint must like path.controller.method")
return "" return ""
} }
if len(values)%2 != 0 { if len(values)%2 != 0 {
Warn("urlfor params must key-value pair") logs.Warn("urlfor params must key-value pair")
return "" return ""
} }
params := make(map[string]string) params := make(map[string]string)
@ -577,10 +580,8 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin
return false, "" return false, ""
} }
func (p *ControllerRegister) execFilter(context *beecontext.Context, pos int, urlPath string) (started bool) { func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath string, pos int) (started bool) {
if p.enableFilter { for _, filterR := range p.filters[pos] {
if l, ok := p.filters[pos]; ok {
for _, filterR := range l {
if filterR.returnOnOutput && context.ResponseWriter.Started { if filterR.returnOnOutput && context.ResponseWriter.Started {
return true return true
} }
@ -591,8 +592,6 @@ func (p *ControllerRegister) execFilter(context *beecontext.Context, pos int, ur
return true return true
} }
} }
}
}
return false return false
} }
@ -617,11 +616,10 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
context.Output.Header("Server", BConfig.ServerName) context.Output.Header("Server", BConfig.ServerName)
} }
var urlPath string var urlPath = r.URL.Path
if !BConfig.RouterCaseSensitive { if !BConfig.RouterCaseSensitive {
urlPath = strings.ToLower(r.URL.Path) urlPath = strings.ToLower(urlPath)
} else {
urlPath = r.URL.Path
} }
// filter wrong http method // filter wrong http method
@ -631,11 +629,12 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
} }
// filter for static file // filter for static file
if p.execFilter(context, BeforeStatic, urlPath) { if len(p.filters[BeforeStatic]) > 0 && p.execFilter(context, urlPath, BeforeStatic) {
goto Admin goto Admin
} }
serverStaticRouter(context) serverStaticRouter(context)
if context.ResponseWriter.Started { if context.ResponseWriter.Started {
findRouter = true findRouter = true
goto Admin goto Admin
@ -653,9 +652,9 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
var err error var err error
context.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) context.Input.CruSession, err = GlobalSessions.SessionStart(rw, r)
if err != nil { if err != nil {
Error(err) logs.Error(err)
exception("503", context) exception("503", context)
return goto Admin
} }
defer func() { defer func() {
if context.Input.CruSession != nil { if context.Input.CruSession != nil {
@ -663,8 +662,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
} }
}() }()
} }
if len(p.filters[BeforeRouter]) > 0 && p.execFilter(context, urlPath, BeforeRouter) {
if p.execFilter(context, BeforeRouter, urlPath) {
goto Admin goto Admin
} }
@ -693,7 +691,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
if findRouter { if findRouter {
//execute middleware filters //execute middleware filters
if p.execFilter(context, BeforeExec, urlPath) { if len(p.filters[BeforeExec]) > 0 && p.execFilter(context, urlPath, BeforeExec) {
goto Admin goto Admin
} }
isRunnable := false isRunnable := false
@ -783,7 +781,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
if !context.ResponseWriter.Started && context.Output.Status == 0 { if !context.ResponseWriter.Started && context.Output.Status == 0 {
if BConfig.WebConfig.AutoRender { if BConfig.WebConfig.AutoRender {
if err := execController.Render(); err != nil { if err := execController.Render(); err != nil {
panic(err) logs.Error(err)
} }
} }
} }
@ -794,17 +792,18 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
} }
//execute middleware filters //execute middleware filters
if p.execFilter(context, AfterExec, urlPath) { if len(p.filters[AfterExec]) > 0 && p.execFilter(context, urlPath, AfterExec) {
goto Admin goto Admin
} }
} }
if len(p.filters[FinishRouter]) > 0 && p.execFilter(context, urlPath, FinishRouter) {
p.execFilter(context, FinishRouter, urlPath) goto Admin
}
Admin: Admin:
timeDur := time.Since(startTime)
//admin module record QPS //admin module record QPS
if BConfig.Listen.EnableAdmin { if BConfig.Listen.EnableAdmin {
timeDur := time.Since(startTime)
if FilterMonitorFunc(r.Method, r.URL.Path, timeDur) { if FilterMonitorFunc(r.Method, r.URL.Path, timeDur) {
if runRouter != nil { if runRouter != nil {
go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, runRouter.Name(), timeDur) go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, runRouter.Name(), timeDur)
@ -815,6 +814,7 @@ Admin:
} }
if BConfig.RunMode == DEV || BConfig.Log.AccessLogs { if BConfig.RunMode == DEV || BConfig.Log.AccessLogs {
timeDur := time.Since(startTime)
var devInfo string var devInfo string
if findRouter { if findRouter {
if routerInfo != nil { if routerInfo != nil {
@ -826,7 +826,7 @@ Admin:
devInfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s |", r.Method, r.URL.Path, timeDur.String(), "notmatch") devInfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s |", r.Method, r.URL.Path, timeDur.String(), "notmatch")
} }
if DefaultAccessLogFilter == nil || !DefaultAccessLogFilter.Filter(context) { if DefaultAccessLogFilter == nil || !DefaultAccessLogFilter.Filter(context) {
Debug(devInfo) logs.Debug(devInfo)
} }
} }
@ -843,7 +843,7 @@ func (p *ControllerRegister) recoverPanic(context *beecontext.Context) {
} }
if !BConfig.RecoverPanic { if !BConfig.RecoverPanic {
panic(err) panic(err)
} else { }
if BConfig.EnableErrorsShow { if BConfig.EnableErrorsShow {
if _, ok := ErrorMaps[fmt.Sprint(err)]; ok { if _, ok := ErrorMaps[fmt.Sprint(err)]; ok {
exception(fmt.Sprint(err), context) exception(fmt.Sprint(err), context)
@ -851,21 +851,20 @@ func (p *ControllerRegister) recoverPanic(context *beecontext.Context) {
} }
} }
var stack string var stack string
Critical("the request url is ", context.Input.URL()) logs.Critical("the request url is ", context.Input.URL())
Critical("Handler crashed with error", err) logs.Critical("Handler crashed with error", err)
for i := 1; ; i++ { for i := 1; ; i++ {
_, file, line, ok := runtime.Caller(i) _, file, line, ok := runtime.Caller(i)
if !ok { if !ok {
break break
} }
Critical(fmt.Sprintf("%s:%d", file, line)) logs.Critical(fmt.Sprintf("%s:%d", file, line))
stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line)) stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line))
} }
if BConfig.RunMode == DEV { if BConfig.RunMode == DEV {
showErr(err, context, stack) showErr(err, context, stack)
} }
} }
}
} }
func toURL(params map[string]string) string { func toURL(params map[string]string) string {

View File

@ -21,6 +21,7 @@ import (
"testing" "testing"
"github.com/astaxie/beego/context" "github.com/astaxie/beego/context"
"github.com/astaxie/beego/logs"
) )
type TestController struct { type TestController struct {
@ -94,7 +95,7 @@ func TestUrlFor(t *testing.T) {
handler.Add("/api/list", &TestController{}, "*:List") handler.Add("/api/list", &TestController{}, "*:List")
handler.Add("/person/:last/:first", &TestController{}, "*:Param") handler.Add("/person/:last/:first", &TestController{}, "*:Param")
if a := handler.URLFor("TestController.List"); a != "/api/list" { if a := handler.URLFor("TestController.List"); a != "/api/list" {
Info(a) logs.Info(a)
t.Errorf("TestController.List must equal to /api/list") t.Errorf("TestController.List must equal to /api/list")
} }
if a := handler.URLFor("TestController.Param", ":last", "xie", ":first", "asta"); a != "/person/xie/asta" { if a := handler.URLFor("TestController.Param", ":last", "xie", ":first", "asta"); a != "/person/xie/asta" {
@ -120,24 +121,24 @@ func TestUrlFor2(t *testing.T) {
handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, "*:Param") handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, "*:Param")
handler.Add("/:year:int/:month:int/:title/:entid", &TestController{}) handler.Add("/:year:int/:month:int/:title/:entid", &TestController{})
if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" { if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" {
Info(handler.URLFor("TestController.GetURL")) logs.Info(handler.URLFor("TestController.GetURL"))
t.Errorf("TestController.List must equal to /v1/astaxie/edit") t.Errorf("TestController.List must equal to /v1/astaxie/edit")
} }
if handler.URLFor("TestController.List", ":v", "za", ":id", "12", ":page", "123") != if handler.URLFor("TestController.List", ":v", "za", ":id", "12", ":page", "123") !=
"/v1/za/cms_12_123.html" { "/v1/za/cms_12_123.html" {
Info(handler.URLFor("TestController.List")) logs.Info(handler.URLFor("TestController.List"))
t.Errorf("TestController.List must equal to /v1/za/cms_12_123.html") t.Errorf("TestController.List must equal to /v1/za/cms_12_123.html")
} }
if handler.URLFor("TestController.Param", ":v", "za", ":id", "12", ":page", "123") != if handler.URLFor("TestController.Param", ":v", "za", ":id", "12", ":page", "123") !=
"/v1/za_cms/ttt_12_123.html" { "/v1/za_cms/ttt_12_123.html" {
Info(handler.URLFor("TestController.Param")) logs.Info(handler.URLFor("TestController.Param"))
t.Errorf("TestController.List must equal to /v1/za_cms/ttt_12_123.html") t.Errorf("TestController.List must equal to /v1/za_cms/ttt_12_123.html")
} }
if handler.URLFor("TestController.Get", ":year", "1111", ":month", "11", if handler.URLFor("TestController.Get", ":year", "1111", ":month", "11",
":title", "aaaa", ":entid", "aaaa") != ":title", "aaaa", ":entid", "aaaa") !=
"/1111/11/aaaa/aaaa" { "/1111/11/aaaa/aaaa" {
Info(handler.URLFor("TestController.Get")) logs.Info(handler.URLFor("TestController.Get"))
t.Errorf("TestController.Get must equal to /1111/11/aaaa/aaaa") t.Errorf("TestController.Get must equal to /1111/11/aaaa/aaaa")
} }
} }

View File

@ -115,7 +115,6 @@ func (st *SessionStore) SessionRelease(w http.ResponseWriter) {
} }
st.c.Exec("UPDATE "+TableName+" set `session_data`=?, `session_expiry`=? where session_key=?", st.c.Exec("UPDATE "+TableName+" set `session_data`=?, `session_expiry`=? where session_key=?",
b, time.Now().Unix(), st.sid) b, time.Now().Unix(), st.sid)
} }
// Provider mysql session provider // Provider mysql session provider

View File

@ -16,7 +16,6 @@ package session
import ( import (
"errors" "errors"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
@ -82,14 +81,17 @@ func (fs *FileSessionStore) SessionID() string {
func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) { func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) {
b, err := EncodeGob(fs.values) b, err := EncodeGob(fs.values)
if err != nil { if err != nil {
SLogger.Println(err)
return return
} }
_, err = os.Stat(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid)) _, err = os.Stat(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid))
var f *os.File var f *os.File
if err == nil { if err == nil {
f, err = os.OpenFile(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid), os.O_RDWR, 0777) f, err = os.OpenFile(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid), os.O_RDWR, 0777)
SLogger.Println(err)
} else if os.IsNotExist(err) { } else if os.IsNotExist(err) {
f, err = os.Create(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid)) f, err = os.Create(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid))
SLogger.Println(err)
} else { } else {
return return
} }
@ -123,7 +125,7 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) {
err := os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777) err := os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777)
if err != nil { if err != nil {
println(err.Error()) SLogger.Println(err.Error())
} }
_, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) _, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
var f *os.File var f *os.File
@ -191,7 +193,7 @@ func (fp *FileProvider) SessionAll() int {
return a.visit(path, f, err) return a.visit(path, f, err)
}) })
if err != nil { if err != nil {
fmt.Printf("filepath.Walk() returned %v\n", err) SLogger.Printf("filepath.Walk() returned %v\n", err)
return 0 return 0
} }
return a.total return a.total
@ -205,11 +207,11 @@ func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) {
err := os.MkdirAll(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])), 0777) err := os.MkdirAll(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])), 0777)
if err != nil { if err != nil {
println(err.Error()) SLogger.Println(err.Error())
} }
err = os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777) err = os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777)
if err != nil { if err != nil {
println(err.Error()) SLogger.Println(err.Error())
} }
_, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) _, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
var newf *os.File var newf *os.File

View File

@ -31,9 +31,14 @@ import (
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io"
"log"
"net/http" "net/http"
"net/textproto"
"net/url" "net/url"
"os"
"time" "time"
) )
@ -61,6 +66,9 @@ type Provider interface {
var provides = make(map[string]Provider) var provides = make(map[string]Provider)
// SLogger a helpful variable to log information about session
var SLogger = NewSessionLog(os.Stderr)
// Register makes a session provide available by the provided name. // Register makes a session provide available by the provided name.
// If Register is called twice with the same name or if driver is nil, // If Register is called twice with the same name or if driver is nil,
// it panics. // it panics.
@ -84,6 +92,9 @@ type managerConfig struct {
ProviderConfig string `json:"providerConfig"` ProviderConfig string `json:"providerConfig"`
Domain string `json:"domain"` Domain string `json:"domain"`
SessionIDLength int64 `json:"sessionIDLength"` SessionIDLength int64 `json:"sessionIDLength"`
EnableSidInHttpHeader bool `json:"enableSidInHttpHeader"`
SessionNameInHttpHeader string `json:"sessionNameInHttpHeader"`
EnableSidInUrlQuery bool `json:"enableSidInUrlQuery"`
} }
// Manager contains Provider and its configuration. // Manager contains Provider and its configuration.
@ -118,6 +129,19 @@ func NewManager(provideName, config string) (*Manager, error) {
if cf.Maxlifetime == 0 { if cf.Maxlifetime == 0 {
cf.Maxlifetime = cf.Gclifetime cf.Maxlifetime = cf.Gclifetime
} }
if cf.EnableSidInHttpHeader {
if cf.SessionNameInHttpHeader == "" {
panic(errors.New("SessionNameInHttpHeader is empty"))
}
strMimeHeader := textproto.CanonicalMIMEHeaderKey(cf.SessionNameInHttpHeader)
if cf.SessionNameInHttpHeader != strMimeHeader {
strErrMsg := "SessionNameInHttpHeader (" + cf.SessionNameInHttpHeader + ") has the wrong format, it should be like this : " + strMimeHeader
panic(errors.New(strErrMsg))
}
}
err = provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig) err = provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig)
if err != nil { if err != nil {
return nil, err return nil, err
@ -143,12 +167,24 @@ func NewManager(provideName, config string) (*Manager, error) {
func (manager *Manager) getSid(r *http.Request) (string, error) { func (manager *Manager) getSid(r *http.Request) (string, error) {
cookie, errs := r.Cookie(manager.config.CookieName) cookie, errs := r.Cookie(manager.config.CookieName)
if errs != nil || cookie.Value == "" || cookie.MaxAge < 0 { if errs != nil || cookie.Value == "" || cookie.MaxAge < 0 {
var sid string
if manager.config.EnableSidInUrlQuery {
errs := r.ParseForm() errs := r.ParseForm()
if errs != nil { if errs != nil {
return "", errs return "", errs
} }
sid := r.FormValue(manager.config.CookieName) sid = r.FormValue(manager.config.CookieName)
}
// if not found in Cookie / param, then read it from request headers
if manager.config.EnableSidInHttpHeader && sid == "" {
sids, isFound := r.Header[manager.config.SessionNameInHttpHeader]
if isFound && len(sids) != 0 {
return sids[0], nil
}
}
return sid, nil return sid, nil
} }
@ -192,11 +228,21 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
} }
r.AddCookie(cookie) r.AddCookie(cookie)
if manager.config.EnableSidInHttpHeader {
r.Header.Set(manager.config.SessionNameInHttpHeader, sid)
w.Header().Set(manager.config.SessionNameInHttpHeader, sid)
}
return return
} }
// SessionDestroy Destroy session by its id in http request cookie. // SessionDestroy Destroy session by its id in http request cookie.
func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) {
if manager.config.EnableSidInHttpHeader {
r.Header.Del(manager.config.SessionNameInHttpHeader)
w.Header().Del(manager.config.SessionNameInHttpHeader)
}
cookie, err := r.Cookie(manager.config.CookieName) cookie, err := r.Cookie(manager.config.CookieName)
if err != nil || cookie.Value == "" { if err != nil || cookie.Value == "" {
return return
@ -261,6 +307,12 @@ func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Reque
http.SetCookie(w, cookie) http.SetCookie(w, cookie)
} }
r.AddCookie(cookie) r.AddCookie(cookie)
if manager.config.EnableSidInHttpHeader {
r.Header.Set(manager.config.SessionNameInHttpHeader, sid)
w.Header().Set(manager.config.SessionNameInHttpHeader, sid)
}
return return
} }
@ -296,3 +348,15 @@ func (manager *Manager) isSecure(req *http.Request) bool {
} }
return true return true
} }
// Log implement the log.Logger
type Log struct {
*log.Logger
}
// NewSessionLog set io.Writer to create a Logger for session.
func NewSessionLog(out io.Writer) *Log {
sl := new(Log)
sl.Logger = log.New(out, "[SESSION]", 1e9)
return sl
}

View File

@ -27,6 +27,7 @@ import (
"time" "time"
"github.com/astaxie/beego/context" "github.com/astaxie/beego/context"
"github.com/astaxie/beego/logs"
) )
var errNotStaticRequest = errors.New("request not a static file request") var errNotStaticRequest = errors.New("request not a static file request")
@ -48,14 +49,19 @@ func serverStaticRouter(ctx *context.Context) {
if filePath == "" || fileInfo == nil { if filePath == "" || fileInfo == nil {
if BConfig.RunMode == DEV { if BConfig.RunMode == DEV {
Warn("Can't find/open the file:", filePath, err) logs.Warn("Can't find/open the file:", filePath, err)
} }
http.NotFound(ctx.ResponseWriter, ctx.Request) http.NotFound(ctx.ResponseWriter, ctx.Request)
return return
} }
if fileInfo.IsDir() { if fileInfo.IsDir() {
requestURL := ctx.Input.URL()
if requestURL[len(requestURL)-1] != '/' {
ctx.Redirect(302, requestURL+"/")
} else {
//serveFile will list dir //serveFile will list dir
http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath) http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath)
}
return return
} }
@ -67,7 +73,7 @@ func serverStaticRouter(ctx *context.Context) {
b, n, sch, err := openFile(filePath, fileInfo, acceptEncoding) b, n, sch, err := openFile(filePath, fileInfo, acceptEncoding)
if err != nil { if err != nil {
if BConfig.RunMode == DEV { if BConfig.RunMode == DEV {
Warn("Can't compress the file:", filePath, err) logs.Warn("Can't compress the file:", filePath, err)
} }
http.NotFound(ctx.ResponseWriter, ctx.Request) http.NotFound(ctx.ResponseWriter, ctx.Request)
return return

View File

@ -1,160 +0,0 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package swagger struct definition
package swagger
// SwaggerVersion show the current swagger version
const SwaggerVersion = "1.2"
// ResourceListing list the resource
type ResourceListing struct {
APIVersion string `json:"apiVersion"`
SwaggerVersion string `json:"swaggerVersion"` // e.g 1.2
// BasePath string `json:"basePath"` obsolete in 1.1
APIs []APIRef `json:"apis"`
Info Information `json:"info"`
}
// APIRef description the api path and description
type APIRef struct {
Path string `json:"path"` // relative or absolute, must start with /
Description string `json:"description"`
}
// Information show the API Information
type Information struct {
Title string `json:"title,omitempty"`
Description string `json:"description,omitempty"`
Contact string `json:"contact,omitempty"`
TermsOfServiceURL string `json:"termsOfServiceUrl,omitempty"`
License string `json:"license,omitempty"`
LicenseURL string `json:"licenseUrl,omitempty"`
}
// APIDeclaration see https://github.com/wordnik/swagger-core/blob/scala_2.10-1.3-RC3/schemas/api-declaration-schema.json
type APIDeclaration struct {
APIVersion string `json:"apiVersion"`
SwaggerVersion string `json:"swaggerVersion"`
BasePath string `json:"basePath"`
ResourcePath string `json:"resourcePath"` // must start with /
Consumes []string `json:"consumes,omitempty"`
Produces []string `json:"produces,omitempty"`
APIs []API `json:"apis,omitempty"`
Models map[string]Model `json:"models,omitempty"`
}
// API show tha API struct
type API struct {
Path string `json:"path"` // relative or absolute, must start with /
Description string `json:"description"`
Operations []Operation `json:"operations,omitempty"`
}
// Operation desc the Operation
type Operation struct {
HTTPMethod string `json:"httpMethod"`
Nickname string `json:"nickname"`
Type string `json:"type"` // in 1.1 = DataType
// ResponseClass string `json:"responseClass"` obsolete in 1.2
Summary string `json:"summary,omitempty"`
Notes string `json:"notes,omitempty"`
Parameters []Parameter `json:"parameters,omitempty"`
ResponseMessages []ResponseMessage `json:"responseMessages,omitempty"` // optional
Consumes []string `json:"consumes,omitempty"`
Produces []string `json:"produces,omitempty"`
Authorizations []Authorization `json:"authorizations,omitempty"`
Protocols []Protocol `json:"protocols,omitempty"`
}
// Protocol support which Protocol
type Protocol struct {
}
// ResponseMessage Show the
type ResponseMessage struct {
Code int `json:"code"`
Message string `json:"message"`
ResponseModel string `json:"responseModel"`
}
// Parameter desc the request parameters
type Parameter struct {
ParamType string `json:"paramType"` // path,query,body,header,form
Name string `json:"name"`
Description string `json:"description"`
DataType string `json:"dataType"` // 1.2 needed?
Type string `json:"type"` // integer
Format string `json:"format"` // int64
AllowMultiple bool `json:"allowMultiple"`
Required bool `json:"required"`
Minimum int `json:"minimum"`
Maximum int `json:"maximum"`
}
// ErrorResponse desc response
type ErrorResponse struct {
Code int `json:"code"`
Reason string `json:"reason"`
}
// Model define the data model
type Model struct {
ID string `json:"id"`
Required []string `json:"required,omitempty"`
Properties map[string]ModelProperty `json:"properties"`
}
// ModelProperty define the properties
type ModelProperty struct {
Type string `json:"type"`
Description string `json:"description"`
Items map[string]string `json:"items,omitempty"`
Format string `json:"format"`
}
// Authorization see https://github.com/wordnik/swagger-core/wiki/authorizations
type Authorization struct {
LocalOAuth OAuth `json:"local-oauth"`
APIKey APIKey `json:"apiKey"`
}
// OAuth see https://github.com/wordnik/swagger-core/wiki/authorizations
type OAuth struct {
Type string `json:"type"` // e.g. oauth2
Scopes []string `json:"scopes"` // e.g. PUBLIC
GrantTypes map[string]GrantType `json:"grantTypes"`
}
// GrantType see https://github.com/wordnik/swagger-core/wiki/authorizations
type GrantType struct {
LoginEndpoint Endpoint `json:"loginEndpoint"`
TokenName string `json:"tokenName"` // e.g. access_code
TokenRequestEndpoint Endpoint `json:"tokenRequestEndpoint"`
TokenEndpoint Endpoint `json:"tokenEndpoint"`
}
// Endpoint see https://github.com/wordnik/swagger-core/wiki/authorizations
type Endpoint struct {
URL string `json:"url"`
ClientIDName string `json:"clientIdName"`
ClientSecretName string `json:"clientSecretName"`
TokenName string `json:"tokenName"`
}
// APIKey see https://github.com/wordnik/swagger-core/wiki/authorizations
type APIKey struct {
Type string `json:"type"` // e.g. apiKey
PassAs string `json:"passAs"` // e.g. header
}

155
swagger/swagger.go Normal file
View File

@ -0,0 +1,155 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Swagger™ is a project used to describe and document RESTful APIs.
//
// The Swagger specification defines a set of files required to describe such an API. These files can then be used by the Swagger-UI project to display the API and Swagger-Codegen to generate clients in various languages. Additional utilities can also take advantage of the resulting files, such as testing tools.
// Now in version 2.0, Swagger is more enabling than ever. And it's 100% open source software.
// Package swagger struct definition
package swagger
// Swagger list the resource
type Swagger struct {
SwaggerVersion string `json:"swagger,omitempty"`
Infos Information `json:"info"`
Host string `json:"host,omitempty"`
BasePath string `json:"basePath,omitempty"`
Schemes []string `json:"schemes,omitempty"`
Consumes []string `json:"consumes,omitempty"`
Produces []string `json:"produces,omitempty"`
Paths map[string]Item `json:"paths"`
Definitions map[string]Schema `json:"definitions,omitempty"`
SecurityDefinitions map[string]Scurity `json:"securityDefinitions,omitempty"`
Security map[string][]string `json:"security,omitempty"`
Tags []Tag `json:"tags,omitempty"`
ExternalDocs ExternalDocs `json:"externalDocs,omitempty"`
}
// Information Provides metadata about the API. The metadata can be used by the clients if needed.
type Information struct {
Title string `json:"title,omitempty"`
Description string `json:"description,omitempty"`
Version string `json:"version,omitempty"`
TermsOfServiceURL string `json:"termsOfServiceUrl,omitempty"`
Contact Contact `json:"contact,omitempty"`
License License `json:"license,omitempty"`
}
// Contact information for the exposed API.
type Contact struct {
Name string `json:"name,omitempty"`
URL string `json:"url,omitempty"`
EMail string `json:"email,omitempty"`
}
// License information for the exposed API.
type License struct {
Name string `json:"name,omitempty"`
URL string `json:"url,omitempty"`
}
// Item Describes the operations available on a single path.
type Item struct {
Ref string `json:"$ref,omitempty"`
Get *Operation `json:"get,omitempty"`
Put *Operation `json:"put,omitempty"`
Post *Operation `json:"post,omitempty"`
Delete *Operation `json:"delete,omitempty"`
Options *Operation `json:"options,omitempty"`
Head *Operation `json:"head,omitempty"`
Patch *Operation `json:"patch,omitempty"`
}
// Operation Describes a single API operation on a path.
type Operation struct {
Tags []string `json:"tags,omitempty"`
Summary string `json:"summary,omitempty"`
Description string `json:"description,omitempty"`
OperationID string `json:"operationId,omitempty"`
Consumes []string `json:"consumes,omitempty"`
Produces []string `json:"produces,omitempty"`
Schemes []string `json:"schemes,omitempty"`
Parameters []Parameter `json:"parameters,omitempty"`
Responses map[string]Response `json:"responses,omitempty"`
Deprecated bool `json:"deprecated,omitempty"`
}
// Parameter Describes a single operation parameter.
type Parameter struct {
In string `json:"in,omitempty"`
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Required bool `json:"required,omitempty"`
Schema Schema `json:"schema,omitempty"`
Type string `json:"type,omitempty"`
Format string `json:"format,omitempty"`
}
// Schema Object allows the definition of input and output data types.
type Schema struct {
Ref string `json:"$ref,omitempty"`
Title string `json:"title,omitempty"`
Format string `json:"format,omitempty"`
Description string `json:"description,omitempty"`
Required []string `json:"required,omitempty"`
Type string `json:"type,omitempty"`
Properties map[string]Propertie `json:"properties,omitempty"`
}
// Propertie are taken from the JSON Schema definition but their definitions were adjusted to the Swagger Specification
type Propertie struct {
Title string `json:"title,omitempty"`
Description string `json:"description,omitempty"`
Default string `json:"default,omitempty"`
Type string `json:"type,omitempty"`
Example string `json:"example,omitempty"`
Required []string `json:"required,omitempty"`
Format string `json:"format,omitempty"`
ReadOnly bool `json:"readOnly,omitempty"`
Properties map[string]Propertie `json:"properties,omitempty"`
}
// Response as they are returned from executing this operation.
type Response struct {
Description string `json:"description,omitempty"`
Schema Schema `json:"schema,omitempty"`
Ref string `json:"$ref,omitempty"`
}
// Scurity Allows the definition of a security scheme that can be used by the operations
type Scurity struct {
Type string `json:"type,omitempty"` // Valid values are "basic", "apiKey" or "oauth2".
Description string `json:"description,omitempty"`
Name string `json:"name,omitempty"`
In string `json:"in,omitempty"` // Valid values are "query" or "header".
Flow string `json:"flow,omitempty"` // Valid values are "implicit", "password", "application" or "accessCode".
AuthorizationURL string `json:"authorizationUrl,omitempty"`
TokenURL string `json:"tokenUrl,omitempty"`
Scopes map[string]string `json:"scopes,omitempty"` // The available scopes for the OAuth2 security scheme.
}
// Tag Allows adding meta data to a single tag that is used by the Operation Object
type Tag struct {
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
ExternalDocs ExternalDocs `json:"externalDocs,omitempty"`
}
// ExternalDocs include Additional external documentation
type ExternalDocs struct {
Description string `json:"description,omitempty"`
URL string `json:"url,omitempty"`
}

View File

@ -26,19 +26,25 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/utils" "github.com/astaxie/beego/utils"
) )
var ( var (
beegoTplFuncMap = make(template.FuncMap) beegoTplFuncMap = make(template.FuncMap)
// beeTemplates caching map and supported template file extensions. // beeTemplates caching map and supported template file extensions.
beeTemplates = make(map[string]*template.Template) beeTemplates = make(map[string]TemplateRenderer)
templatesLock sync.RWMutex templatesLock sync.RWMutex
// beeTemplateExt stores the template extension which will build // beeTemplateExt stores the template extension which will build
beeTemplateExt = []string{"tpl", "html"} beeTemplateExt = []string{"tpl", "html"}
// beeTemplatePreprocessors stores associations of extension -> preprocessor handler
beeTemplateEngines = map[string]templateHandler{}
) )
func executeTemplate(wr io.Writer, name string, data interface{}) error { // ExecuteTemplate applies the template with name to the specified data object,
// writing the output to wr.
// A template will be executed safely in parallel.
func ExecuteTemplate(wr io.Writer, name string, data interface{}) error {
if BConfig.RunMode == DEV { if BConfig.RunMode == DEV {
templatesLock.RLock() templatesLock.RLock()
defer templatesLock.RUnlock() defer templatesLock.RUnlock()
@ -46,7 +52,7 @@ func executeTemplate(wr io.Writer, name string, data interface{}) error {
if t, ok := beeTemplates[name]; ok { if t, ok := beeTemplates[name]; ok {
err := t.ExecuteTemplate(wr, name, data) err := t.ExecuteTemplate(wr, name, data)
if err != nil { if err != nil {
Trace("template Execute err:", err) logs.Trace("template Execute err:", err)
} }
return err return err
} }
@ -88,6 +94,10 @@ func AddFuncMap(key string, fn interface{}) error {
return nil return nil
} }
type templateHandler func(root, path string, funcs template.FuncMap) (TemplateRenderer, error)
type TemplateRenderer interface {
ExecuteTemplate(wr io.Writer, name string, data interface{}) error
}
type templateFile struct { type templateFile struct {
root string root string
files map[string][]string files map[string][]string
@ -156,13 +166,22 @@ func BuildTemplate(dir string, files ...string) error {
fmt.Printf("filepath.Walk() returned %v\n", err) fmt.Printf("filepath.Walk() returned %v\n", err)
return err return err
} }
buildAllFiles := len(files) == 0
for _, v := range self.files { for _, v := range self.files {
for _, file := range v { for _, file := range v {
if len(files) == 0 || utils.InSlice(file, files) { if buildAllFiles || utils.InSlice(file, files) {
templatesLock.Lock() templatesLock.Lock()
t, err := getTemplate(self.root, file, v...) ext := filepath.Ext(file)
var t TemplateRenderer
if len(ext) == 0 {
t, err = getTemplate(self.root, file, v...)
} else if fn, ok := beeTemplateEngines[ext[1:]]; ok {
t, err = fn(self.root, file, beegoTplFuncMap)
} else {
t, err = getTemplate(self.root, file, v...)
}
if err != nil { if err != nil {
Trace("parse template err:", file, err) logs.Trace("parse template err:", file, err)
} else { } else {
beeTemplates[file] = t beeTemplates[file] = t
} }
@ -240,7 +259,7 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others
var subMods1 [][]string var subMods1 [][]string
t, subMods1, err = getTplDeep(root, otherFile, "", t) t, subMods1, err = getTplDeep(root, otherFile, "", t)
if err != nil { if err != nil {
Trace("template parse file err:", err) logs.Trace("template parse file err:", err)
} else if subMods1 != nil && len(subMods1) > 0 { } else if subMods1 != nil && len(subMods1) > 0 {
t, err = _getTemplate(t, root, subMods1, others...) t, err = _getTemplate(t, root, subMods1, others...)
} }
@ -261,7 +280,7 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others
var subMods1 [][]string var subMods1 [][]string
t, subMods1, err = getTplDeep(root, otherFile, "", t) t, subMods1, err = getTplDeep(root, otherFile, "", t)
if err != nil { if err != nil {
Trace("template parse file err:", err) logs.Trace("template parse file err:", err)
} else if subMods1 != nil && len(subMods1) > 0 { } else if subMods1 != nil && len(subMods1) > 0 {
t, err = _getTemplate(t, root, subMods1, others...) t, err = _getTemplate(t, root, subMods1, others...)
} }
@ -305,3 +324,9 @@ func DelStaticPath(url string) *App {
delete(BConfig.WebConfig.StaticDir, url) delete(BConfig.WebConfig.StaticDir, url)
return BeeApp return BeeApp
} }
func AddTemplateEngine(extension string, fn templateHandler) *App {
AddTemplateExt(extension)
beeTemplateEngines[extension] = fn
return BeeApp
}

View File

@ -389,6 +389,10 @@ func dayMatches(s *Schedule, t time.Time) bool {
// StartTask start all tasks // StartTask start all tasks
func StartTask() { func StartTask() {
if isstart {
//If already started no need to start another goroutine.
return
}
isstart = true isstart = true
go run() go run()
} }
@ -432,8 +436,11 @@ func run() {
// StopTask stop all tasks // StopTask stop all tasks
func StopTask() { func StopTask() {
if isstart {
isstart = false isstart = false
stop <- true stop <- true
}
} }
// AddTask add task with name // AddTask add task with name

View File

@ -69,6 +69,7 @@ import (
"github.com/astaxie/beego" "github.com/astaxie/beego"
"github.com/astaxie/beego/cache" "github.com/astaxie/beego/cache"
"github.com/astaxie/beego/context" "github.com/astaxie/beego/context"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/utils" "github.com/astaxie/beego/utils"
) )
@ -139,7 +140,7 @@ func (c *Captcha) Handler(ctx *context.Context) {
if err := c.store.Put(key, chars, c.Expiration); err != nil { if err := c.store.Put(key, chars, c.Expiration); err != nil {
ctx.Output.SetStatus(500) ctx.Output.SetStatus(500)
ctx.WriteString("captcha reload error") ctx.WriteString("captcha reload error")
beego.Error("Reload Create Captcha Error:", err) logs.Error("Reload Create Captcha Error:", err)
return return
} }
} else { } else {
@ -154,7 +155,7 @@ func (c *Captcha) Handler(ctx *context.Context) {
img := NewImage(chars, c.StdWidth, c.StdHeight) img := NewImage(chars, c.StdWidth, c.StdHeight)
if _, err := img.WriteTo(ctx.ResponseWriter); err != nil { if _, err := img.WriteTo(ctx.ResponseWriter); err != nil {
beego.Error("Write Captcha Image Error:", err) logs.Error("Write Captcha Image Error:", err)
} }
} }
@ -162,7 +163,7 @@ func (c *Captcha) Handler(ctx *context.Context) {
func (c *Captcha) CreateCaptchaHTML() template.HTML { func (c *Captcha) CreateCaptchaHTML() template.HTML {
value, err := c.CreateCaptcha() value, err := c.CreateCaptcha()
if err != nil { if err != nil {
beego.Error("Create Captcha Error:", err) logs.Error("Create Captcha Error:", err)
return "" return ""
} }

View File

@ -18,7 +18,7 @@ import (
"github.com/astaxie/beego/context" "github.com/astaxie/beego/context"
) )
// SetPaginator Instantiates a Paginator and assigns it to context.Input.Data["paginator"]. // SetPaginator Instantiates a Paginator and assigns it to context.Input.Data("paginator").
func SetPaginator(context *context.Context, per int, nums int64) (paginator *Paginator) { func SetPaginator(context *context.Context, per int, nums int64) (paginator *Paginator) {
paginator = NewPaginator(context.Request, per, nums) paginator = NewPaginator(context.Request, per, nums)
context.Input.SetData("paginator", &paginator) context.Input.SetData("paginator", &paginator)

View File

@ -20,29 +20,25 @@ import (
"time" "time"
) )
var alphaNum = []byte(`0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz`)
// RandomCreateBytes generate random []byte by specify chars. // RandomCreateBytes generate random []byte by specify chars.
func RandomCreateBytes(n int, alphabets ...byte) []byte { func RandomCreateBytes(n int, alphabets ...byte) []byte {
const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" if len(alphabets) == 0 {
alphabets = alphaNum
}
var bytes = make([]byte, n) var bytes = make([]byte, n)
var randby bool var randBy bool
if num, err := rand.Read(bytes); num != n || err != nil { if num, err := rand.Read(bytes); num != n || err != nil {
r.Seed(time.Now().UnixNano()) r.Seed(time.Now().UnixNano())
randby = true randBy = true
} }
for i, b := range bytes { for i, b := range bytes {
if len(alphabets) == 0 { if randBy {
if randby {
bytes[i] = alphanum[r.Intn(len(alphanum))]
} else {
bytes[i] = alphanum[b%byte(len(alphanum))]
}
} else {
if randby {
bytes[i] = alphabets[r.Intn(len(alphabets))] bytes[i] = alphabets[r.Intn(len(alphabets))]
} else { } else {
bytes[i] = alphabets[b%byte(len(alphabets))] bytes[i] = alphabets[b%byte(len(alphabets))]
} }
} }
}
return bytes return bytes
} }

View File

@ -1,4 +1,4 @@
// Copyright 2014 beego Author. All Rights Reserved. // Copyright 2016 beego Author. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -12,28 +12,22 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package beego package utils
import ( import "testing"
"github.com/astaxie/beego/context"
)
// GlobalDocAPI store the swagger api documents func TestRand_01(t *testing.T) {
var GlobalDocAPI = make(map[string]interface{}) bs0 := RandomCreateBytes(16)
bs1 := RandomCreateBytes(16)
func serverDocs(ctx *context.Context) { t.Log(string(bs0), string(bs1))
var obj interface{} if string(bs0) == string(bs1) {
if splat := ctx.Input.Param(":splat"); splat == "" { t.FailNow()
obj = GlobalDocAPI["Root"]
} else {
if v, ok := GlobalDocAPI[splat]; ok {
obj = v
} }
bs0 = RandomCreateBytes(4, []byte(`a`)...)
if string(bs0) != "aaaa" {
t.FailNow()
} }
if obj != nil {
ctx.Output.Header("Access-Control-Allow-Origin", "*")
ctx.Output.JSON(obj, false, false)
return
}
ctx.Output.SetStatus(404)
} }