diff --git a/.travis.yml b/.travis.yml index 133ec3cc..31a76557 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,9 +1,8 @@ language: go go: - - 1.7.5 - - 1.8.5 - - 1.9.2 + - "1.9.7" + - "1.10.3" services: - redis-server - mysql @@ -22,10 +21,11 @@ install: - go get github.com/go-sql-driver/mysql - go get github.com/mattn/go-sqlite3 - go get github.com/bradfitz/gomemcache/memcache - - go get github.com/garyburd/redigo/redis + - go get github.com/gomodule/redigo/redis - go get github.com/beego/x2j - go get github.com/couchbase/go-couchbase - go get github.com/beego/goyaml2 + - go get gopkg.in/yaml.v2 - go get github.com/belogik/goes - go get github.com/siddontang/ledisdb/config - go get github.com/siddontang/ledisdb/ledis @@ -38,13 +38,14 @@ install: - go get -u github.com/mdempsky/unconvert - go get -u github.com/gordonklaus/ineffassign - go get -u github.com/golang/lint/golint + - go get -u github.com/go-redis/redis before_script: - psql --version - sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi" - sh -c "if [ '$ORM_DRIVER' = 'mysql' ]; then mysql -u root -e 'create database orm_test;'; fi" - sh -c "if [ '$ORM_DRIVER' = 'sqlite' ]; then touch $TRAVIS_BUILD_DIR/orm_test.db; fi" - - sh -c "if [ $(go version) == *1.[5-9]* ]; then go get github.com/golang/lint/golint; golint ./...; fi" - - sh -c "if [ $(go version) == *1.[5-9]* ]; then go tool vet .; fi" + - sh -c "go get github.com/golang/lint/golint; golint ./...;" + - sh -c "go tool vet ." - mkdir -p res/var - ./ssdb/ssdb-server ./ssdb/ssdb.conf -d after_script: @@ -58,4 +59,4 @@ script: - find . ! \( -path './vendor' -prune \) -type f -name '*.go' -print0 | xargs -0 gofmt -l -s - golint ./... addons: - postgresql: "9.4" + postgresql: "9.6" diff --git a/admin.go b/admin.go index 0688dcbc..73d4f9f2 100644 --- a/admin.go +++ b/admin.go @@ -76,6 +76,18 @@ func adminIndex(rw http.ResponseWriter, r *http.Request) { func qpsIndex(rw http.ResponseWriter, r *http.Request) { data := make(map[interface{}]interface{}) data["Content"] = toolbox.StatisticsMap.GetMap() + + // do html escape before display path, avoid xss + if content, ok := (data["Content"]).(map[string]interface{}); ok { + if resultLists, ok := (content["Data"]).([][]string); ok { + for i := range resultLists { + if len(resultLists[i]) > 0 { + resultLists[i][0] = template.HTMLEscapeString(resultLists[i][0]) + } + } + } + } + execTpl(rw, data, qpsTpl, defaultScriptsTpl) } diff --git a/admin_test.go b/admin_test.go index a99da6fc..03a3c044 100644 --- a/admin_test.go +++ b/admin_test.go @@ -67,6 +67,7 @@ func oldMap() map[string]interface{} { m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain m["BConfig.WebConfig.Session.SessionDisableHTTPOnly"] = BConfig.WebConfig.Session.SessionDisableHTTPOnly m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs + m["BConfig.Log.EnableStaticLogs"] = BConfig.Log.EnableStaticLogs m["BConfig.Log.AccessLogsFormat"] = BConfig.Log.AccessLogsFormat m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum m["BConfig.Log.Outputs"] = BConfig.Log.Outputs diff --git a/app.go b/app.go index 0c07117a..32ff159d 100644 --- a/app.go +++ b/app.go @@ -24,12 +24,13 @@ import ( "net/http/fcgi" "os" "path" - "time" "strings" + "time" "github.com/astaxie/beego/grace" "github.com/astaxie/beego/logs" "github.com/astaxie/beego/utils" + "golang.org/x/crypto/acme/autocert" ) var ( @@ -101,7 +102,7 @@ func (app *App) Run(mws ...MiddleWare) { } app.Server.Handler = app.Handlers - for i:=len(mws)-1;i>=0;i-- { + for i := len(mws) - 1; i >= 0; i-- { if mws[i] == nil { continue } @@ -117,7 +118,7 @@ func (app *App) Run(mws ...MiddleWare) { app.Server.Addr = httpsAddr if BConfig.Listen.EnableHTTPS || BConfig.Listen.EnableMutualHTTPS { go func() { - time.Sleep(20 * time.Microsecond) + time.Sleep(1000 * time.Microsecond) if BConfig.Listen.HTTPSPort != 0 { httpsAddr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) app.Server.Addr = httpsAddr @@ -126,13 +127,21 @@ func (app *App) Run(mws ...MiddleWare) { server.Server.ReadTimeout = app.Server.ReadTimeout server.Server.WriteTimeout = app.Server.WriteTimeout if BConfig.Listen.EnableMutualHTTPS { - if err := server.ListenAndServeMutualTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile, BConfig.Listen.TrustCaFile); err != nil { logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) time.Sleep(100 * time.Microsecond) endRunning <- true } } else { + if BConfig.Listen.AutoTLS { + m := autocert.Manager{ + Prompt: autocert.AcceptTOS, + HostPolicy: autocert.HostWhitelist(BConfig.Listen.Domains...), + Cache: autocert.DirCache(BConfig.Listen.TLSCacheDir), + } + app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate} + BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile = "", "" + } if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) time.Sleep(100 * time.Microsecond) @@ -163,15 +172,23 @@ func (app *App) Run(mws ...MiddleWare) { // run normal mode if BConfig.Listen.EnableHTTPS || BConfig.Listen.EnableMutualHTTPS { go func() { - time.Sleep(20 * time.Microsecond) + time.Sleep(1000 * time.Microsecond) if BConfig.Listen.HTTPSPort != 0 { app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) } else if BConfig.Listen.EnableHTTP { - BeeLogger.Info("Start https server error, confict with http.Please reset https port") + BeeLogger.Info("Start https server error, conflict with http. Please reset https port") return } logs.Info("https server Running on https://%s", app.Server.Addr) - if BConfig.Listen.EnableMutualHTTPS { + if BConfig.Listen.AutoTLS { + m := autocert.Manager{ + Prompt: autocert.AcceptTOS, + HostPolicy: autocert.HostWhitelist(BConfig.Listen.Domains...), + Cache: autocert.DirCache(BConfig.Listen.TLSCacheDir), + } + app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate} + BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile = "", "" + } else if BConfig.Listen.EnableMutualHTTPS { pool := x509.NewCertPool() data, err := ioutil.ReadFile(BConfig.Listen.TrustCaFile) if err != nil { @@ -190,6 +207,7 @@ func (app *App) Run(mws ...MiddleWare) { endRunning <- true } }() + } if BConfig.Listen.EnableHTTP { go func() { diff --git a/beego.go b/beego.go index fdbdc798..407fcc4c 100644 --- a/beego.go +++ b/beego.go @@ -23,7 +23,7 @@ import ( const ( // VERSION represent beego web framework version. - VERSION = "1.9.2" + VERSION = "1.10.0" // DEV is for develop DEV = "dev" @@ -62,6 +62,8 @@ func Run(params ...string) { if len(strs) > 1 && strs[1] != "" { BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1]) } + + BConfig.Listen.Domains = params } BeeApp.Run() @@ -74,6 +76,7 @@ func RunWithMiddleWares(addr string, mws ...MiddleWare) { strs := strings.Split(addr, ":") if len(strs) > 0 && strs[0] != "" { BConfig.Listen.HTTPAddr = strs[0] + BConfig.Listen.Domains = []string{strs[0]} } if len(strs) > 1 && strs[1] != "" { BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1]) diff --git a/cache/README.md b/cache/README.md index 957790e7..b467760a 100644 --- a/cache/README.md +++ b/cache/README.md @@ -52,7 +52,7 @@ Configure like this: ## Redis adapter -Redis adapter use the [redigo](http://github.com/garyburd/redigo) client. +Redis adapter use the [redigo](http://github.com/gomodule/redigo) client. Configure like this: diff --git a/cache/redis/redis.go b/cache/redis/redis.go index 1da22480..55fea25e 100644 --- a/cache/redis/redis.go +++ b/cache/redis/redis.go @@ -14,9 +14,9 @@ // Package redis for cache provider // -// depend on github.com/garyburd/redigo/redis +// depend on github.com/gomodule/redigo/redis // -// go install github.com/garyburd/redigo/redis +// go install github.com/gomodule/redigo/redis // // Usage: // import( @@ -36,7 +36,7 @@ import ( "strconv" "time" - "github.com/garyburd/redigo/redis" + "github.com/gomodule/redigo/redis" "github.com/astaxie/beego/cache" ) @@ -53,6 +53,7 @@ type Cache struct { dbNum int key string password string + maxIdle int } // NewRedisCache create new redis cache with default collection name. @@ -169,10 +170,14 @@ func (rc *Cache) StartAndGC(config string) error { if _, ok := cf["password"]; !ok { cf["password"] = "" } + if _, ok := cf["maxIdle"]; !ok { + cf["maxIdle"] = "3" + } rc.key = cf["key"] rc.conninfo = cf["conn"] rc.dbNum, _ = strconv.Atoi(cf["dbNum"]) rc.password = cf["password"] + rc.maxIdle, _ = strconv.Atoi(cf["maxIdle"]) rc.connectInit() @@ -206,7 +211,7 @@ func (rc *Cache) connectInit() { } // initialize a new pool rc.p = &redis.Pool{ - MaxIdle: 3, + MaxIdle: rc.maxIdle, IdleTimeout: 180 * time.Second, Dial: dialFunc, } diff --git a/cache/redis/redis_test.go b/cache/redis/redis_test.go index 6b81da4d..56877f6b 100644 --- a/cache/redis/redis_test.go +++ b/cache/redis/redis_test.go @@ -19,7 +19,7 @@ import ( "time" "github.com/astaxie/beego/cache" - "github.com/garyburd/redigo/redis" + "github.com/gomodule/redigo/redis" ) func TestRedisCache(t *testing.T) { diff --git a/config.go b/config.go index eeeac8ee..7969dcea 100644 --- a/config.go +++ b/config.go @@ -55,6 +55,9 @@ type Listen struct { EnableHTTP bool HTTPAddr string HTTPPort int + AutoTLS bool + Domains []string + TLSCacheDir string EnableHTTPS bool EnableMutualHTTPS bool HTTPSAddr string @@ -98,14 +101,15 @@ type SessionConfig struct { SessionAutoSetCookie bool SessionDomain string SessionDisableHTTPOnly bool // used to allow for cross domain cookies/javascript cookies. - SessionEnableSidInHTTPHeader bool // enable store/get the sessionId into/from http headers + SessionEnableSidInHTTPHeader bool // enable store/get the sessionId into/from http headers SessionNameInHTTPHeader string - SessionEnableSidInURLQuery bool // enable get the sessionId from Url Query params + SessionEnableSidInURLQuery bool // enable get the sessionId from Url Query params } // LogConfig holds Log related config type LogConfig struct { AccessLogs bool + EnableStaticLogs bool //log static files requests default: false AccessLogsFormat string //access log format: JSON_FORMAT, APACHE_FORMAT or empty string FileLineNum bool Outputs map[string]string // Store Adaptor : config @@ -138,8 +142,8 @@ func init() { panic(err) } var filename = "app.conf" - if os.Getenv("BEEGO_MODE") != "" { - filename = os.Getenv("BEEGO_MODE") + ".app.conf" + if os.Getenv("BEEGO_RUNMODE") != "" { + filename = os.Getenv("BEEGO_RUNMODE") + ".app.conf" } appConfigPath = filepath.Join(workPath, "conf", filename) if !utils.FileExists(appConfigPath) { @@ -182,13 +186,18 @@ func recoverPanic(ctx *context.Context) { if BConfig.RunMode == DEV && BConfig.EnableErrorsRender { showErr(err, ctx, stack) } + if ctx.Output.Status != 0 { + ctx.ResponseWriter.WriteHeader(ctx.Output.Status) + } else { + ctx.ResponseWriter.WriteHeader(500) + } } } func newBConfig() *Config { return &Config{ AppName: "beego", - RunMode: DEV, + RunMode: PROD, RouterCaseSensitive: true, ServerName: "beegoServer:" + VERSION, RecoverPanic: true, @@ -203,6 +212,9 @@ func newBConfig() *Config { ServerTimeOut: 0, ListenTCP4: false, EnableHTTP: true, + AutoTLS: false, + Domains: []string{}, + TLSCacheDir: ".", HTTPAddr: "", HTTPPort: 8080, EnableHTTPS: false, @@ -240,13 +252,14 @@ func newBConfig() *Config { SessionCookieLifeTime: 0, //set cookie default is the browser life SessionAutoSetCookie: true, SessionDomain: "", - SessionEnableSidInHTTPHeader: false, // enable store/get the sessionId into/from http headers + SessionEnableSidInHTTPHeader: false, // enable store/get the sessionId into/from http headers SessionNameInHTTPHeader: "Beegosessionid", - SessionEnableSidInURLQuery: false, // enable get the sessionId from Url Query params + SessionEnableSidInURLQuery: false, // enable get the sessionId from Url Query params }, }, Log: LogConfig{ AccessLogs: false, + EnableStaticLogs: false, AccessLogsFormat: "APACHE_FORMAT", FileLineNum: true, Outputs: map[string]string{"console": ""}, diff --git a/config/config.go b/config/config.go index c620504a..bfd79e85 100644 --- a/config/config.go +++ b/config/config.go @@ -150,12 +150,12 @@ func ExpandValueEnv(value string) (realValue string) { } key := "" - defalutV := "" + defaultV := "" // value start with "${" for i := 2; i < vLen; i++ { if value[i] == '|' && (i+1 < vLen && value[i+1] == '|') { key = value[2:i] - defalutV = value[i+2 : vLen-1] // other string is default value. + defaultV = value[i+2 : vLen-1] // other string is default value. break } else if value[i] == '}' { key = value[2:i] @@ -165,7 +165,7 @@ func ExpandValueEnv(value string) (realValue string) { realValue = os.Getenv(key) if realValue == "" { - realValue = defalutV + realValue = defaultV } return diff --git a/config/fake.go b/config/fake.go index f5144598..d21ab820 100644 --- a/config/fake.go +++ b/config/fake.go @@ -126,7 +126,7 @@ func (c *fakeConfigContainer) SaveConfigFile(filename string) error { var _ Configer = new(fakeConfigContainer) -// NewFakeConfig return a fake Congiger +// NewFakeConfig return a fake Configer func NewFakeConfig() Configer { return &fakeConfigContainer{ data: make(map[string]string), diff --git a/config/ini.go b/config/ini.go index a681bc1b..1376d933 100644 --- a/config/ini.go +++ b/config/ini.go @@ -215,7 +215,7 @@ func (c *IniConfigContainer) Bool(key string) (bool, error) { } // DefaultBool returns the boolean value for a given key. -// if err != nil return defaltval +// if err != nil return defaultval func (c *IniConfigContainer) DefaultBool(key string, defaultval bool) bool { v, err := c.Bool(key) if err != nil { @@ -230,7 +230,7 @@ func (c *IniConfigContainer) Int(key string) (int, error) { } // DefaultInt returns the integer value for a given key. -// if err != nil return defaltval +// if err != nil return defaultval func (c *IniConfigContainer) DefaultInt(key string, defaultval int) int { v, err := c.Int(key) if err != nil { @@ -245,7 +245,7 @@ func (c *IniConfigContainer) Int64(key string) (int64, error) { } // DefaultInt64 returns the int64 value for a given key. -// if err != nil return defaltval +// if err != nil return defaultval func (c *IniConfigContainer) DefaultInt64(key string, defaultval int64) int64 { v, err := c.Int64(key) if err != nil { @@ -260,7 +260,7 @@ func (c *IniConfigContainer) Float(key string) (float64, error) { } // DefaultFloat returns the float64 value for a given key. -// if err != nil return defaltval +// if err != nil return defaultval func (c *IniConfigContainer) DefaultFloat(key string, defaultval float64) float64 { v, err := c.Float(key) if err != nil { @@ -275,7 +275,7 @@ func (c *IniConfigContainer) String(key string) string { } // DefaultString returns the string value for a given key. -// if err != nil return defaltval +// if err != nil return defaultval func (c *IniConfigContainer) DefaultString(key string, defaultval string) string { v := c.String(key) if v == "" { @@ -295,7 +295,7 @@ func (c *IniConfigContainer) Strings(key string) []string { } // DefaultStrings returns the []string value for a given key. -// if err != nil return defaltval +// if err != nil return defaultval func (c *IniConfigContainer) DefaultStrings(key string, defaultval []string) []string { v := c.Strings(key) if v == nil { @@ -314,7 +314,7 @@ func (c *IniConfigContainer) GetSection(section string) (map[string]string, erro // SaveConfigFile save the config into file. // -// BUG(env): The environment variable config item will be saved with real value in SaveConfigFile Funcation. +// BUG(env): The environment variable config item will be saved with real value in SaveConfigFile Function. func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { // Write configuration file by filename. f, err := os.Create(filename) diff --git a/config/json.go b/config/json.go index a0d93210..74c18c9c 100644 --- a/config/json.go +++ b/config/json.go @@ -101,7 +101,7 @@ func (c *JSONConfigContainer) Int(key string) (int, error) { } // DefaultInt returns the integer value for a given key. -// if err != nil return defaltval +// if err != nil return defaultval func (c *JSONConfigContainer) DefaultInt(key string, defaultval int) int { if v, err := c.Int(key); err == nil { return v @@ -122,7 +122,7 @@ func (c *JSONConfigContainer) Int64(key string) (int64, error) { } // DefaultInt64 returns the int64 value for a given key. -// if err != nil return defaltval +// if err != nil return defaultval func (c *JSONConfigContainer) DefaultInt64(key string, defaultval int64) int64 { if v, err := c.Int64(key); err == nil { return v @@ -143,7 +143,7 @@ func (c *JSONConfigContainer) Float(key string) (float64, error) { } // DefaultFloat returns the float64 value for a given key. -// if err != nil return defaltval +// if err != nil return defaultval func (c *JSONConfigContainer) DefaultFloat(key string, defaultval float64) float64 { if v, err := c.Float(key); err == nil { return v @@ -163,7 +163,7 @@ func (c *JSONConfigContainer) String(key string) string { } // DefaultString returns the string value for a given key. -// if err != nil return defaltval +// if err != nil return defaultval func (c *JSONConfigContainer) DefaultString(key string, defaultval string) string { // TODO FIXME should not use "" to replace non existence if v := c.String(key); v != "" { @@ -182,7 +182,7 @@ func (c *JSONConfigContainer) Strings(key string) []string { } // DefaultStrings returns the []string value for a given key. -// if err != nil return defaltval +// if err != nil return defaultval func (c *JSONConfigContainer) DefaultStrings(key string, defaultval []string) []string { if v := c.Strings(key); v != nil { return v diff --git a/config/json_test.go b/config/json_test.go index 24ff9644..16f42409 100644 --- a/config/json_test.go +++ b/config/json_test.go @@ -216,7 +216,7 @@ func TestJson(t *testing.T) { t.Error("unknown keys should return an error when expecting a Bool") } - if !jsonconf.DefaultBool("unknow", true) { + if !jsonconf.DefaultBool("unknown", true) { t.Error("unknown keys with default value wrong") } } diff --git a/config/xml/xml.go b/config/xml/xml.go index b82bf403..494242d3 100644 --- a/config/xml/xml.go +++ b/config/xml/xml.go @@ -102,7 +102,7 @@ func (c *ConfigContainer) Int(key string) (int, error) { } // DefaultInt returns the integer value for a given key. -// if err != nil return defaltval +// if err != nil return defaultval func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { v, err := c.Int(key) if err != nil { @@ -117,7 +117,7 @@ func (c *ConfigContainer) Int64(key string) (int64, error) { } // DefaultInt64 returns the int64 value for a given key. -// if err != nil return defaltval +// if err != nil return defaultval func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { v, err := c.Int64(key) if err != nil { @@ -133,7 +133,7 @@ func (c *ConfigContainer) Float(key string) (float64, error) { } // DefaultFloat returns the float64 value for a given key. -// if err != nil return defaltval +// if err != nil return defaultval func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { v, err := c.Float(key) if err != nil { @@ -151,7 +151,7 @@ func (c *ConfigContainer) String(key string) string { } // DefaultString returns the string value for a given key. -// if err != nil return defaltval +// if err != nil return defaultval func (c *ConfigContainer) DefaultString(key string, defaultval string) string { v := c.String(key) if v == "" { @@ -170,7 +170,7 @@ func (c *ConfigContainer) Strings(key string) []string { } // DefaultStrings returns the []string value for a given key. -// if err != nil return defaltval +// if err != nil return defaultval func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { v := c.Strings(key) if v == nil { diff --git a/config/yaml/yaml.go b/config/yaml/yaml.go index 51fe44d3..881737e3 100644 --- a/config/yaml/yaml.go +++ b/config/yaml/yaml.go @@ -119,7 +119,7 @@ func parseYML(buf []byte) (cnf map[string]interface{}, err error) { // ConfigContainer A Config represents the yaml configuration. type ConfigContainer struct { data map[string]interface{} - sync.Mutex + sync.RWMutex } // Bool returns the boolean value for a given key. @@ -154,7 +154,7 @@ func (c *ConfigContainer) Int(key string) (int, error) { } // DefaultInt returns the integer value for a given key. -// if err != nil return defaltval +// if err != nil return defaultval func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { v, err := c.Int(key) if err != nil { @@ -174,7 +174,7 @@ func (c *ConfigContainer) Int64(key string) (int64, error) { } // DefaultInt64 returns the int64 value for a given key. -// if err != nil return defaltval +// if err != nil return defaultval func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { v, err := c.Int64(key) if err != nil { @@ -198,7 +198,7 @@ func (c *ConfigContainer) Float(key string) (float64, error) { } // DefaultFloat returns the float64 value for a given key. -// if err != nil return defaltval +// if err != nil return defaultval func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { v, err := c.Float(key) if err != nil { @@ -218,7 +218,7 @@ func (c *ConfigContainer) String(key string) string { } // DefaultString returns the string value for a given key. -// if err != nil return defaltval +// if err != nil return defaultval func (c *ConfigContainer) DefaultString(key string, defaultval string) string { v := c.String(key) if v == "" { @@ -237,7 +237,7 @@ func (c *ConfigContainer) Strings(key string) []string { } // DefaultStrings returns the []string value for a given key. -// if err != nil return defaltval +// if err != nil return defaultval func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { v := c.Strings(key) if v == nil { @@ -285,9 +285,25 @@ func (c *ConfigContainer) getData(key string) (interface{}, error) { if len(key) == 0 { return nil, errors.New("key is empty") } + c.RLock() + defer c.RUnlock() - if v, ok := c.data[key]; ok { - return v, nil + keys := strings.Split(key, ".") + tmpData := c.data + for _, k := range keys { + if v, ok := tmpData[k]; ok { + switch v.(type) { + case map[string]interface{}: + { + tmpData = v.(map[string]interface{}) + } + default: + { + return v, nil + } + + } + } } return nil, fmt.Errorf("not exist key %q", key) } diff --git a/config_test.go b/config_test.go index c1973f7b..379fe1c6 100644 --- a/config_test.go +++ b/config_test.go @@ -75,7 +75,7 @@ func TestAssignConfig_02(t *testing.T) { jcf := &config.JSONConfig{} bs, _ = json.Marshal(configMap) - ac, _ := jcf.ParseData([]byte(bs)) + ac, _ := jcf.ParseData(bs) for _, i := range []interface{}{_BConfig, &_BConfig.Listen, &_BConfig.WebConfig, &_BConfig.Log, &_BConfig.WebConfig.Session} { assignSingleConfig(i, ac) diff --git a/context/input.go b/context/input.go index 168c709a..81952158 100644 --- a/context/input.go +++ b/context/input.go @@ -37,6 +37,7 @@ var ( acceptsHTMLRegex = regexp.MustCompile(`(text/html|application/xhtml\+xml)(?:,|$)`) acceptsXMLRegex = regexp.MustCompile(`(application/xml|text/xml)(?:,|$)`) acceptsJSONRegex = regexp.MustCompile(`(application/json)(?:,|$)`) + acceptsYAMLRegex = regexp.MustCompile(`(application/x-yaml)(?:,|$)`) maxParam = 50 ) @@ -203,6 +204,10 @@ func (input *BeegoInput) AcceptsXML() bool { func (input *BeegoInput) AcceptsJSON() bool { return acceptsJSONRegex.MatchString(input.Header("Accept")) } +// AcceptsYAML Checks if request accepts json response +func (input *BeegoInput) AcceptsYAML() bool { + return acceptsYAMLRegex.MatchString(input.Header("Accept")) +} // IP returns request client ip. // if in proxy, return first proxy id. diff --git a/context/output.go b/context/output.go index 61ce8cc7..3cc6044e 100644 --- a/context/output.go +++ b/context/output.go @@ -30,6 +30,7 @@ import ( "strconv" "strings" "time" + "gopkg.in/yaml.v2" ) // BeegoOutput does work for sending response header. @@ -182,8 +183,8 @@ func errorRenderer(err error) Renderer { } // JSON writes json to response body. -// if coding is true, it converts utf-8 to \u0000 type. -func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, coding bool) error { +// if encoding is true, it converts utf-8 to \u0000 type. +func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, encoding bool) error { output.Header("Content-Type", "application/json; charset=utf-8") var content []byte var err error @@ -196,12 +197,26 @@ func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, coding bool) e http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) return err } - if coding { + if encoding { content = []byte(stringsToJSON(string(content))) } return output.Body(content) } + +// YAML writes yaml to response body. +func (output *BeegoOutput) YAML(data interface{}) error { + output.Header("Content-Type", "application/application/x-yaml; charset=utf-8") + var content []byte + var err error + content, err = yaml.Marshal(data) + if err != nil { + http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) + return err + } + return output.Body(content) +} + // JSONP writes jsonp to response body. func (output *BeegoOutput) JSONP(data interface{}, hasIndent bool) error { output.Header("Content-Type", "application/javascript; charset=utf-8") @@ -260,7 +275,7 @@ func (output *BeegoOutput) Download(file string, filename ...string) { } else { fName = filepath.Base(file) } - output.Header("Content-Disposition", "attachment; filename="+url.QueryEscape(fName)) + output.Header("Content-Disposition", "attachment; filename="+url.PathEscape(fName)) output.Header("Content-Description", "File Transfer") output.Header("Content-Type", "application/octet-stream") output.Header("Content-Transfer-Encoding", "binary") @@ -325,13 +340,13 @@ func (output *BeegoOutput) IsForbidden() bool { } // IsNotFound returns boolean of this request is not found. -// HTTP 404 means forbidden. +// HTTP 404 means not found. func (output *BeegoOutput) IsNotFound() bool { return output.Status == 404 } // IsClientError returns boolean of this request client sends error data. -// HTTP 4xx means forbidden. +// HTTP 4xx means client error. func (output *BeegoOutput) IsClientError() bool { return output.Status >= 400 && output.Status < 500 } @@ -350,6 +365,11 @@ func stringsToJSON(str string) string { jsons.WriteRune(r) } else { jsons.WriteString("\\u") + if rint < 0x100 { + jsons.WriteString("00") + } else if rint < 0x1000 { + jsons.WriteString("0") + } jsons.WriteString(strconv.FormatInt(int64(rint), 16)) } } diff --git a/controller.go b/controller.go index c104eb2a..8be43a33 100644 --- a/controller.go +++ b/controller.go @@ -36,6 +36,7 @@ import ( const ( applicationJSON = "application/json" applicationXML = "application/xml" + applicationYAML = "application/x-yaml" textXML = "text/xml" ) @@ -272,7 +273,22 @@ func (c *Controller) viewPath() string { // Redirect sends the redirection response to url with status code. func (c *Controller) Redirect(url string, code int) { + logAccess(c.Ctx, nil, code) c.Ctx.Redirect(code, url) + panic(ErrAbort) +} + +// Set the data depending on the accepted +func (c *Controller) SetData(data interface{}) { + accept := c.Ctx.Input.Header("Accept") + switch accept { + case applicationJSON: + c.Data["json"] = data + case applicationXML, textXML: + c.Data["xml"] = data + default: + c.Data["json"] = data + } } // Abort stops controller handler and show the error data if code is defined in ErrorMap or code string. @@ -347,6 +363,11 @@ func (c *Controller) ServeXML() { c.Ctx.Output.XML(c.Data["xml"], hasIndent) } +// ServeXML sends xml response. +func (c *Controller) ServeYAML() { + c.Ctx.Output.YAML(c.Data["yaml"]) +} + // ServeFormatted serve Xml OR Json, depending on the value of the Accept header func (c *Controller) ServeFormatted() { accept := c.Ctx.Input.Header("Accept") @@ -355,6 +376,8 @@ func (c *Controller) ServeFormatted() { c.ServeJSON() case applicationXML, textXML: c.ServeXML() + case applicationYAML: + c.ServeYAML() default: c.ServeJSON() } diff --git a/error.go b/error.go index b913db39..7a50a5b3 100644 --- a/error.go +++ b/error.go @@ -28,7 +28,7 @@ import ( ) const ( - errorTypeHandler = iota + errorTypeHandler = iota errorTypeController ) @@ -93,11 +93,6 @@ func showErr(err interface{}, ctx *context.Context, stack string) { "BeegoVersion": VERSION, "GoVersion": runtime.Version(), } - if ctx.Output.Status != 0 { - ctx.ResponseWriter.WriteHeader(ctx.Output.Status) - } else { - ctx.ResponseWriter.WriteHeader(500) - } t.Execute(ctx.ResponseWriter, data) } @@ -439,6 +434,9 @@ func exception(errCode string, ctx *context.Context) { } func executeError(err *errorInfo, ctx *context.Context, code int) { + //make sure to log the error in the access log + logAccess(ctx, nil, code) + if err.errorType == errorTypeHandler { ctx.ResponseWriter.WriteHeader(code) err.handler(ctx.ResponseWriter, ctx.Request) diff --git a/httplib/httplib.go b/httplib/httplib.go index 5d82ab33..074cf661 100644 --- a/httplib/httplib.go +++ b/httplib/httplib.go @@ -50,6 +50,7 @@ import ( "strings" "sync" "time" + "gopkg.in/yaml.v2" ) var defaultSetting = BeegoHTTPSettings{ @@ -317,6 +318,7 @@ func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest { } return b } + // XMLBody adds request raw body encoding by XML. func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) { if b.req.Body == nil && obj != nil { @@ -330,6 +332,21 @@ func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) { } return b, nil } + +// YAMLBody adds request raw body encoding by YAML. +func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error) { + if b.req.Body == nil && obj != nil { + byts, err := yaml.Marshal(obj) + if err != nil { + return b, err + } + b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) + b.req.ContentLength = int64(len(byts)) + b.req.Header.Set("Content-Type", "application/x+yaml") + } + return b, nil +} + // JSONBody adds request raw body encoding by JSON. func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) { if b.req.Body == nil && obj != nil { @@ -429,12 +446,12 @@ func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) { } b.buildURL(paramBody) - url, err := url.Parse(b.url) + urlParsed, err := url.Parse(b.url) if err != nil { return nil, err } - b.req.URL = url + b.req.URL = urlParsed trans := b.setting.Transport @@ -444,7 +461,7 @@ func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) { TLSClientConfig: b.setting.TLSClientConfig, Proxy: b.setting.Proxy, Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout), - MaxIdleConnsPerHost: -1, + MaxIdleConnsPerHost: 100, } } else { // if b.transport is *http.Transport then set the settings. @@ -579,6 +596,16 @@ func (b *BeegoHTTPRequest) ToXML(v interface{}) error { return xml.Unmarshal(data, v) } +// ToYAML returns the map that marshals from the body bytes as yaml in response . +// it calls Response inner. +func (b *BeegoHTTPRequest) ToYAML(v interface{}) error { + data, err := b.Bytes() + if err != nil { + return err + } + return yaml.Unmarshal(data, v) +} + // Response executes request client gets response mannually. func (b *BeegoHTTPRequest) Response() (*http.Response, error) { return b.getResponse() diff --git a/logs/README.md b/logs/README.md index 57d7abc3..c05bcc04 100644 --- a/logs/README.md +++ b/logs/README.md @@ -16,48 +16,57 @@ As of now this logs support console, file,smtp and conn. First you must import it - import ( - "github.com/astaxie/beego/logs" - ) +```golang +import ( + "github.com/astaxie/beego/logs" +) +``` Then init a Log (example with console adapter) - log := NewLogger(10000) - log.SetLogger("console", "") +```golang +log := logs.NewLogger(10000) +log.SetLogger("console", "") +``` > the first params stand for how many channel -Use it like this: - - log.Trace("trace") - log.Info("info") - log.Warn("warning") - log.Debug("debug") - log.Critical("critical") +Use it like this: +```golang +log.Trace("trace") +log.Info("info") +log.Warn("warning") +log.Debug("debug") +log.Critical("critical") +``` ## File adapter Configure file adapter like this: - log := NewLogger(10000) - log.SetLogger("file", `{"filename":"test.log"}`) - +```golang +log := NewLogger(10000) +log.SetLogger("file", `{"filename":"test.log"}`) +``` ## Conn adapter Configure like this: - log := NewLogger(1000) - log.SetLogger("conn", `{"net":"tcp","addr":":7020"}`) - log.Info("info") - +```golang +log := NewLogger(1000) +log.SetLogger("conn", `{"net":"tcp","addr":":7020"}`) +log.Info("info") +``` ## Smtp adapter Configure like this: - log := NewLogger(10000) - log.SetLogger("smtp", `{"username":"beegotest@gmail.com","password":"xxxxxxxx","host":"smtp.gmail.com:587","sendTos":["xiemengjun@gmail.com"]}`) - log.Critical("sendmail critical") - time.Sleep(time.Second * 30) +```golang +log := NewLogger(10000) +log.SetLogger("smtp", `{"username":"beegotest@gmail.com","password":"xxxxxxxx","host":"smtp.gmail.com:587","sendTos":["xiemengjun@gmail.com"]}`) +log.Critical("sendmail critical") +time.Sleep(time.Second * 30) +``` diff --git a/logs/accesslog.go b/logs/accesslog.go index cf799dc1..3ff9e20f 100644 --- a/logs/accesslog.go +++ b/logs/accesslog.go @@ -16,13 +16,14 @@ package logs import ( "bytes" + "strings" "encoding/json" - "time" "fmt" + "time" ) const ( - apacheFormatPattern = "%s - - [%s] \"%s %d %d\" %f %s %s\n" + apacheFormatPattern = "%s - - [%s] \"%s %d %d\" %f %s %s" apacheFormat = "APACHE_FORMAT" jsonFormat = "JSON_FORMAT" ) @@ -53,10 +54,9 @@ func (r *AccessLogRecord) json() ([]byte, error) { } func disableEscapeHTML(i interface{}) { - e, ok := i.(interface { + if e, ok := i.(interface { SetEscapeHTML(bool) - }); - if ok { + }); ok { e.SetEscapeHTML(false) } } @@ -64,9 +64,7 @@ func disableEscapeHTML(i interface{}) { // AccessLog - Format and print access log. func AccessLog(r *AccessLogRecord, format string) { var msg string - switch format { - case apacheFormat: timeFormatted := r.RequestTime.Format("02/Jan/2006 03:04:05") msg = fmt.Sprintf(apacheFormatPattern, r.RemoteAddr, timeFormatted, r.Request, r.Status, r.BodyBytesSent, @@ -81,6 +79,5 @@ func AccessLog(r *AccessLogRecord, format string) { msg = string(jsonData) } } - - beeLogger.Debug(msg) + beeLogger.writeMsg(levelLoggerImpl, strings.TrimSpace(msg)) } diff --git a/logs/file.go b/logs/file.go index 3e102f07..21c4438e 100644 --- a/logs/file.go +++ b/logs/file.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "os" + "path" "path/filepath" "strconv" "strings" @@ -40,6 +41,9 @@ type fileLogWriter struct { MaxLines int `json:"maxlines"` maxLinesCurLines int + MaxFiles int `json:"maxfiles"` + MaxFilesCurFiles int + // Rotate at size MaxSize int `json:"maxsize"` maxSizeCurSize int @@ -78,6 +82,9 @@ func newFileWriter() Logger { RotatePerm: "0440", Level: LevelTrace, Perm: "0660", + MaxLines: 10000000, + MaxFiles: 999, + MaxSize: 1 << 28, } return w } @@ -184,6 +191,10 @@ func (w *fileLogWriter) createLogFile() (*os.File, error) { if err != nil { return nil, err } + + filepath := path.Dir(w.Filename) + os.MkdirAll(filepath, os.FileMode(perm)) + fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, os.FileMode(perm)) if err == nil { // Make sure file perm is user set perm cause of `os.OpenFile` will obey umask @@ -280,7 +291,7 @@ func (w *fileLogWriter) lines() (int, error) { func (w *fileLogWriter) doRotate(logTime time.Time) error { // file exists // Find the next available number - num := 1 + num := w.MaxFilesCurFiles + 1 fName := "" format := "" var openTime time.Time @@ -295,6 +306,7 @@ func (w *fileLogWriter) doRotate(logTime time.Time) error { goto RESTART_LOGGER } +<<<<<<< HEAD if w.Hourly { format = "2006010215" openTime = w.hourlyOpenTime @@ -315,6 +327,18 @@ func (w *fileLogWriter) doRotate(logTime time.Time) error { fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", openTime.Format(format), num, w.suffix) _, err = os.Lstat(fName) } +======= + // only when one of them be setted, then the file would be splited + if w.MaxLines > 0 || w.MaxSize > 0 { + for ; err == nil && num <= w.MaxFiles; num++ { + fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", logTime.Format("2006-01-02"), num, w.suffix) + _, err = os.Lstat(fName) + } + } else { + fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", w.dailyOpenTime.Format("2006-01-02"), num, w.suffix) + _, err = os.Lstat(fName) + w.MaxFilesCurFiles = num +>>>>>>> old/develop } // return error if the last file checked still existed if err == nil { @@ -359,6 +383,7 @@ func (w *fileLogWriter) deleteOldLog() { if info == nil { return } +<<<<<<< HEAD if w.Hourly { if !info.IsDir() && info.ModTime().Add(1*time.Hour*time.Duration(w.MaxHours)).Before(time.Now()) { if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && @@ -374,6 +399,15 @@ func (w *fileLogWriter) deleteOldLog() { } } } +======= + + if !info.IsDir() && info.ModTime().Add(24 * time.Hour * time.Duration(w.MaxDays)).Before(time.Now()) { + if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && + strings.HasSuffix(filepath.Base(path), w.suffix) { + os.Remove(path) + } + } +>>>>>>> old/develop return }) } diff --git a/logs/file_test.go b/logs/file_test.go index 333209c6..ad377f6d 100644 --- a/logs/file_test.go +++ b/logs/file_test.go @@ -135,7 +135,7 @@ func TestFileDailyRotate_01(t *testing.T) { func TestFileDailyRotate_02(t *testing.T) { fn1 := "rotate_day.log" - fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log" + fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log" testFileRotate(t, fn1, fn2, true, false) } @@ -150,7 +150,7 @@ func TestFileDailyRotate_03(t *testing.T) { func TestFileDailyRotate_04(t *testing.T) { fn1 := "rotate_day.log" - fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log" + fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log" testFileDailyRotate(t, fn1, fn2) } @@ -286,6 +286,7 @@ func testFileRotate(t *testing.T, fn1, fn2 string, daily, hourly bool) { for _, file := range []string{fn1, fn2} { _, err := os.Stat(file) if err != nil { + t.Log(err) t.FailNow() } os.Remove(file) diff --git a/logs/log.go b/logs/log.go index 0e97a70e..a3614165 100644 --- a/logs/log.go +++ b/logs/log.go @@ -47,7 +47,7 @@ import ( // RFC5424 log message levels. const ( - LevelEmergency = iota + LevelEmergency = iota LevelAlert LevelCritical LevelError @@ -116,6 +116,7 @@ type BeeLogger struct { enableFuncCallDepth bool loggerFuncCallDepth int asynchronous bool + prefix string msgChanLen int64 msgChan chan *logMsg signalChan chan string @@ -247,7 +248,7 @@ func (bl *BeeLogger) Write(p []byte) (n int, err error) { } // writeMsg will always add a '\n' character if p[len(p)-1] == '\n' { - p = p[0 : len(p)-1] + p = p[0: len(p)-1] } // set levelLoggerImpl to ensure all log message will be write out err = bl.writeMsg(levelLoggerImpl, string(p)) @@ -267,6 +268,9 @@ func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error if len(v) > 0 { msg = fmt.Sprintf(msg, v...) } + + msg = bl.prefix + " " + msg + when := time.Now() if bl.enableFuncCallDepth { _, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth) @@ -305,6 +309,11 @@ func (bl *BeeLogger) SetLevel(l int) { bl.level = l } +// GetLevel Get Current log message level. +func (bl *BeeLogger) GetLevel() int { + return bl.level +} + // SetLogFuncCallDepth set log funcCallDepth func (bl *BeeLogger) SetLogFuncCallDepth(d int) { bl.loggerFuncCallDepth = d @@ -320,6 +329,11 @@ func (bl *BeeLogger) EnableFuncCallDepth(b bool) { bl.enableFuncCallDepth = b } +// set prefix +func (bl *BeeLogger) SetPrefix(s string) { + bl.prefix = s +} + // start logger chan reading. // when chan is not empty, write logs. func (bl *BeeLogger) startLogger() { @@ -544,6 +558,11 @@ func SetLevel(l int) { beeLogger.SetLevel(l) } +// SetPrefix sets the prefix +func SetPrefix(s string) { + beeLogger.SetPrefix(s) +} + // EnableFuncCallDepth enable log funcCallDepth func EnableFuncCallDepth(b bool) { beeLogger.enableFuncCallDepth = b diff --git a/logs/logger.go b/logs/logger.go index fa6d9334..428d3aa0 100644 --- a/logs/logger.go +++ b/logs/logger.go @@ -93,7 +93,7 @@ const ( func formatTimeHeader(when time.Time) ([]byte, int, int) { y, mo, d := when.Date() h, mi, s := when.Clock() - ns := when.Nanosecond()/1000000 + ns := when.Nanosecond() / 1000000 //len("2006/01/02 15:04:05.123 ")==24 var buf [24]byte diff --git a/logs/logger_test.go b/logs/logger_test.go index 5a0a423d..870b0d81 100644 --- a/logs/logger_test.go +++ b/logs/logger_test.go @@ -30,8 +30,13 @@ func TestFormatHeader_0(t *testing.T) { if tm.Year() >= 2100 { break } +<<<<<<< HEAD h, _, _ := formatTimeHeader(tm) if tm.Format("2006/01/02 15:04:05.999 ") != string(h) { +======= + h, _ := formatTimeHeader(tm) + if tm.Format("2006/01/02 15:04:05.000 ") != string(h) { +>>>>>>> old/develop t.Log(tm) t.FailNow() } @@ -48,8 +53,13 @@ func TestFormatHeader_1(t *testing.T) { if tm.Year() >= year+1 { break } +<<<<<<< HEAD h, _, _ := formatTimeHeader(tm) if tm.Format("2006/01/02 15:04:05.999 ") != string(h) { +======= + h, _ := formatTimeHeader(tm) + if tm.Format("2006/01/02 15:04:05.000 ") != string(h) { +>>>>>>> old/develop t.Log(tm) t.FailNow() } diff --git a/logs/multifile.go b/logs/multifile.go index 63204e17..90168274 100644 --- a/logs/multifile.go +++ b/logs/multifile.go @@ -67,7 +67,10 @@ func (f *multiFileLogWriter) Init(config string) error { jsonMap["level"] = i bs, _ := json.Marshal(jsonMap) writer = newFileWriter().(*fileLogWriter) - writer.Init(string(bs)) + err := writer.Init(string(bs)) + if err != nil { + return err + } f.writers[i] = writer } } diff --git a/migration/ddl.go b/migration/ddl.go index cea10355..9313acf8 100644 --- a/migration/ddl.go +++ b/migration/ddl.go @@ -322,7 +322,7 @@ func (m *Migration) GetSQL() (sql string) { sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name) } - if len(m.Columns) > index { + if len(m.Columns) > index+1 { sql += "," } } @@ -355,7 +355,7 @@ func (m *Migration) GetSQL() (sql string) { } else { sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name) } - if len(m.Columns) > index { + if len(m.Columns) > index+1 { sql += "," } } @@ -366,14 +366,14 @@ func (m *Migration) GetSQL() (sql string) { for index, unique := range m.Uniques { sql += fmt.Sprintf("\n DROP KEY `%s`", unique.Definition) - if len(m.Uniques) > index { + if len(m.Uniques) > index+1 { sql += "," } } for index, column := range m.Renames { sql += fmt.Sprintf("\n CHANGE COLUMN `%s` `%s` %s %s %s %s", column.NewName, column.OldName, column.OldDataType, column.OldUnsign, column.OldNull, column.OldDefault) - if len(m.Renames) > index { + if len(m.Renames) > index+1 { sql += "," } } diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go index ba7afb53..7c9aa51e 100644 --- a/orm/cmd_utils.go +++ b/orm/cmd_utils.go @@ -197,6 +197,10 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex if strings.Contains(column, "%COL%") { column = strings.Replace(column, "%COL%", fi.column, -1) } + + if fi.description != "" { + column += " " + fmt.Sprintf("COMMENT '%s'",fi.description) + } columns = append(columns, column) } diff --git a/orm/db.go b/orm/db.go index 5862d003..87d08df6 100644 --- a/orm/db.go +++ b/orm/db.go @@ -536,6 +536,8 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a updates := make([]string, len(names)) var conflitValue interface{} for i, v := range names { + // identifier in database may not be case-sensitive, so quote it + v = fmt.Sprintf("%s%s%s", Q, v, Q) marks[i] = "?" valueStr := argsMap[strings.ToLower(v)] if v == args0 { @@ -969,6 +971,10 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi } query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) + if qs.forupdate { + query += " FOR UPDATE" + } + d.ins.ReplaceMarks(&query) var rs *sql.Rows diff --git a/orm/models_fields.go b/orm/models_fields.go index d23c49fa..b4fad94f 100644 --- a/orm/models_fields.go +++ b/orm/models_fields.go @@ -86,7 +86,7 @@ func (e *BooleanField) SetRaw(value interface{}) error { e.Set(d) case string: v, err := StrTo(d).Bool() - if err != nil { + if err == nil { e.Set(v) } return err @@ -191,7 +191,7 @@ func (e *TimeField) SetRaw(value interface{}) error { e.Set(d) case string: v, err := timeParse(d, formatTime) - if err != nil { + if err == nil { e.Set(v) } return err @@ -250,7 +250,7 @@ func (e *DateField) SetRaw(value interface{}) error { e.Set(d) case string: v, err := timeParse(d, formatDate) - if err != nil { + if err == nil { e.Set(v) } return err @@ -300,7 +300,7 @@ func (e *DateTimeField) SetRaw(value interface{}) error { e.Set(d) case string: v, err := timeParse(d, formatDateTime) - if err != nil { + if err == nil { e.Set(v) } return err @@ -350,9 +350,10 @@ func (e *FloatField) SetRaw(value interface{}) error { e.Set(d) case string: v, err := StrTo(d).Float64() - if err != nil { + if err == nil { e.Set(v) } + return err default: return fmt.Errorf(" unknown value `%s`", value) } @@ -397,9 +398,10 @@ func (e *SmallIntegerField) SetRaw(value interface{}) error { e.Set(d) case string: v, err := StrTo(d).Int16() - if err != nil { + if err == nil { e.Set(v) } + return err default: return fmt.Errorf(" unknown value `%s`", value) } @@ -444,9 +446,10 @@ func (e *IntegerField) SetRaw(value interface{}) error { e.Set(d) case string: v, err := StrTo(d).Int32() - if err != nil { + if err == nil { e.Set(v) } + return err default: return fmt.Errorf(" unknown value `%s`", value) } @@ -491,9 +494,10 @@ func (e *BigIntegerField) SetRaw(value interface{}) error { e.Set(d) case string: v, err := StrTo(d).Int64() - if err != nil { + if err == nil { e.Set(v) } + return err default: return fmt.Errorf(" unknown value `%s`", value) } @@ -538,9 +542,10 @@ func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error { e.Set(d) case string: v, err := StrTo(d).Uint16() - if err != nil { + if err == nil { e.Set(v) } + return err default: return fmt.Errorf(" unknown value `%s`", value) } @@ -585,9 +590,10 @@ func (e *PositiveIntegerField) SetRaw(value interface{}) error { e.Set(d) case string: v, err := StrTo(d).Uint32() - if err != nil { + if err == nil { e.Set(v) } + return err default: return fmt.Errorf(" unknown value `%s`", value) } @@ -632,9 +638,10 @@ func (e *PositiveBigIntegerField) SetRaw(value interface{}) error { e.Set(d) case string: v, err := StrTo(d).Uint64() - if err != nil { + if err == nil { e.Set(v) } + return err default: return fmt.Errorf(" unknown value `%s`", value) } diff --git a/orm/models_info_f.go b/orm/models_info_f.go index 646e2273..479f5ae6 100644 --- a/orm/models_info_f.go +++ b/orm/models_info_f.go @@ -136,6 +136,7 @@ type fieldInfo struct { decimals int isFielder bool // implement Fielder interface onDelete string + description string } // new field info @@ -300,6 +301,7 @@ checkType: fi.sf = sf fi.fullName = mi.fullName + mName + "." + sf.Name + fi.description = sf.Tag.Get("description") fi.null = attrs["null"] fi.index = attrs["index"] fi.auto = attrs["auto"] diff --git a/orm/models_info_m.go b/orm/models_info_m.go index 4a3a37f9..a4d733b6 100644 --- a/orm/models_info_m.go +++ b/orm/models_info_m.go @@ -75,7 +75,8 @@ func addModelFields(mi *modelInfo, ind reflect.Value, mName string, index []int) break } //record current field index - fi.fieldIndex = append(index, i) + fi.fieldIndex = append(fi.fieldIndex, index...) + fi.fieldIndex = append(fi.fieldIndex, i) fi.mi = mi fi.inModel = true if !mi.fields.Add(fi) { diff --git a/orm/models_test.go b/orm/models_test.go index d6c2b581..e3a635f2 100644 --- a/orm/models_test.go +++ b/orm/models_test.go @@ -433,53 +433,57 @@ var ( dDbBaser dbBaser ) +var ( + helpinfo = `need driver and source! + + Default DB Drivers. + + driver: url + mysql: https://github.com/go-sql-driver/mysql + sqlite3: https://github.com/mattn/go-sqlite3 + postgres: https://github.com/lib/pq + tidb: https://github.com/pingcap/tidb + + usage: + + go get -u github.com/astaxie/beego/orm + go get -u github.com/go-sql-driver/mysql + go get -u github.com/mattn/go-sqlite3 + go get -u github.com/lib/pq + go get -u github.com/pingcap/tidb + + #### MySQL + mysql -u root -e 'create database orm_test;' + export ORM_DRIVER=mysql + export ORM_SOURCE="root:@/orm_test?charset=utf8" + go test -v github.com/astaxie/beego/orm + + + #### Sqlite3 + export ORM_DRIVER=sqlite3 + export ORM_SOURCE='file:memory_test?mode=memory' + go test -v github.com/astaxie/beego/orm + + + #### PostgreSQL + psql -c 'create database orm_test;' -U postgres + export ORM_DRIVER=postgres + export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" + go test -v github.com/astaxie/beego/orm + + #### TiDB + export ORM_DRIVER=tidb + export ORM_SOURCE='memory://test/test' + go test -v github.com/astaxie/beego/orm + + ` +) + func init() { Debug, _ = StrTo(DBARGS.Debug).Bool() if DBARGS.Driver == "" || DBARGS.Source == "" { - fmt.Println(`need driver and source! - -Default DB Drivers. - - driver: url - mysql: https://github.com/go-sql-driver/mysql - sqlite3: https://github.com/mattn/go-sqlite3 -postgres: https://github.com/lib/pq -tidb: https://github.com/pingcap/tidb - -usage: - -go get -u github.com/astaxie/beego/orm -go get -u github.com/go-sql-driver/mysql -go get -u github.com/mattn/go-sqlite3 -go get -u github.com/lib/pq -go get -u github.com/pingcap/tidb - -#### MySQL -mysql -u root -e 'create database orm_test;' -export ORM_DRIVER=mysql -export ORM_SOURCE="root:@/orm_test?charset=utf8" -go test -v github.com/astaxie/beego/orm - - -#### Sqlite3 -export ORM_DRIVER=sqlite3 -export ORM_SOURCE='file:memory_test?mode=memory' -go test -v github.com/astaxie/beego/orm - - -#### PostgreSQL -psql -c 'create database orm_test;' -U postgres -export ORM_DRIVER=postgres -export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" -go test -v github.com/astaxie/beego/orm - -#### TiDB -export ORM_DRIVER=tidb -export ORM_SOURCE='memory://test/test' -go test -v github.com/astaxie/beego/orm - -`) + fmt.Println(helpinfo) os.Exit(2) } diff --git a/orm/orm.go b/orm/orm.go index fcf82590..e59589aa 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +// +build go1.8 + // Package orm provide ORM for MySQL/PostgreSQL/sqlite // Simple Usage // @@ -52,6 +54,7 @@ package orm import ( + "context" "database/sql" "errors" "fmt" @@ -458,11 +461,15 @@ func (o *orm) Using(name string) error { // begin transaction func (o *orm) Begin() error { + return o.BeginTx(context.Background(), nil) +} + +func (o *orm) BeginTx(ctx context.Context, opts *sql.TxOptions) error { if o.isTx { return ErrTxHasBegan } var tx *sql.Tx - tx, err := o.db.(txer).Begin() + tx, err := o.db.(txer).BeginTx(ctx, opts) if err != nil { return err } diff --git a/orm/orm_log.go b/orm/orm_log.go index 26c73f9e..979dbbc6 100644 --- a/orm/orm_log.go +++ b/orm/orm_log.go @@ -15,6 +15,7 @@ package orm import ( + "context" "database/sql" "fmt" "io" @@ -150,6 +151,13 @@ func (d *dbQueryLog) Begin() (*sql.Tx, error) { return tx, err } +func (d *dbQueryLog) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { + a := time.Now() + tx, err := d.db.(txer).BeginTx(ctx, opts) + debugLogQueies(d.alias, "db.BeginTx", "START TRANSACTION", a, err) + return tx, err +} + func (d *dbQueryLog) Commit() error { a := time.Now() err := d.db.(txEnder).Commit() diff --git a/orm/orm_queryset.go b/orm/orm_queryset.go index 4e33646d..8b5d6fd2 100644 --- a/orm/orm_queryset.go +++ b/orm/orm_queryset.go @@ -55,16 +55,17 @@ func ColValue(opt operator, value interface{}) interface{} { // real query struct type querySet struct { - mi *modelInfo - cond *Condition - related []string - relDepth int - limit int64 - offset int64 - groups []string - orders []string - distinct bool - orm *orm + mi *modelInfo + cond *Condition + related []string + relDepth int + limit int64 + offset int64 + groups []string + orders []string + distinct bool + forupdate bool + orm *orm } var _ QuerySeter = new(querySet) @@ -127,6 +128,12 @@ func (o querySet) Distinct() QuerySeter { return &o } +// add FOR UPDATE to SELECT +func (o querySet) ForUpdate() QuerySeter { + o.forupdate = true + return &o +} + // set relation model to query together. // it will query relation models and assign to parent model. func (o querySet) RelatedSel(params ...interface{}) QuerySeter { @@ -191,7 +198,11 @@ func (o *querySet) PrepareInsert() (Inserter, error) { // query all data and map to containers. // cols means the columns when querying. func (o *querySet) All(container interface{}, cols ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) + num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) + if num == 0 { + return 0, ErrNoRows + } + return num, err } // query one row data and map to containers. diff --git a/orm/orm_test.go b/orm/orm_test.go index f1f2d85e..cdb2fe3f 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -12,10 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +// +build go1.8 + package orm import ( "bytes" + "context" "database/sql" "fmt" "io/ioutil" @@ -452,9 +455,9 @@ func TestNullDataTypes(t *testing.T) { throwFail(t, AssertIs(*d.Float32Ptr, float32Ptr)) throwFail(t, AssertIs(*d.Float64Ptr, float64Ptr)) throwFail(t, AssertIs(*d.DecimalPtr, decimalPtr)) - throwFail(t, AssertIs((*d.TimePtr).Format(testTime), timePtr.Format(testTime))) - throwFail(t, AssertIs((*d.DatePtr).Format(testDate), datePtr.Format(testDate))) - throwFail(t, AssertIs((*d.DateTimePtr).Format(testDateTime), dateTimePtr.Format(testDateTime))) + throwFail(t, AssertIs((*d.TimePtr).UTC().Format(testTime), timePtr.UTC().Format(testTime))) + throwFail(t, AssertIs((*d.DatePtr).UTC().Format(testDate), datePtr.UTC().Format(testDate))) + throwFail(t, AssertIs((*d.DateTimePtr).UTC().Format(testDateTime), dateTimePtr.UTC().Format(testDateTime))) } func TestDataCustomTypes(t *testing.T) { @@ -1008,13 +1011,13 @@ func TestAll(t *testing.T) { qs = dORM.QueryTable("user") num, err = qs.Filter("user_name", "nothing").All(&users) - throwFailNow(t, err) + throwFailNow(t, AssertIs(err, ErrNoRows)) throwFailNow(t, AssertIs(num, 0)) var users3 []*User qs = dORM.QueryTable("user") num, err = qs.Filter("user_name", "nothing").All(&users3) - throwFailNow(t, err) + throwFailNow(t, AssertIs(err, ErrNoRows)) throwFailNow(t, AssertIs(num, 0)) throwFailNow(t, AssertIs(users3 == nil, false)) } @@ -1990,6 +1993,66 @@ func TestTransaction(t *testing.T) { } +func TestTransactionIsolationLevel(t *testing.T) { + // this test worked when database support transaction isolation level + if IsSqlite { + return + } + + o1 := NewOrm() + o2 := NewOrm() + + // start two transaction with isolation level repeatable read + err := o1.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + throwFail(t, err) + err = o2.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + throwFail(t, err) + + // o1 insert tag + var tag Tag + tag.Name = "test-transaction" + id, err := o1.Insert(&tag) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + // o2 query tag table, no result + num, err := o2.QueryTable("tag").Filter("name", "test-transaction").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 0)) + + // o1 commit + o1.Commit() + + // o2 query tag table, still no result + num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 0)) + + // o2 commit and query tag table, get the result + o2.Commit() + num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = o1.QueryTable("tag").Filter("name", "test-transaction").Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestBeginTxWithContextCanceled(t *testing.T) { + o := NewOrm() + ctx, cancel := context.WithCancel(context.Background()) + o.BeginTx(ctx, nil) + id, err := o.Insert(&Tag{Name: "test-context"}) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + // cancel the context before commit to make it error + cancel() + err = o.Commit() + throwFail(t, AssertIs(err, context.Canceled)) +} + func TestReadOrCreate(t *testing.T) { u := &User{ UserName: "Kyle", @@ -2260,6 +2323,7 @@ func TestIgnoreCaseTag(t *testing.T) { throwFail(t, AssertIs(info.fields.GetByName("Name02").column, "Name")) throwFail(t, AssertIs(info.fields.GetByName("Name03").column, "name")) } + func TestInsertOrUpdate(t *testing.T) { RegisterModel(new(User)) user := User{UserName: "unique_username133", Status: 1, Password: "o"} @@ -2297,6 +2361,11 @@ func TestInsertOrUpdate(t *testing.T) { throwFailNow(t, AssertIs(user2.Status, test.Status)) throwFailNow(t, AssertIs(user2.Password, strings.TrimSpace(test.Password))) } + + //postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values + if IsPostgres { + return + } //test3 + _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status+1") if err != nil { diff --git a/orm/types.go b/orm/types.go index 3e6a9e87..2fdc98c7 100644 --- a/orm/types.go +++ b/orm/types.go @@ -15,6 +15,7 @@ package orm import ( + "context" "database/sql" "reflect" "time" @@ -106,6 +107,17 @@ type Ormer interface { // ... // err = o.Rollback() Begin() error + // begin transaction with provided context and option + // the provided context is used until the transaction is committed or rolled back. + // if the context is canceled, the transaction will be rolled back. + // the provided TxOptions is optional and may be nil if defaults should be used. + // if a non-default isolation level is used that the driver doesn't support, an error will be returned. + // for example: + // o := NewOrm() + // err := o.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + // ... + // err = o.Rollback() + BeginTx(ctx context.Context, opts *sql.TxOptions) error // commit transaction Commit() error // rollback transaction @@ -190,6 +202,10 @@ type QuerySeter interface { // Distinct(). // All(&permissions) Distinct() QuerySeter + // set FOR UPDATE to query. + // for example: + // o.QueryTable("user").Filter("uid", uid).ForUpdate().All(&users) + ForUpdate() QuerySeter // return QuerySeter execution result number // for example: // num, err = qs.Filter("profile__age__gt", 28).Count() @@ -397,6 +413,7 @@ type dbQuerier interface { // transaction beginner type txer interface { Begin() (*sql.Tx, error) + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) } // transaction ending diff --git a/parser.go b/parser.go index 1933c6c6..2ac48b85 100644 --- a/parser.go +++ b/parser.go @@ -114,20 +114,21 @@ type parsedParam struct { func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error { if f.Doc != nil { - parsedComment, err := parseComment(f.Doc.List) + parsedComments, err := parseComment(f.Doc.List) if err != nil { return err } - if parsedComment.routerPath != "" { - key := pkgpath + ":" + controllerName - cc := ControllerComments{} - cc.Method = f.Name.String() - cc.Router = parsedComment.routerPath - cc.AllowHTTPMethods = parsedComment.methods - cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment) - genInfoList[key] = append(genInfoList[key], cc) + for _, parsedComment := range parsedComments { + if parsedComment.routerPath != "" { + key := pkgpath + ":" + controllerName + cc := ControllerComments{} + cc.Method = f.Name.String() + cc.Router = parsedComment.routerPath + cc.AllowHTTPMethods = parsedComment.methods + cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment) + genInfoList[key] = append(genInfoList[key], cc) + } } - } return nil } @@ -177,26 +178,13 @@ func paramInPath(name, route string) bool { var routeRegex = regexp.MustCompile(`@router\s+(\S+)(?:\s+\[(\S+)\])?`) -func parseComment(lines []*ast.Comment) (pc *parsedComment, err error) { - pc = &parsedComment{} +func parseComment(lines []*ast.Comment) (pcs []*parsedComment, err error) { + pcs = []*parsedComment{} + params := map[string]parsedParam{} + for _, c := range lines { t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - if strings.HasPrefix(t, "@router") { - matches := routeRegex.FindStringSubmatch(t) - if len(matches) == 3 { - pc.routerPath = matches[1] - methods := matches[2] - if methods == "" { - pc.methods = []string{"get"} - //pc.hasGet = true - } else { - pc.methods = strings.Split(methods, ",") - //pc.hasGet = strings.Contains(methods, "get") - } - } else { - return nil, errors.New("Router information is missing") - } - } else if strings.HasPrefix(t, "@Param") { + if strings.HasPrefix(t, "@Param") { pv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Param"))) if len(pv) < 4 { logs.Error("Invalid @Param format. Needs at least 4 parameters") @@ -217,10 +205,32 @@ func parseComment(lines []*ast.Comment) (pc *parsedComment, err error) { p.defValue = pv[3] p.required, _ = strconv.ParseBool(pv[4]) } - if pc.params == nil { - pc.params = map[string]parsedParam{} + params[funcParamName] = p + } + } + + for _, c := range lines { + var pc = &parsedComment{} + pc.params = params + + t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) + if strings.HasPrefix(t, "@router") { + t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) + matches := routeRegex.FindStringSubmatch(t) + if len(matches) == 3 { + pc.routerPath = matches[1] + methods := matches[2] + if methods == "" { + pc.methods = []string{"get"} + //pc.hasGet = true + } else { + pc.methods = strings.Split(methods, ",") + //pc.hasGet = strings.Contains(methods, "get") + } + pcs = append(pcs, pc) + } else { + return nil, errors.New("Router information is missing") } - pc.params[funcParamName] = p } } return diff --git a/router.go b/router.go index 2f5d2eae..2e0cea25 100644 --- a/router.go +++ b/router.go @@ -43,7 +43,7 @@ const ( ) const ( - routerTypeBeego = iota + routerTypeBeego = iota routerTypeRESTFul routerTypeHandler ) @@ -71,7 +71,7 @@ var ( // these beego.Controller's methods shouldn't reflect to AutoRouter exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString", "RenderBytes", "Redirect", "Abort", "StopRun", "UrlFor", "ServeJSON", "ServeJSONP", - "ServeXML", "Input", "ParseForm", "GetString", "GetStrings", "GetInt", "GetBool", + "ServeYAML", "ServeXML", "Input", "ParseForm", "GetString", "GetStrings", "GetInt", "GetBool", "GetFloat", "GetFile", "SaveToFile", "StartSession", "SetSession", "GetSession", "DelSession", "SessionRegenerateID", "DestroySession", "IsAjax", "GetSecureCookie", "SetSecureCookie", "XsrfToken", "CheckXsrfCookie", "XsrfFormHtml", @@ -201,9 +201,12 @@ func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInt numOfFields := elemVal.NumField() for i := 0; i < numOfFields; i++ { - fieldVal := elemVal.Field(i) fieldType := elemType.Field(i) - execElem.FieldByName(fieldType.Name).Set(fieldVal) + elemField := execElem.FieldByName(fieldType.Name) + if elemField.CanSet() { + fieldVal := elemVal.Field(i) + elemField.Set(fieldVal) + } } return execController @@ -874,13 +877,15 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) } Admin: - //admin module record QPS +//admin module record QPS statusCode := context.ResponseWriter.Status if statusCode == 0 { statusCode = 200 } + logAccess(context, &startTime, statusCode) + if BConfig.Listen.EnableAdmin { timeDur := time.Since(startTime) pattern := "" @@ -897,49 +902,30 @@ Admin: } } - if BConfig.RunMode == DEV || BConfig.Log.AccessLogs { - timeDur := time.Since(startTime) + if BConfig.RunMode == DEV && !BConfig.Log.AccessLogs { var devInfo string - + timeDur := time.Since(startTime) iswin := (runtime.GOOS == "windows") statusColor := logs.ColorByStatus(iswin, statusCode) methodColor := logs.ColorByMethod(iswin, r.Method) resetColor := logs.ColorByMethod(iswin, "") - if BConfig.Log.AccessLogsFormat != "" { - record := &logs.AccessLogRecord{ - RemoteAddr: context.Input.IP(), - RequestTime: startTime, - RequestMethod: r.Method, - Request: fmt.Sprintf("%s %s %s", r.Method, r.RequestURI, r.Proto), - ServerProtocol: r.Proto, - Host: r.Host, - Status: statusCode, - ElapsedTime: timeDur, - HTTPReferrer: r.Header.Get("Referer"), - HTTPUserAgent: r.Header.Get("User-Agent"), - RemoteUser: r.Header.Get("Remote-User"), - BodyBytesSent: 0, //@todo this one is missing! - } - logs.AccessLog(record, BConfig.Log.AccessLogsFormat) - } else { - if findRouter { - if routerInfo != nil { - devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s r:%s", context.Input.IP(), statusColor, statusCode, - resetColor, timeDur.String(), "match", methodColor, r.Method, resetColor, r.URL.Path, - routerInfo.pattern) - } else { - devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", context.Input.IP(), statusColor, statusCode, resetColor, - timeDur.String(), "match", methodColor, r.Method, resetColor, r.URL.Path) - } + if findRouter { + if routerInfo != nil { + devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s r:%s", context.Input.IP(), statusColor, statusCode, + resetColor, timeDur.String(), "match", methodColor, r.Method, resetColor, r.URL.Path, + routerInfo.pattern) } else { devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", context.Input.IP(), statusColor, statusCode, resetColor, - timeDur.String(), "nomatch", methodColor, r.Method, resetColor, r.URL.Path) - } - if iswin { - logs.W32Debug(devInfo) - } else { - logs.Debug(devInfo) + timeDur.String(), "match", methodColor, r.Method, resetColor, r.URL.Path) } + } else { + devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", context.Input.IP(), statusColor, statusCode, resetColor, + timeDur.String(), "nomatch", methodColor, r.Method, resetColor, r.URL.Path) + } + if iswin { + logs.W32Debug(devInfo) + } else { + logs.Debug(devInfo) } } // Call WriteHeader if status code has been set changed @@ -957,7 +943,7 @@ func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, ex context.RenderMethodResult(resultValue) } } - if !context.ResponseWriter.Started && context.Output.Status == 0 { + if !context.ResponseWriter.Started && len(results) > 0 && context.Output.Status == 0 { context.Output.SetStatus(200) } } @@ -988,3 +974,38 @@ func toURL(params map[string]string) string { } return strings.TrimRight(u, "&") } + +func logAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) { + //Skip logging if AccessLogs config is false + if !BConfig.Log.AccessLogs { + return + } + //Skip logging static requests unless EnableStaticLogs config is true + if !BConfig.Log.EnableStaticLogs && DefaultAccessLogFilter.Filter(ctx) { + return + } + var ( + requestTime time.Time + elapsedTime time.Duration + r = ctx.Request + ) + if startTime != nil { + requestTime = *startTime + elapsedTime = time.Since(*startTime) + } + record := &logs.AccessLogRecord{ + RemoteAddr: ctx.Input.IP(), + RequestTime: requestTime, + RequestMethod: r.Method, + Request: fmt.Sprintf("%s %s %s", r.Method, r.RequestURI, r.Proto), + ServerProtocol: r.Proto, + Host: r.Host, + Status: statusCode, + ElapsedTime: elapsedTime, + HTTPReferrer: r.Header.Get("Referer"), + HTTPUserAgent: r.Header.Get("User-Agent"), + RemoteUser: r.Header.Get("Remote-User"), + BodyBytesSent: 0, //@todo this one is missing! + } + logs.AccessLog(record, BConfig.Log.AccessLogsFormat) +} diff --git a/router_test.go b/router_test.go index 720b4ca8..90104427 100644 --- a/router_test.go +++ b/router_test.go @@ -695,3 +695,30 @@ func beegoResetParams(ctx *context.Context) { func beegoHandleResetParams(ctx *context.Context) { ctx.ResponseWriter.Header().Set("splat", ctx.Input.Param(":splat")) } + +// YAML +type YAMLController struct { + Controller +} + +func (jc *YAMLController) Prepare() { + jc.Data["yaml"] = "prepare" + jc.ServeYAML() +} + +func (jc *YAMLController) Get() { + jc.Data["Username"] = "astaxie" + jc.Ctx.Output.Body([]byte("ok")) +} + +func TestYAMLPrepare(t *testing.T) { + r, _ := http.NewRequest("GET", "/yaml/list", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/yaml/list", &YAMLController{}) + handler.ServeHTTP(w, r) + if strings.TrimSpace(w.Body.String()) != "prepare" { + t.Errorf(w.Body.String()) + } +} diff --git a/session/redis/sess_redis.go b/session/redis/sess_redis.go index 55595851..5c382d61 100644 --- a/session/redis/sess_redis.go +++ b/session/redis/sess_redis.go @@ -14,9 +14,9 @@ // Package redis for session provider // -// depend on github.com/garyburd/redigo/redis +// depend on github.com/gomodule/redigo/redis // -// go install github.com/garyburd/redigo/redis +// go install github.com/gomodule/redigo/redis // // Usage: // import( @@ -24,10 +24,10 @@ // "github.com/astaxie/beego/session" // ) // -// func init() { -// globalSessions, _ = session.NewManager("redis", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070"}``) -// go globalSessions.GC() -// } +// func init() { +// globalSessions, _ = session.NewManager("redis", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070"}``) +// go globalSessions.GC() +// } // // more docs: http://beego.me/docs/module/session.md package redis @@ -37,10 +37,11 @@ import ( "strconv" "strings" "sync" + "time" "github.com/astaxie/beego/session" - "github.com/garyburd/redigo/redis" + "github.com/gomodule/redigo/redis" ) var redispder = &Provider{} @@ -118,8 +119,8 @@ type Provider struct { } // SessionInit init redis session -// savepath like redis server addr,pool size,password,dbnum -// e.g. 127.0.0.1:6379,100,astaxie,0 +// savepath like redis server addr,pool size,password,dbnum,IdleTimeout second +// e.g. 127.0.0.1:6379,100,astaxie,0,30 func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { rp.maxlifetime = maxlifetime configs := strings.Split(savePath, ",") @@ -149,27 +150,39 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { } else { rp.dbNum = 0 } - rp.poollist = redis.NewPool(func() (redis.Conn, error) { - c, err := redis.Dial("tcp", rp.savePath) - if err != nil { - return nil, err + var idleTimeout time.Duration = 0 + if len(configs) > 4 { + timeout, err := strconv.Atoi(configs[4]) + if err == nil && timeout > 0 { + idleTimeout = time.Duration(timeout) * time.Second } - if rp.password != "" { - if _, err = c.Do("AUTH", rp.password); err != nil { - c.Close() - return nil, err - } - } - //some redis proxy such as twemproxy is not support select command - if rp.dbNum > 0 { - _, err = c.Do("SELECT", rp.dbNum) + } + rp.poollist = &redis.Pool{ + Dial: func() (redis.Conn, error) { + c, err := redis.Dial("tcp", rp.savePath) if err != nil { - c.Close() return nil, err } - } - return c, err - }, rp.poolsize) + if rp.password != "" { + if _, err = c.Do("AUTH", rp.password); err != nil { + c.Close() + return nil, err + } + } + // some redis proxy such as twemproxy is not support select command + if rp.dbNum > 0 { + _, err = c.Do("SELECT", rp.dbNum) + if err != nil { + c.Close() + return nil, err + } + } + return c, err + }, + MaxIdle: rp.poolsize, + } + + rp.poollist.IdleTimeout = idleTimeout return rp.poollist.Get().Err() } diff --git a/session/redis_cluster/redis_cluster.go b/session/redis_cluster/redis_cluster.go new file mode 100644 index 00000000..2fe300df --- /dev/null +++ b/session/redis_cluster/redis_cluster.go @@ -0,0 +1,220 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for session provider +// +// depend on github.com/go-redis/redis +// +// go install github.com/go-redis/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/redis_cluster" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("redis_cluster", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070;127.0.0.1:7071"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package redis_cluster +import ( + "net/http" + "strconv" + "strings" + "sync" + "github.com/astaxie/beego/session" + rediss "github.com/go-redis/redis" + "time" +) + +var redispder = &Provider{} + +// MaxPoolSize redis_cluster max pool size +var MaxPoolSize = 1000 + +// SessionStore redis_cluster session store +type SessionStore struct { + p *rediss.ClusterClient + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxlifetime int64 +} + +// Set value in redis_cluster session +func (rs *SessionStore) Set(key, value interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values[key] = value + return nil +} + +// Get value in redis_cluster session +func (rs *SessionStore) Get(key interface{}) interface{} { + rs.lock.RLock() + defer rs.lock.RUnlock() + if v, ok := rs.values[key]; ok { + return v + } + return nil +} + +// Delete value in redis_cluster session +func (rs *SessionStore) Delete(key interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + delete(rs.values, key) + return nil +} + +// Flush clear all values in redis_cluster session +func (rs *SessionStore) Flush() error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get redis_cluster session id +func (rs *SessionStore) SessionID() string { + return rs.sid +} + +// SessionRelease save session values to redis_cluster +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + b, err := session.EncodeGob(rs.values) + if err != nil { + return + } + c := rs.p + c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime) * time.Second) +} + +// Provider redis_cluster session provider +type Provider struct { + maxlifetime int64 + savePath string + poolsize int + password string + dbNum int + poollist *rediss.ClusterClient +} + +// SessionInit init redis_cluster session +// savepath like redis server addr,pool size,password,dbnum +// e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0 +func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { + rp.maxlifetime = maxlifetime + configs := strings.Split(savePath, ",") + if len(configs) > 0 { + rp.savePath = configs[0] + } + if len(configs) > 1 { + poolsize, err := strconv.Atoi(configs[1]) + if err != nil || poolsize < 0 { + rp.poolsize = MaxPoolSize + } else { + rp.poolsize = poolsize + } + } else { + rp.poolsize = MaxPoolSize + } + if len(configs) > 2 { + rp.password = configs[2] + } + if len(configs) > 3 { + dbnum, err := strconv.Atoi(configs[3]) + if err != nil || dbnum < 0 { + rp.dbNum = 0 + } else { + rp.dbNum = dbnum + } + } else { + rp.dbNum = 0 + } + + rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{ + Addrs: strings.Split(rp.savePath, ";"), + Password: rp.password, + PoolSize: rp.poolsize, + }) + return rp.poollist.Ping().Err() +} + +// SessionRead read redis_cluster session by sid +func (rp *Provider) SessionRead(sid string) (session.Store, error) { + var kv map[interface{}]interface{} + kvs, err := rp.poollist.Get(sid).Result() + if err != nil && err != rediss.Nil { + return nil, err + } + if len(kvs) == 0 { + kv = make(map[interface{}]interface{}) + } else { + if kv, err = session.DecodeGob([]byte(kvs)); err != nil { + return nil, err + } + } + + rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} + return rs, nil +} + +// SessionExist check redis_cluster session exist by sid +func (rp *Provider) SessionExist(sid string) bool { + c := rp.poollist + if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { + return false + } + return true +} + +// SessionRegenerate generate new sid for redis_cluster session +func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + c := rp.poollist + + if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 { + // oldsid doesn't exists, set the new sid directly + // ignore error here, since if it return error + // the existed value will be 0 + c.Set(sid, "", time.Duration(rp.maxlifetime) * time.Second) + } else { + c.Rename(oldsid, sid) + c.Expire(sid, time.Duration(rp.maxlifetime) * time.Second) + } + return rp.SessionRead(sid) +} + +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(sid string) error { + c := rp.poollist + c.Del(sid) + return nil +} + +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC() { +} + +// SessionAll return all activeSession +func (rp *Provider) SessionAll() int { + return 0 +} + +func init() { + session.Register("redis_cluster", redispder) +} diff --git a/staticfile.go b/staticfile.go index bbb2a1fb..30ee3dfb 100644 --- a/staticfile.go +++ b/staticfile.go @@ -74,7 +74,7 @@ func serverStaticRouter(ctx *context.Context) { if enableCompress { acceptEncoding = context.ParseEncoding(ctx.Request) } - b, n, sch, err := openFile(filePath, fileInfo, acceptEncoding) + b, n, sch, reader, err := openFile(filePath, fileInfo, acceptEncoding) if err != nil { if BConfig.RunMode == DEV { logs.Warn("Can't compress the file:", filePath, err) @@ -89,47 +89,53 @@ func serverStaticRouter(ctx *context.Context) { ctx.Output.Header("Content-Length", strconv.FormatInt(sch.size, 10)) } - http.ServeContent(ctx.ResponseWriter, ctx.Request, filePath, sch.modTime, sch) + http.ServeContent(ctx.ResponseWriter, ctx.Request, filePath, sch.modTime, reader) } type serveContentHolder struct { - *bytes.Reader + data []byte modTime time.Time size int64 encoding string } +type serveContentReader struct { + *bytes.Reader +} + var ( staticFileMap = make(map[string]*serveContentHolder) mapLock sync.RWMutex ) -func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, string, *serveContentHolder, error) { +func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, string, *serveContentHolder, *serveContentReader, error) { mapKey := acceptEncoding + ":" + filePath mapLock.RLock() mapFile := staticFileMap[mapKey] mapLock.RUnlock() if isOk(mapFile, fi) { - return mapFile.encoding != "", mapFile.encoding, mapFile, nil + reader := &serveContentReader{Reader: bytes.NewReader(mapFile.data)} + return mapFile.encoding != "", mapFile.encoding, mapFile, reader, nil } mapLock.Lock() defer mapLock.Unlock() if mapFile = staticFileMap[mapKey]; !isOk(mapFile, fi) { file, err := os.Open(filePath) if err != nil { - return false, "", nil, err + return false, "", nil, nil, err } defer file.Close() var bufferWriter bytes.Buffer _, n, err := context.WriteFile(acceptEncoding, &bufferWriter, file) if err != nil { - return false, "", nil, err + return false, "", nil, nil, err } - mapFile = &serveContentHolder{Reader: bytes.NewReader(bufferWriter.Bytes()), modTime: fi.ModTime(), size: int64(bufferWriter.Len()), encoding: n} + mapFile = &serveContentHolder{data: bufferWriter.Bytes(), modTime: fi.ModTime(), size: int64(bufferWriter.Len()), encoding: n} staticFileMap[mapKey] = mapFile } - return mapFile.encoding != "", mapFile.encoding, mapFile, nil + reader := &serveContentReader{Reader: bytes.NewReader(mapFile.data)} + return mapFile.encoding != "", mapFile.encoding, mapFile, reader, nil } func isOk(s *serveContentHolder, fi os.FileInfo) bool { diff --git a/staticfile_test.go b/staticfile_test.go index a043b4fd..69667bf8 100644 --- a/staticfile_test.go +++ b/staticfile_test.go @@ -16,7 +16,7 @@ var licenseFile = filepath.Join(currentWorkDir, "LICENSE") func testOpenFile(encoding string, content []byte, t *testing.T) { fi, _ := os.Stat(licenseFile) - b, n, sch, err := openFile(licenseFile, fi, encoding) + b, n, sch, reader, err := openFile(licenseFile, fi, encoding) if err != nil { t.Log(err) t.Fail() @@ -24,7 +24,7 @@ func testOpenFile(encoding string, content []byte, t *testing.T) { t.Log("open static file encoding "+n, b) - assetOpenFileAndContent(sch, content, t) + assetOpenFileAndContent(sch, reader, content, t) } func TestOpenStaticFile_1(t *testing.T) { file, _ := os.Open(licenseFile) @@ -53,13 +53,13 @@ func TestOpenStaticFileDeflate_1(t *testing.T) { testOpenFile("deflate", content, t) } -func assetOpenFileAndContent(sch *serveContentHolder, content []byte, t *testing.T) { +func assetOpenFileAndContent(sch *serveContentHolder, reader *serveContentReader, content []byte, t *testing.T) { t.Log(sch.size, len(content)) if sch.size != int64(len(content)) { t.Log("static content file size not same") t.Fail() } - bs, _ := ioutil.ReadAll(sch) + bs, _ := ioutil.ReadAll(reader) for i, v := range content { if v != bs[i] { t.Log("content not same") diff --git a/swagger/swagger.go b/swagger/swagger.go index 105d7fa4..a55676cd 100644 --- a/swagger/swagger.go +++ b/swagger/swagger.go @@ -122,6 +122,7 @@ type Schema struct { Items *Schema `json:"items,omitempty" yaml:"items,omitempty"` Properties map[string]Propertie `json:"properties,omitempty" yaml:"properties,omitempty"` Enum []interface{} `json:"enum,omitempty" yaml:"enum,omitempty"` + Example interface{} `json:"example,omitempty" yaml:"example,omitempty"` } // Propertie are taken from the JSON Schema definition but their definitions were adjusted to the Swagger Specification @@ -131,7 +132,7 @@ type Propertie struct { Description string `json:"description,omitempty" yaml:"description,omitempty"` Default interface{} `json:"default,omitempty" yaml:"default,omitempty"` Type string `json:"type,omitempty" yaml:"type,omitempty"` - Example string `json:"example,omitempty" yaml:"example,omitempty"` + Example interface{} `json:"example,omitempty" yaml:"example,omitempty"` Required []string `json:"required,omitempty" yaml:"required,omitempty"` Format string `json:"format,omitempty" yaml:"format,omitempty"` ReadOnly bool `json:"readOnly,omitempty" yaml:"readOnly,omitempty"` diff --git a/template.go b/template.go index 7adc0403..88c84e2f 100644 --- a/template.go +++ b/template.go @@ -218,6 +218,7 @@ func BuildTemplate(dir string, files ...string) error { } if err != nil { logs.Error("parse template err:", file, err) + templatesLock.Unlock() return err } beeTemplates[file] = t diff --git a/templatefunc.go b/templatefunc.go index a104fd24..e4d4667a 100644 --- a/templatefunc.go +++ b/templatefunc.go @@ -17,6 +17,7 @@ package beego import ( "errors" "fmt" + "html" "html/template" "net/url" "reflect" @@ -84,24 +85,24 @@ func DateFormat(t time.Time, layout string) (datestring string) { var datePatterns = []string{ // year "Y", "2006", // A full numeric representation of a year, 4 digits Examples: 1999 or 2003 - "y", "06", //A two digit representation of a year Examples: 99 or 03 + "y", "06", //A two digit representation of a year Examples: 99 or 03 // month - "m", "01", // Numeric representation of a month, with leading zeros 01 through 12 - "n", "1", // Numeric representation of a month, without leading zeros 1 through 12 - "M", "Jan", // A short textual representation of a month, three letters Jan through Dec + "m", "01", // Numeric representation of a month, with leading zeros 01 through 12 + "n", "1", // Numeric representation of a month, without leading zeros 1 through 12 + "M", "Jan", // A short textual representation of a month, three letters Jan through Dec "F", "January", // A full textual representation of a month, such as January or March January through December // day "d", "02", // Day of the month, 2 digits with leading zeros 01 to 31 - "j", "2", // Day of the month without leading zeros 1 to 31 + "j", "2", // Day of the month without leading zeros 1 to 31 // week - "D", "Mon", // A textual representation of a day, three letters Mon through Sun + "D", "Mon", // A textual representation of a day, three letters Mon through Sun "l", "Monday", // A full textual representation of the day of the week Sunday through Saturday // time - "g", "3", // 12-hour format of an hour without leading zeros 1 through 12 + "g", "3", // 12-hour format of an hour without leading zeros 1 through 12 "G", "15", // 24-hour format of an hour without leading zeros 0 through 23 "h", "03", // 12-hour format of an hour with leading zeros 01 through 12 "H", "15", // 24-hour format of an hour with leading zeros 00 through 23 @@ -207,14 +208,12 @@ func Htmlquote(text string) string { '<'&">' */ - text = strings.Replace(text, "&", "&", -1) // Must be done first! - text = strings.Replace(text, "<", "<", -1) - text = strings.Replace(text, ">", ">", -1) - text = strings.Replace(text, "'", "'", -1) - text = strings.Replace(text, "\"", """, -1) - text = strings.Replace(text, "“", "“", -1) - text = strings.Replace(text, "”", "”", -1) - text = strings.Replace(text, " ", " ", -1) + text = html.EscapeString(text) + text = strings.NewReplacer( + `“`, "“", + `”`, "”", + ` `, " ", + ).Replace(text) return strings.TrimSpace(text) } @@ -228,17 +227,7 @@ func Htmlunquote(text string) string { '<\\'&">' */ - // strings.Replace(s, old, new, n) - // 在s字符串中,把old字符串替换为new字符串,n表示替换的次数,小于0表示全部替换 - - text = strings.Replace(text, " ", " ", -1) - text = strings.Replace(text, "”", "”", -1) - text = strings.Replace(text, "“", "“", -1) - text = strings.Replace(text, """, "\"", -1) - text = strings.Replace(text, "'", "'", -1) - text = strings.Replace(text, ">", ">", -1) - text = strings.Replace(text, "<", "<", -1) - text = strings.Replace(text, "&", "&", -1) // Must be done last! + text = html.UnescapeString(text) return strings.TrimSpace(text) } diff --git a/templatefunc_test.go b/templatefunc_test.go index 9df61125..aa82c26f 100644 --- a/templatefunc_test.go +++ b/templatefunc_test.go @@ -94,7 +94,7 @@ func TestCompareRelated(t *testing.T) { } func TestHtmlquote(t *testing.T) { - h := `<' ”“&">` + h := `<' ”“&">` s := `<' ”“&">` if Htmlquote(s) != h { t.Error("should be equal") @@ -102,8 +102,8 @@ func TestHtmlquote(t *testing.T) { } func TestHtmlunquote(t *testing.T) { - h := `<' ”“&">` - s := `<' ”“&">` + h := `<' ”“&">` + s := `<' ”“&">` if Htmlunquote(h) != s { t.Error("should be equal") } diff --git a/validation/validation.go b/validation/validation.go index 0d89be30..ca1e211f 100644 --- a/validation/validation.go +++ b/validation/validation.go @@ -245,7 +245,21 @@ func (v *Validation) ZipCode(obj interface{}, key string) *Result { } func (v *Validation) apply(chk Validator, obj interface{}) *Result { - if chk.IsSatisfied(obj) { + if nil == obj { + if chk.IsSatisfied(obj) { + return &Result{Ok: true} + } + } else if reflect.TypeOf(obj).Kind() == reflect.Ptr { + if reflect.ValueOf(obj).IsNil() { + if chk.IsSatisfied(nil) { + return &Result{Ok: true} + } + } else { + if chk.IsSatisfied(reflect.ValueOf(obj).Elem().Interface()) { + return &Result{Ok: true} + } + } + } else if chk.IsSatisfied(obj) { return &Result{Ok: true} } @@ -351,13 +365,24 @@ func (v *Validation) Valid(obj interface{}) (b bool, err error) { return } - var hasReuired bool + var hasRequired bool for _, vf := range vfs { if vf.Name == "Required" { - hasReuired = true + hasRequired = true } - if !hasReuired && v.RequiredFirst && len(objV.Field(i).String()) == 0 { + currentField := objV.Field(i).Interface() + if objV.Field(i).Kind() == reflect.Ptr { + if objV.Field(i).IsNil() { + currentField = "" + } else { + currentField = objV.Field(i).Elem().Interface() + } + } + + + chk := Required{""}.IsSatisfied(currentField) + if !hasRequired && v.RequiredFirst && !chk { if _, ok := CanSkipFuncs[vf.Name]; ok { continue } @@ -414,3 +439,9 @@ func (v *Validation) RecursiveValid(objc interface{}) (bool, error) { } return pass, err } + +func (v *Validation) CanSkipAlso(skipFunc string) { + if _, ok := CanSkipFuncs[skipFunc]; !ok { + CanSkipFuncs[skipFunc] = struct{}{} + } +} diff --git a/validation/validation_test.go b/validation/validation_test.go index bf612015..f97105fd 100644 --- a/validation/validation_test.go +++ b/validation/validation_test.go @@ -442,3 +442,122 @@ func TestSkipValid(t *testing.T) { t.Fatal("validation should be passed") } } + +func TestPointer(t *testing.T) { + type User struct { + ID int + + Email *string `valid:"Email"` + ReqEmail *string `valid:"Required;Email"` + } + + u := User{ + ReqEmail: nil, + Email: nil, + } + + valid := Validation{} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + validEmail := "a@a.com" + u = User{ + ReqEmail: &validEmail, + Email: nil, + } + + valid = Validation{RequiredFirst: true} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Fatal("validation should be passed") + } + + u = User{ + ReqEmail: &validEmail, + Email: nil, + } + + valid = Validation{} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + invalidEmail := "a@a" + u = User{ + ReqEmail: &validEmail, + Email: &invalidEmail, + } + + valid = Validation{RequiredFirst: true} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + u = User{ + ReqEmail: &validEmail, + Email: &invalidEmail, + } + + valid = Validation{} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } +} + + +func TestCanSkipAlso(t *testing.T) { + type User struct { + ID int + + Email string `valid:"Email"` + ReqEmail string `valid:"Required;Email"` + MatchRange int `valid:"Range(10, 20)"` + } + + u := User{ + ReqEmail: "a@a.com", + Email: "", + MatchRange: 0, + } + + valid := Validation{RequiredFirst: true} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + valid = Validation{RequiredFirst: true} + valid.CanSkipAlso("Range") + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Fatal("validation should be passed") + } + +} +