diff --git a/.github/ISSUE_TEMPLATE b/.github/ISSUE_TEMPLATE new file mode 100644 index 00000000..db349198 --- /dev/null +++ b/.github/ISSUE_TEMPLATE @@ -0,0 +1,17 @@ +Please answer these questions before submitting your issue. Thanks! + +1. What version of Go and beego are you using (`bee version`)? + + +2. What operating system and processor architecture are you using (`go env`)? + + +3. What did you do? +If possible, provide a recipe for reproducing the error. +A complete runnable program is good. + + +4. What did you expect to see? + + +5. What did you see instead? \ No newline at end of file diff --git a/.travis.yml b/.travis.yml index 3c821dcd..92d9ac8b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,7 @@ language: go go: - - tip - - 1.6.0 + - 1.6 - 1.5.3 - 1.4.3 services: @@ -32,12 +31,12 @@ install: - go get github.com/siddontang/ledisdb/config - go get github.com/siddontang/ledisdb/ledis - go get golang.org/x/tools/cmd/vet - - go get github.com/golang/lint/golint - go get github.com/ssdb/gossdb/ssdb before_script: - sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi" - sh -c "if [ '$ORM_DRIVER' = 'mysql' ]; then mysql -u root -e 'create database orm_test;'; fi" - sh -c "if [ '$ORM_DRIVER' = 'sqlite' ]; then touch $TRAVIS_BUILD_DIR/orm_test.db; fi" + - sh -c "if [ $(go version) == *1.[5-9]* ]; then go get github.com/golang/lint/golint; golint ./...; fi" - mkdir -p res/var - ./ssdb/ssdb-server ./ssdb/ssdb.conf -d after_script: @@ -45,7 +44,4 @@ after_script: - rm -rf ./res/var/* script: - go vet -x ./... - - $HOME/gopath/bin/golint ./... - go test -v ./... -notifications: - webhooks: https://hooks.pubu.im/services/z7m9bvybl3rgtg9 diff --git a/admin.go b/admin.go index 031e6421..94fa55b9 100644 --- a/admin.go +++ b/admin.go @@ -24,6 +24,7 @@ import ( "time" "github.com/astaxie/beego/grace" + "github.com/astaxie/beego/logs" "github.com/astaxie/beego/toolbox" "github.com/astaxie/beego/utils" ) @@ -196,7 +197,7 @@ func listConf(rw http.ResponseWriter, r *http.Request) { BeforeExec: "Before Exec", AfterExec: "After Exec", FinishRouter: "Finish Router"} { - if bf, ok := BeeApp.Handlers.filters[k]; ok { + if bf := BeeApp.Handlers.filters[k]; len(bf) > 0 { filterType = fr filterTypes = append(filterTypes, filterType) resultList := new([][]string) @@ -410,7 +411,7 @@ func (admin *adminApp) Run() { for p, f := range admin.routers { http.Handle(p, f) } - BeeLogger.Info("Admin server Running on %s", addr) + logs.Info("Admin server Running on %s", addr) var err error if BConfig.Listen.Graceful { @@ -419,6 +420,6 @@ func (admin *adminApp) Run() { err = http.ListenAndServe(addr, nil) } if err != nil { - BeeLogger.Critical("Admin ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) + logs.Critical("Admin ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) } } diff --git a/app.go b/app.go index af54ea4b..77116a81 100644 --- a/app.go +++ b/app.go @@ -24,6 +24,7 @@ import ( "time" "github.com/astaxie/beego/grace" + "github.com/astaxie/beego/logs" "github.com/astaxie/beego/utils" ) @@ -68,9 +69,9 @@ func (app *App) Run() { if BConfig.Listen.EnableFcgi { if BConfig.Listen.EnableStdIo { if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O - BeeLogger.Info("Use FCGI via standard I/O") + logs.Info("Use FCGI via standard I/O") } else { - BeeLogger.Critical("Cannot use FCGI via standard I/O", err) + logs.Critical("Cannot use FCGI via standard I/O", err) } return } @@ -84,10 +85,10 @@ func (app *App) Run() { l, err = net.Listen("tcp", addr) } if err != nil { - BeeLogger.Critical("Listen: ", err) + logs.Critical("Listen: ", err) } if err = fcgi.Serve(l, app.Handlers); err != nil { - BeeLogger.Critical("fcgi.Serve: ", err) + logs.Critical("fcgi.Serve: ", err) } return } @@ -95,6 +96,7 @@ func (app *App) Run() { app.Server.Handler = app.Handlers app.Server.ReadTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second app.Server.WriteTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second + app.Server.ErrorLog = logs.GetLogger("HTTP") // run graceful mode if BConfig.Listen.Graceful { @@ -111,7 +113,7 @@ func (app *App) Run() { server.Server.ReadTimeout = app.Server.ReadTimeout server.Server.WriteTimeout = app.Server.WriteTimeout if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { - BeeLogger.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) + logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) time.Sleep(100 * time.Microsecond) endRunning <- true } @@ -126,7 +128,7 @@ func (app *App) Run() { server.Network = "tcp4" } if err := server.ListenAndServe(); err != nil { - BeeLogger.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) + logs.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) time.Sleep(100 * time.Microsecond) endRunning <- true } @@ -144,9 +146,9 @@ func (app *App) Run() { if BConfig.Listen.HTTPSPort != 0 { app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) } - BeeLogger.Info("https server Running on %s", app.Server.Addr) + logs.Info("https server Running on %s", app.Server.Addr) if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { - BeeLogger.Critical("ListenAndServeTLS: ", err) + logs.Critical("ListenAndServeTLS: ", err) time.Sleep(100 * time.Microsecond) endRunning <- true } @@ -155,24 +157,24 @@ func (app *App) Run() { if BConfig.Listen.EnableHTTP { go func() { app.Server.Addr = addr - BeeLogger.Info("http server Running on %s", app.Server.Addr) + logs.Info("http server Running on %s", app.Server.Addr) if BConfig.Listen.ListenTCP4 { ln, err := net.Listen("tcp4", app.Server.Addr) if err != nil { - BeeLogger.Critical("ListenAndServe: ", err) + logs.Critical("ListenAndServe: ", err) time.Sleep(100 * time.Microsecond) endRunning <- true return } if err = app.Server.Serve(ln); err != nil { - BeeLogger.Critical("ListenAndServe: ", err) + logs.Critical("ListenAndServe: ", err) time.Sleep(100 * time.Microsecond) endRunning <- true return } } else { if err := app.Server.ListenAndServe(); err != nil { - BeeLogger.Critical("ListenAndServe: ", err) + logs.Critical("ListenAndServe: ", err) time.Sleep(100 * time.Microsecond) endRunning <- true } diff --git a/beego.go b/beego.go index fb628e5f..16601ecc 100644 --- a/beego.go +++ b/beego.go @@ -23,7 +23,7 @@ import ( const ( // VERSION represent beego web framework version. - VERSION = "1.6.0" + VERSION = "1.6.1" // DEV is for develop DEV = "dev" @@ -51,6 +51,7 @@ func AddAPPStartHook(hf hookfunc) { // beego.Run(":8089") // beego.Run("127.0.0.1:8089") func Run(params ...string) { + initBeforeHTTPRun() if len(params) > 0 && params[0] != "" { @@ -74,6 +75,7 @@ func initBeforeHTTPRun() { AddAPPStartHook(registerDocs) AddAPPStartHook(registerTemplate) AddAPPStartHook(registerAdmin) + AddAPPStartHook(registerGzip) for _, hk := range hooks { if err := hk(); err != nil { @@ -87,5 +89,6 @@ func TestBeegoInit(ap string) { os.Setenv("BEEGO_RUNMODE", "test") appConfigPath = filepath.Join(ap, "conf", "app.conf") os.Chdir(ap) + LoadAppConfig(appConfigProvider, appConfigPath) initBeforeHTTPRun() } diff --git a/cache/redis/redis_test.go b/cache/redis/redis_test.go index 47c5acc6..6b81da4d 100644 --- a/cache/redis/redis_test.go +++ b/cache/redis/redis_test.go @@ -18,9 +18,8 @@ import ( "testing" "time" - "github.com/garyburd/redigo/redis" - "github.com/astaxie/beego/cache" + "github.com/garyburd/redigo/redis" ) func TestRedisCache(t *testing.T) { diff --git a/cache/ssdb/ssdb_test.go b/cache/ssdb/ssdb_test.go index e03ba343..dd474960 100644 --- a/cache/ssdb/ssdb_test.go +++ b/cache/ssdb/ssdb_test.go @@ -1,10 +1,11 @@ package ssdb import ( - "github.com/astaxie/beego/cache" "strconv" "testing" "time" + + "github.com/astaxie/beego/cache" ) func TestSsdbcacheCache(t *testing.T) { diff --git a/config.go b/config.go index 2761e7cb..72045fd3 100644 --- a/config.go +++ b/config.go @@ -21,6 +21,7 @@ import ( "strings" "github.com/astaxie/beego/config" + "github.com/astaxie/beego/logs" "github.com/astaxie/beego/session" "github.com/astaxie/beego/utils" ) @@ -293,14 +294,14 @@ func parseConfig(appConfigPath string) (err error) { } //init log - BeeLogger.Reset() + logs.Reset() for adaptor, config := range BConfig.Log.Outputs { - err = BeeLogger.SetLogger(adaptor, config) + err = logs.SetLogger(adaptor, config) if err != nil { fmt.Printf("%s with the config `%s` got err:%s\n", adaptor, config, err) } } - SetLogFuncCall(BConfig.Log.FileLineNum) + logs.SetLogFuncCall(BConfig.Log.FileLineNum) return nil } @@ -316,10 +317,6 @@ func LoadAppConfig(adapterName, configPath string) error { return fmt.Errorf("the target config file: %s don't exist", configPath) } - if absConfigPath == appConfigPath { - return nil - } - appConfigPath = absConfigPath appConfigProvider = adapterName @@ -353,7 +350,7 @@ func (b *beegoAppConfig) String(key string) string { } func (b *beegoAppConfig) Strings(key string) []string { - if v := b.innerConfig.Strings(BConfig.RunMode + "::" + key); v[0] != "" { + if v := b.innerConfig.Strings(BConfig.RunMode + "::" + key); len(v) > 0 { return v } return b.innerConfig.Strings(key) diff --git a/config/ini.go b/config/ini.go index 829901d8..1ec56238 100644 --- a/config/ini.go +++ b/config/ini.go @@ -82,6 +82,10 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) { if err == io.EOF { break } + //It might be a good idea to throw a error on all unknonw errors? + if _, ok := err.(*os.PathError); ok { + return nil, err + } if bytes.Equal(line, bEmpty) { continue } diff --git a/context/acceptencoder.go b/context/acceptencoder.go index 033d9ca8..cb735445 100644 --- a/context/acceptencoder.go +++ b/context/acceptencoder.go @@ -27,6 +27,33 @@ import ( "sync" ) +var ( + //Default size==20B same as nginx + defaultGzipMinLength = 20 + //Content will only be compressed if content length is either unknown or greater than gzipMinLength. + gzipMinLength = defaultGzipMinLength + //The compression level used for deflate compression. (0-9). + gzipCompressLevel int + //List of HTTP methods to compress. If not set, only GET requests are compressed. + includedMethods map[string]bool + getMethodOnly bool +) + +func InitGzip(minLength, compressLevel int, methods []string) { + if minLength >= 0 { + gzipMinLength = minLength + } + gzipCompressLevel = compressLevel + if gzipCompressLevel < flate.NoCompression || gzipCompressLevel > flate.BestCompression { + gzipCompressLevel = flate.BestSpeed + } + getMethodOnly = (len(methods) == 0) || (len(methods) == 1 && strings.ToUpper(methods[0]) == "GET") + includedMethods = make(map[string]bool, len(methods)) + for _, v := range methods { + includedMethods[strings.ToUpper(v)] = true + } +} + type resetWriter interface { io.Writer Reset(w io.Writer) @@ -41,20 +68,20 @@ func (n nopResetWriter) Reset(w io.Writer) { } type acceptEncoder struct { - name string - levelEncode func(int) resetWriter - bestSpeedPool *sync.Pool - bestCompressionPool *sync.Pool + name string + levelEncode func(int) resetWriter + customCompressLevelPool *sync.Pool + bestCompressionPool *sync.Pool } func (ac acceptEncoder) encode(wr io.Writer, level int) resetWriter { - if ac.bestSpeedPool == nil || ac.bestCompressionPool == nil { + if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil { return nopResetWriter{wr} } var rwr resetWriter switch level { case flate.BestSpeed: - rwr = ac.bestSpeedPool.Get().(resetWriter) + rwr = ac.customCompressLevelPool.Get().(resetWriter) case flate.BestCompression: rwr = ac.bestCompressionPool.Get().(resetWriter) default: @@ -65,13 +92,18 @@ func (ac acceptEncoder) encode(wr io.Writer, level int) resetWriter { } func (ac acceptEncoder) put(wr resetWriter, level int) { - if ac.bestSpeedPool == nil || ac.bestCompressionPool == nil { + if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil { return } wr.Reset(nil) + + //notice + //compressionLevel==BestCompression DOES NOT MATTER + //sync.Pool will not memory leak + switch level { - case flate.BestSpeed: - ac.bestSpeedPool.Put(wr) + case gzipCompressLevel: + ac.customCompressLevelPool.Put(wr) case flate.BestCompression: ac.bestCompressionPool.Put(wr) } @@ -79,28 +111,22 @@ func (ac acceptEncoder) put(wr resetWriter, level int) { var ( noneCompressEncoder = acceptEncoder{"", nil, nil, nil} - gzipCompressEncoder = acceptEncoder{"gzip", - func(level int) resetWriter { wr, _ := gzip.NewWriterLevel(nil, level); return wr }, - &sync.Pool{ - New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestSpeed); return wr }, - }, - &sync.Pool{ - New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestCompression); return wr }, - }, + gzipCompressEncoder = acceptEncoder{ + name: "gzip", + levelEncode: func(level int) resetWriter { wr, _ := gzip.NewWriterLevel(nil, level); return wr }, + customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, gzipCompressLevel); return wr }}, + bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestCompression); return wr }}, } //according to the sec :http://tools.ietf.org/html/rfc2616#section-3.5 ,the deflate compress in http is zlib indeed //deflate //The "zlib" format defined in RFC 1950 [31] in combination with //the "deflate" compression mechanism described in RFC 1951 [29]. - deflateCompressEncoder = acceptEncoder{"deflate", - func(level int) resetWriter { wr, _ := zlib.NewWriterLevel(nil, level); return wr }, - &sync.Pool{ - New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestSpeed); return wr }, - }, - &sync.Pool{ - New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestCompression); return wr }, - }, + deflateCompressEncoder = acceptEncoder{ + name: "deflate", + levelEncode: func(level int) resetWriter { wr, _ := zlib.NewWriterLevel(nil, level); return wr }, + customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, gzipCompressLevel); return wr }}, + bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestCompression); return wr }}, } ) @@ -120,7 +146,11 @@ func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string, // WriteBody reads writes content to writer by the specific encoding(gzip/deflate) func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) { - return writeLevel(encoding, writer, bytes.NewReader(content), flate.BestSpeed) + if encoding == "" || len(content) < gzipMinLength { + _, err := writer.Write(content) + return false, "", err + } + return writeLevel(encoding, writer, bytes.NewReader(content), gzipCompressLevel) } // writeLevel reads from reader,writes to writer by specific encoding and compress level @@ -156,7 +186,10 @@ func ParseEncoding(r *http.Request) string { if r == nil { return "" } - return parseEncoding(r) + if (getMethodOnly && r.Method == "GET") || includedMethods[r.Method] { + return parseEncoding(r) + } + return "" } type q struct { diff --git a/context/context.go b/context/context.go index 31698c85..fee5e1c5 100644 --- a/context/context.go +++ b/context/context.go @@ -24,13 +24,11 @@ package context import ( "bufio" - "bytes" "crypto/hmac" "crypto/sha1" "encoding/base64" "errors" "fmt" - "io" "net" "net/http" "strconv" @@ -79,6 +77,7 @@ func (ctx *Context) Redirect(status int, localurl string) { // Abort stops this request. // if beego.ErrorMaps exists, panic body. func (ctx *Context) Abort(status int, body string) { + ctx.Output.SetStatus(status) panic(body) } @@ -190,34 +189,26 @@ func (r *Response) reset(rw http.ResponseWriter) { // Write writes the data to the connection as part of an HTTP reply, // and sets `started` to true. // started means the response has sent out. -func (w *Response) Write(p []byte) (int, error) { - w.Started = true - return w.ResponseWriter.Write(p) -} - -// Write writes the data to the connection as part of an HTTP reply, -// and sets `started` to true. -// started means the response has sent out. -func (w *Response) Copy(buf *bytes.Buffer) (int64, error) { - w.Started = true - return io.Copy(w.ResponseWriter, buf) +func (r *Response) Write(p []byte) (int, error) { + r.Started = true + return r.ResponseWriter.Write(p) } // WriteHeader sends an HTTP response header with status code, // and sets `started` to true. -func (w *Response) WriteHeader(code int) { - if w.Status > 0 { +func (r *Response) WriteHeader(code int) { + if r.Status > 0 { //prevent multiple response.WriteHeader calls return } - w.Status = code - w.Started = true - w.ResponseWriter.WriteHeader(code) + r.Status = code + r.Started = true + r.ResponseWriter.WriteHeader(code) } // Hijack hijacker for http -func (w *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) { - hj, ok := w.ResponseWriter.(http.Hijacker) +func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hj, ok := r.ResponseWriter.(http.Hijacker) if !ok { return nil, nil, errors.New("webserver doesn't support hijacking") } @@ -225,15 +216,15 @@ func (w *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) { } // Flush http.Flusher -func (w *Response) Flush() { - if f, ok := w.ResponseWriter.(http.Flusher); ok { +func (r *Response) Flush() { + if f, ok := r.ResponseWriter.(http.Flusher); ok { f.Flush() } } // CloseNotify http.CloseNotifier -func (w *Response) CloseNotify() <-chan bool { - if cn, ok := w.ResponseWriter.(http.CloseNotifier); ok { +func (r *Response) CloseNotify() <-chan bool { + if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok { return cn.CloseNotify() } return nil diff --git a/context/input_test.go b/context/input_test.go index 24f6fd99..8887aec4 100644 --- a/context/input_test.go +++ b/context/input_test.go @@ -100,7 +100,7 @@ func TestSubDomain(t *testing.T) { /* TODO Fix this r, _ = http.NewRequest("GET", "http://127.0.0.1/", nil) - beegoInput.Request = r + beegoInput.Context.Request = r if beegoInput.SubDomains() != "" { t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) } diff --git a/context/output.go b/context/output.go index 17404702..3568d89c 100644 --- a/context/output.go +++ b/context/output.go @@ -21,8 +21,10 @@ import ( "errors" "fmt" "html/template" + "io" "mime" "net/http" + "os" "path/filepath" "strconv" "strings" @@ -72,10 +74,11 @@ func (output *BeegoOutput) Body(content []byte) error { if output.Status != 0 { output.Context.ResponseWriter.WriteHeader(output.Status) output.Status = 0 + } else { + output.Context.ResponseWriter.Started = true } - - _, err := output.Context.ResponseWriter.Copy(buf) - return err + io.Copy(output.Context.ResponseWriter, buf) + return nil } // Cookie sets cookie value via given key. @@ -235,6 +238,12 @@ func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error { // Download forces response for download file. // it prepares the download response header automatically. func (output *BeegoOutput) Download(file string, filename ...string) { + // check get file error, file not found or other error. + if _, err := os.Stat(file); err != nil { + http.ServeFile(output.Context.ResponseWriter, output.Context.Request, file) + return + } + output.Header("Content-Description", "File Transfer") output.Header("Content-Type", "application/octet-stream") if len(filename) > 0 && filename[0] != "" { diff --git a/controller.go b/controller.go index 85894275..9d265758 100644 --- a/controller.go +++ b/controller.go @@ -261,12 +261,13 @@ func (c *Controller) Abort(code string) { // CustomAbort stops controller handler and show the error data, it's similar Aborts, but support status code and body. func (c *Controller) CustomAbort(status int, body string) { - c.Ctx.Output.Status = status // first panic from ErrorMaps, is is user defined error functions. if _, ok := ErrorMaps[body]; ok { + c.Ctx.Output.Status = status panic(body) } // last panic user string + c.Ctx.ResponseWriter.WriteHeader(status) c.Ctx.ResponseWriter.Write([]byte(body)) panic(ErrAbort) } diff --git a/error.go b/error.go index 4f48fab2..8cfa0c67 100644 --- a/error.go +++ b/error.go @@ -210,159 +210,139 @@ var ErrorMaps = make(map[string]*errorInfo, 10) // show 401 unauthorized error. func unauthorized(rw http.ResponseWriter, r *http.Request) { - t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := map[string]interface{}{ - "Title": http.StatusText(401), - "BeegoVersion": VERSION, - } - data["Content"] = template.HTML("
The page you have requested can't be authorized." + - "
Perhaps you are here because:" + - "

") - t.Execute(rw, data) + responseError(rw, r, + 401, + "
The page you have requested can't be authorized."+ + "
Perhaps you are here because:"+ + "

", + ) } // show 402 Payment Required func paymentRequired(rw http.ResponseWriter, r *http.Request) { - t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := map[string]interface{}{ - "Title": http.StatusText(402), - "BeegoVersion": VERSION, - } - data["Content"] = template.HTML("
The page you have requested Payment Required." + - "
Perhaps you are here because:" + - "

") - t.Execute(rw, data) + responseError(rw, r, + 402, + "
The page you have requested Payment Required."+ + "
Perhaps you are here because:"+ + "

", + ) } // show 403 forbidden error. func forbidden(rw http.ResponseWriter, r *http.Request) { - t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := map[string]interface{}{ - "Title": http.StatusText(403), - "BeegoVersion": VERSION, - } - data["Content"] = template.HTML("
The page you have requested is forbidden." + - "
Perhaps you are here because:" + - "

") - t.Execute(rw, data) + responseError(rw, r, + 403, + "
The page you have requested is forbidden."+ + "
Perhaps you are here because:"+ + "

", + ) } -// show 404 notfound error. +// show 404 not found error. func notFound(rw http.ResponseWriter, r *http.Request) { - t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := map[string]interface{}{ - "Title": http.StatusText(404), - "BeegoVersion": VERSION, - } - data["Content"] = template.HTML("
The page you have requested has flown the coop." + - "
Perhaps you are here because:" + - "

") - t.Execute(rw, data) + responseError(rw, r, + 404, + "
The page you have requested has flown the coop."+ + "
Perhaps you are here because:"+ + "

", + ) } // show 405 Method Not Allowed func methodNotAllowed(rw http.ResponseWriter, r *http.Request) { - t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := map[string]interface{}{ - "Title": http.StatusText(405), - "BeegoVersion": VERSION, - } - data["Content"] = template.HTML("
The method you have requested Not Allowed." + - "
Perhaps you are here because:" + - "

") - t.Execute(rw, data) + responseError(rw, r, + 405, + "
The method you have requested Not Allowed."+ + "
Perhaps you are here because:"+ + "

", + ) } // show 500 internal server error. func internalServerError(rw http.ResponseWriter, r *http.Request) { - t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := map[string]interface{}{ - "Title": http.StatusText(500), - "BeegoVersion": VERSION, - } - data["Content"] = template.HTML("
The page you have requested is down right now." + - "

") - t.Execute(rw, data) + responseError(rw, r, + 500, + "
The page you have requested is down right now."+ + "

", + ) } // show 501 Not Implemented. func notImplemented(rw http.ResponseWriter, r *http.Request) { - t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := map[string]interface{}{ - "Title": http.StatusText(504), - "BeegoVersion": VERSION, - } - data["Content"] = template.HTML("
The page you have requested is Not Implemented." + - "

") - t.Execute(rw, data) + responseError(rw, r, + 501, + "
The page you have requested is Not Implemented."+ + "

", + ) } // show 502 Bad Gateway. func badGateway(rw http.ResponseWriter, r *http.Request) { - t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := map[string]interface{}{ - "Title": http.StatusText(502), - "BeegoVersion": VERSION, - } - data["Content"] = template.HTML("
The page you have requested is down right now." + - "

") - t.Execute(rw, data) + responseError(rw, r, + 502, + "
The page you have requested is down right now."+ + "

", + ) } // show 503 service unavailable error. func serviceUnavailable(rw http.ResponseWriter, r *http.Request) { - t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := map[string]interface{}{ - "Title": http.StatusText(503), - "BeegoVersion": VERSION, - } - data["Content"] = template.HTML("
The page you have requested is unavailable." + - "
Perhaps you are here because:" + - "

") - t.Execute(rw, data) + responseError(rw, r, + 503, + "
The page you have requested is unavailable."+ + "
Perhaps you are here because:"+ + "

", + ) } // show 504 Gateway Timeout. func gatewayTimeout(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 504, + "
The page you have requested is unavailable"+ + "
Perhaps you are here because:"+ + "

", + ) +} + +func responseError(rw http.ResponseWriter, r *http.Request, errCode int, errContent string) { t, _ := template.New("beegoerrortemp").Parse(errtpl) data := map[string]interface{}{ - "Title": http.StatusText(504), + "Title": http.StatusText(errCode), "BeegoVersion": VERSION, + "Content": errContent, } - data["Content"] = template.HTML("
The page you have requested is unavailable." + - "
Perhaps you are here because:" + - "

") t.Execute(rw, data) } @@ -408,7 +388,10 @@ func exception(errCode string, ctx *context.Context) { if err == nil { return v } - return 503 + if ctx.Output.Status == 0 { + return 503 + } + return ctx.Output.Status } for _, ec := range []string{errCode, "503", "500"} { diff --git a/error_test.go b/error_test.go new file mode 100644 index 00000000..2fb8f962 --- /dev/null +++ b/error_test.go @@ -0,0 +1,88 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" +) + +type errorTestController struct { + Controller +} + +const parseCodeError = "parse code error" + +func (ec *errorTestController) Get() { + errorCode, err := ec.GetInt("code") + if err != nil { + ec.Abort(parseCodeError) + } + if errorCode != 0 { + ec.CustomAbort(errorCode, ec.GetString("code")) + } + ec.Abort("404") +} + +func TestErrorCode_01(t *testing.T) { + registerDefaultErrorHandler() + for k := range ErrorMaps { + r, _ := http.NewRequest("GET", "/error?code="+k, nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/error", &errorTestController{}) + handler.ServeHTTP(w, r) + code, _ := strconv.Atoi(k) + if w.Code != code { + t.Fail() + } + if !strings.Contains(string(w.Body.Bytes()), http.StatusText(code)) { + t.Fail() + } + } +} + +func TestErrorCode_02(t *testing.T) { + registerDefaultErrorHandler() + r, _ := http.NewRequest("GET", "/error?code=0", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/error", &errorTestController{}) + handler.ServeHTTP(w, r) + if w.Code != 404 { + t.Fail() + } +} + +func TestErrorCode_03(t *testing.T) { + registerDefaultErrorHandler() + r, _ := http.NewRequest("GET", "/error?code=panic", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/error", &errorTestController{}) + handler.ServeHTTP(w, r) + if w.Code != 200 { + t.Fail() + } + if string(w.Body.Bytes()) != parseCodeError { + t.Fail() + } +} diff --git a/filter_test.go b/filter_test.go index d9928d8d..4ca4d2b8 100644 --- a/filter_test.go +++ b/filter_test.go @@ -20,14 +20,8 @@ import ( "testing" "github.com/astaxie/beego/context" - "github.com/astaxie/beego/logs" ) -func init() { - BeeLogger = logs.NewLogger(10000) - BeeLogger.SetLogger("console", "") -} - var FilterUser = func(ctx *context.Context) { ctx.Output.Body([]byte("i am " + ctx.Input.Param(":last") + ctx.Input.Param(":first"))) } diff --git a/hooks.go b/hooks.go index 59b10b32..3d784467 100644 --- a/hooks.go +++ b/hooks.go @@ -6,6 +6,8 @@ import ( "net/http" "path/filepath" + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" "github.com/astaxie/beego/session" ) @@ -70,7 +72,7 @@ func registerSession() error { func registerTemplate() error { if err := BuildTemplate(BConfig.WebConfig.ViewsPath); err != nil { if BConfig.RunMode == DEV { - Warn(err) + logs.Warn(err) } return err } @@ -91,3 +93,14 @@ func registerAdmin() error { } return nil } + +func registerGzip() error { + if BConfig.EnableGzip { + context.InitGzip( + AppConfig.DefaultInt("gzipMinLength", -1), + AppConfig.DefaultInt("gzipCompressLevel", -1), + AppConfig.DefaultStrings("includedMethods", []string{"GET"}), + ) + } + return nil +} diff --git a/log.go b/log.go index 46ec57dd..e9412f92 100644 --- a/log.go +++ b/log.go @@ -33,82 +33,77 @@ const ( ) // BeeLogger references the used application logger. -var BeeLogger = logs.NewLogger(100) +var BeeLogger = logs.GetBeeLogger() // SetLevel sets the global log level used by the simple logger. func SetLevel(l int) { - BeeLogger.SetLevel(l) + logs.SetLevel(l) } // SetLogFuncCall set the CallDepth, default is 3 func SetLogFuncCall(b bool) { - BeeLogger.EnableFuncCallDepth(b) - BeeLogger.SetLogFuncCallDepth(3) + logs.SetLogFuncCall(b) } // SetLogger sets a new logger. func SetLogger(adaptername string, config string) error { - err := BeeLogger.SetLogger(adaptername, config) - if err != nil { - return err - } - return nil + return logs.SetLogger(adaptername, config) } // Emergency logs a message at emergency level. func Emergency(v ...interface{}) { - BeeLogger.Emergency(generateFmtStr(len(v)), v...) + logs.Emergency(generateFmtStr(len(v)), v...) } // Alert logs a message at alert level. func Alert(v ...interface{}) { - BeeLogger.Alert(generateFmtStr(len(v)), v...) + logs.Alert(generateFmtStr(len(v)), v...) } // Critical logs a message at critical level. func Critical(v ...interface{}) { - BeeLogger.Critical(generateFmtStr(len(v)), v...) + logs.Critical(generateFmtStr(len(v)), v...) } // Error logs a message at error level. func Error(v ...interface{}) { - BeeLogger.Error(generateFmtStr(len(v)), v...) + logs.Error(generateFmtStr(len(v)), v...) } // Warning logs a message at warning level. func Warning(v ...interface{}) { - BeeLogger.Warning(generateFmtStr(len(v)), v...) + logs.Warning(generateFmtStr(len(v)), v...) } // Warn compatibility alias for Warning() func Warn(v ...interface{}) { - BeeLogger.Warn(generateFmtStr(len(v)), v...) + logs.Warn(generateFmtStr(len(v)), v...) } // Notice logs a message at notice level. func Notice(v ...interface{}) { - BeeLogger.Notice(generateFmtStr(len(v)), v...) + logs.Notice(generateFmtStr(len(v)), v...) } // Informational logs a message at info level. func Informational(v ...interface{}) { - BeeLogger.Informational(generateFmtStr(len(v)), v...) + logs.Informational(generateFmtStr(len(v)), v...) } // Info compatibility alias for Warning() func Info(v ...interface{}) { - BeeLogger.Info(generateFmtStr(len(v)), v...) + logs.Info(generateFmtStr(len(v)), v...) } // Debug logs a message at debug level. func Debug(v ...interface{}) { - BeeLogger.Debug(generateFmtStr(len(v)), v...) + logs.Debug(generateFmtStr(len(v)), v...) } // Trace logs a message at trace level. // compatibility alias for Warning() func Trace(v ...interface{}) { - BeeLogger.Trace(generateFmtStr(len(v)), v...) + logs.Trace(generateFmtStr(len(v)), v...) } func generateFmtStr(n int) string { diff --git a/logs/conn.go b/logs/conn.go index 1db1a427..6d5bf6bf 100644 --- a/logs/conn.go +++ b/logs/conn.go @@ -113,5 +113,5 @@ func (c *connWriter) needToConnectOnMsg() bool { } func init() { - Register("conn", NewConn) + Register(AdapterConn, NewConn) } diff --git a/logs/console.go b/logs/console.go index 05d08a42..e6bf6c29 100644 --- a/logs/console.go +++ b/logs/console.go @@ -56,7 +56,7 @@ func NewConsole() Logger { cw := &consoleWriter{ lg: newLogWriter(os.Stdout), Level: LevelDebug, - Colorful: true, + Colorful: runtime.GOOS != "windows", } return cw } @@ -97,5 +97,5 @@ func (c *consoleWriter) Flush() { } func init() { - Register("console", NewConsole) + Register(AdapterConsole, NewConsole) } diff --git a/logs/es/es.go b/logs/es/es.go index 397ca2ef..22f4f650 100644 --- a/logs/es/es.go +++ b/logs/es/es.go @@ -76,5 +76,5 @@ func (el *esLogger) Flush() { } func init() { - logs.Register("es", NewES) + logs.Register(logs.AdapterEs, NewES) } diff --git a/logs/file.go b/logs/file.go index 9d3f78a0..5901143b 100644 --- a/logs/file.go +++ b/logs/file.go @@ -30,7 +30,7 @@ import ( // fileLogWriter implements LoggerInterface. // It writes messages by lines limit, file size limit, or time frequency. type fileLogWriter struct { - sync.Mutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize + sync.RWMutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize // The opened file Filename string `json:"filename"` fileWriter *os.File @@ -60,14 +60,11 @@ type fileLogWriter struct { // newFileWriter create a FileLogWriter returning as LoggerInterface. func newFileWriter() Logger { w := &fileLogWriter{ - Filename: "", - MaxLines: 1000000, - MaxSize: 1 << 28, //256 MB - Daily: true, - MaxDays: 7, - Rotate: true, - Level: LevelTrace, - Perm: 0660, + Daily: true, + MaxDays: 7, + Rotate: true, + Level: LevelTrace, + Perm: 0660, } return w } @@ -77,7 +74,7 @@ func newFileWriter() Logger { // { // "filename":"logs/beego.log", // "maxLines":10000, -// "maxsize":1<<30, +// "maxsize":1024, // "daily":true, // "maxDays":15, // "rotate":true, @@ -128,7 +125,9 @@ func (w *fileLogWriter) WriteMsg(when time.Time, msg string, level int) error { h, d := formatTimeHeader(when) msg = string(h) + msg + "\n" if w.Rotate { + w.RLock() if w.needRotate(len(msg), d) { + w.RUnlock() w.Lock() if w.needRotate(len(msg), d) { if err := w.doRotate(when); err != nil { @@ -136,6 +135,8 @@ func (w *fileLogWriter) WriteMsg(when time.Time, msg string, level int) error { } } w.Unlock() + } else { + w.RUnlock() } } @@ -255,8 +256,8 @@ func (w *fileLogWriter) deleteOldLog() { } }() - if !info.IsDir() && info.ModTime().Unix() < (time.Now().Unix()-60*60*24*w.MaxDays) { - if strings.HasPrefix(filepath.Base(path), w.fileNameOnly) && + if !info.IsDir() && info.ModTime().Add(24*time.Hour*time.Duration(w.MaxDays)).Before(time.Now()) { + if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && strings.HasSuffix(filepath.Base(path), w.suffix) { os.Remove(path) } @@ -278,5 +279,5 @@ func (w *fileLogWriter) Flush() { } func init() { - Register("file", newFileWriter) + Register(AdapterFile, newFileWriter) } diff --git a/logs/log.go b/logs/log.go index c6ba3dc3..695691b2 100644 --- a/logs/log.go +++ b/logs/log.go @@ -35,10 +35,12 @@ package logs import ( "fmt" + "log" "os" "path" "runtime" "strconv" + "strings" "sync" "time" ) @@ -55,16 +57,28 @@ const ( LevelDebug ) -// Legacy loglevel constants to ensure backwards compatibility. -// -// Deprecated: will be removed in 1.5.0. +// levelLogLogger is defined to implement log.Logger +// the real log level will be LevelEmergency +const levelLoggerImpl = -1 + +// Name for adapter with beego official support +const ( + AdapterConsole = "console" + AdapterFile = "file" + AdapterMultiFile = "multifile" + AdapterMail = "stmp" + AdapterConn = "conn" + AdapterEs = "es" +) + +// Legacy log level constants to ensure backwards compatibility. const ( LevelInfo = LevelInformational LevelTrace = LevelDebug LevelWarn = LevelWarning ) -type loggerType func() Logger +type newLoggerFunc func() Logger // Logger defines the behavior of a log provider. type Logger interface { @@ -74,12 +88,13 @@ type Logger interface { Flush() } -var adapters = make(map[string]loggerType) +var adapters = make(map[string]newLoggerFunc) +var levelPrefix = [LevelDebug + 1]string{"[M] ", "[A] ", "[C] ", "[E] ", "[W] ", "[N] ", "[I] ", "[D] "} // Register makes a log provide available by the provided name. // If Register is called twice with the same name or if driver is nil, // it panics. -func Register(name string, log loggerType) { +func Register(name string, log newLoggerFunc) { if log == nil { panic("logs: Register provide is nil") } @@ -94,15 +109,19 @@ func Register(name string, log loggerType) { type BeeLogger struct { lock sync.Mutex level int + init bool enableFuncCallDepth bool loggerFuncCallDepth int asynchronous bool + msgChanLen int64 msgChan chan *logMsg signalChan chan string wg sync.WaitGroup outputs []*nameLogger } +const defaultAsyncMsgLen = 1e3 + type nameLogger struct { Logger name string @@ -119,18 +138,31 @@ var logMsgPool *sync.Pool // NewLogger returns a new BeeLogger. // channelLen means the number of messages in chan(used where asynchronous is true). // if the buffering chan is full, logger adapters write to file or other way. -func NewLogger(channelLen int64) *BeeLogger { +func NewLogger(channelLens ...int64) *BeeLogger { bl := new(BeeLogger) bl.level = LevelDebug bl.loggerFuncCallDepth = 2 - bl.msgChan = make(chan *logMsg, channelLen) + bl.msgChanLen = append(channelLens, 0)[0] + if bl.msgChanLen <= 0 { + bl.msgChanLen = defaultAsyncMsgLen + } bl.signalChan = make(chan string, 1) + bl.setLogger(AdapterConsole) return bl } // Async set the log to asynchronous and start the goroutine -func (bl *BeeLogger) Async() *BeeLogger { +func (bl *BeeLogger) Async(msgLen ...int64) *BeeLogger { + bl.lock.Lock() + defer bl.lock.Unlock() + if bl.asynchronous { + return bl + } bl.asynchronous = true + if len(msgLen) > 0 && msgLen[0] > 0 { + bl.msgChanLen = msgLen[0] + } + bl.msgChan = make(chan *logMsg, bl.msgChanLen) logMsgPool = &sync.Pool{ New: func() interface{} { return &logMsg{} @@ -143,23 +175,41 @@ func (bl *BeeLogger) Async() *BeeLogger { // SetLogger provides a given logger adapter into BeeLogger with config string. // config need to be correct JSON as string: {"interval":360}. -func (bl *BeeLogger) SetLogger(adapterName string, config string) error { - bl.lock.Lock() - defer bl.lock.Unlock() - if log, ok := adapters[adapterName]; ok { - lg := log() - err := lg.Init(config) - if err != nil { - fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error()) - return err +func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { + config := append(configs, "{}")[0] + for _, l := range bl.outputs { + if l.name == adapterName { + return fmt.Errorf("logs: duplicate adaptername %q (you have set this logger before)", adapterName) } - bl.outputs = append(bl.outputs, &nameLogger{name: adapterName, Logger: lg}) - } else { + } + + log, ok := adapters[adapterName] + if !ok { return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName) } + + lg := log() + err := lg.Init(config) + if err != nil { + fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error()) + return err + } + bl.outputs = append(bl.outputs, &nameLogger{name: adapterName, Logger: lg}) return nil } +// SetLogger provides a given logger adapter into BeeLogger with config string. +// config need to be correct JSON as string: {"interval":360}. +func (bl *BeeLogger) SetLogger(adapterName string, configs ...string) error { + bl.lock.Lock() + defer bl.lock.Unlock() + if !bl.init { + bl.outputs = []*nameLogger{} + bl.init = true + } + return bl.setLogger(adapterName, configs...) +} + // DelLogger remove a logger adapter in BeeLogger. func (bl *BeeLogger) DelLogger(adapterName string) error { bl.lock.Lock() @@ -188,7 +238,37 @@ func (bl *BeeLogger) writeToLoggers(when time.Time, msg string, level int) { } } -func (bl *BeeLogger) writeMsg(logLevel int, msg string) error { +func (bl *BeeLogger) Write(p []byte) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + // writeMsg will always add a '\n' character + if p[len(p)-1] == '\n' { + p = p[0 : len(p)-1] + } + // set levelLoggerImpl to ensure all log message will be write out + err = bl.writeMsg(levelLoggerImpl, string(p)) + if err == nil { + return len(p), err + } + return 0, err +} + +func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error { + if !beeLogger.init { + bl.lock.Lock() + bl.setLogger(AdapterConsole) + bl.lock.Unlock() + } + if logLevel == levelLoggerImpl { + // set to emergency to ensure all log will be print out correctly + logLevel = LevelEmergency + } else { + msg = levelPrefix[logLevel] + msg + } + if len(v) > 0 { + msg = fmt.Sprintf(msg, v...) + } when := time.Now() if bl.enableFuncCallDepth { _, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth) @@ -197,7 +277,7 @@ func (bl *BeeLogger) writeMsg(logLevel int, msg string) error { line = 0 } _, filename := path.Split(file) - msg = "[" + filename + ":" + strconv.FormatInt(int64(line), 10) + "]" + msg + msg = "[" + filename + ":" + strconv.FormatInt(int64(line), 10) + "] " + msg } if bl.asynchronous { lm := logMsgPool.Get().(*logMsg) @@ -265,8 +345,7 @@ func (bl *BeeLogger) Emergency(format string, v ...interface{}) { if LevelEmergency > bl.level { return } - msg := fmt.Sprintf("[M] "+format, v...) - bl.writeMsg(LevelEmergency, msg) + bl.writeMsg(LevelEmergency, format, v...) } // Alert Log ALERT level message. @@ -274,8 +353,7 @@ func (bl *BeeLogger) Alert(format string, v ...interface{}) { if LevelAlert > bl.level { return } - msg := fmt.Sprintf("[A] "+format, v...) - bl.writeMsg(LevelAlert, msg) + bl.writeMsg(LevelAlert, format, v...) } // Critical Log CRITICAL level message. @@ -283,8 +361,7 @@ func (bl *BeeLogger) Critical(format string, v ...interface{}) { if LevelCritical > bl.level { return } - msg := fmt.Sprintf("[C] "+format, v...) - bl.writeMsg(LevelCritical, msg) + bl.writeMsg(LevelCritical, format, v...) } // Error Log ERROR level message. @@ -292,17 +369,12 @@ func (bl *BeeLogger) Error(format string, v ...interface{}) { if LevelError > bl.level { return } - msg := fmt.Sprintf("[E] "+format, v...) - bl.writeMsg(LevelError, msg) + bl.writeMsg(LevelError, format, v...) } // Warning Log WARNING level message. func (bl *BeeLogger) Warning(format string, v ...interface{}) { - if LevelWarning > bl.level { - return - } - msg := fmt.Sprintf("[W] "+format, v...) - bl.writeMsg(LevelWarning, msg) + bl.Warn(format, v...) } // Notice Log NOTICE level message. @@ -310,17 +382,12 @@ func (bl *BeeLogger) Notice(format string, v ...interface{}) { if LevelNotice > bl.level { return } - msg := fmt.Sprintf("[N] "+format, v...) - bl.writeMsg(LevelNotice, msg) + bl.writeMsg(LevelNotice, format, v...) } // Informational Log INFORMATIONAL level message. func (bl *BeeLogger) Informational(format string, v ...interface{}) { - if LevelInformational > bl.level { - return - } - msg := fmt.Sprintf("[I] "+format, v...) - bl.writeMsg(LevelInformational, msg) + bl.Info(format, v...) } // Debug Log DEBUG level message. @@ -328,38 +395,31 @@ func (bl *BeeLogger) Debug(format string, v ...interface{}) { if LevelDebug > bl.level { return } - msg := fmt.Sprintf("[D] "+format, v...) - bl.writeMsg(LevelDebug, msg) + bl.writeMsg(LevelDebug, format, v...) } // Warn Log WARN level message. // compatibility alias for Warning() func (bl *BeeLogger) Warn(format string, v ...interface{}) { - if LevelWarning > bl.level { + if LevelWarn > bl.level { return } - msg := fmt.Sprintf("[W] "+format, v...) - bl.writeMsg(LevelWarning, msg) + bl.writeMsg(LevelWarn, format, v...) } // Info Log INFO level message. // compatibility alias for Informational() func (bl *BeeLogger) Info(format string, v ...interface{}) { - if LevelInformational > bl.level { + if LevelInfo > bl.level { return } - msg := fmt.Sprintf("[I] "+format, v...) - bl.writeMsg(LevelInformational, msg) + bl.writeMsg(LevelInfo, format, v...) } // Trace Log TRACE level message. // compatibility alias for Debug() func (bl *BeeLogger) Trace(format string, v ...interface{}) { - if LevelDebug > bl.level { - return - } - msg := fmt.Sprintf("[D] "+format, v...) - bl.writeMsg(LevelDebug, msg) + bl.Debug(format, v...) } // Flush flush all chan data. @@ -378,6 +438,7 @@ func (bl *BeeLogger) Close() { if bl.asynchronous { bl.signalChan <- "close" bl.wg.Wait() + close(bl.msgChan) } else { bl.flush() for _, l := range bl.outputs { @@ -385,7 +446,6 @@ func (bl *BeeLogger) Close() { } bl.outputs = nil } - close(bl.msgChan) close(bl.signalChan) } @@ -399,16 +459,175 @@ func (bl *BeeLogger) Reset() { } func (bl *BeeLogger) flush() { - for { - if len(bl.msgChan) > 0 { - bm := <-bl.msgChan - bl.writeToLoggers(bm.when, bm.msg, bm.level) - logMsgPool.Put(bm) - continue + if bl.asynchronous { + for { + if len(bl.msgChan) > 0 { + bm := <-bl.msgChan + bl.writeToLoggers(bm.when, bm.msg, bm.level) + logMsgPool.Put(bm) + continue + } + break } - break } for _, l := range bl.outputs { l.Flush() } } + +// beeLogger references the used application logger. +var beeLogger *BeeLogger = NewLogger() + +// GetLogger returns the default BeeLogger +func GetBeeLogger() *BeeLogger { + return beeLogger +} + +var beeLoggerMap = struct { + sync.RWMutex + logs map[string]*log.Logger +}{ + logs: map[string]*log.Logger{}, +} + +// GetLogger returns the default BeeLogger +func GetLogger(prefixes ...string) *log.Logger { + prefix := append(prefixes, "")[0] + if prefix != "" { + prefix = fmt.Sprintf(`[%s] `, strings.ToUpper(prefix)) + } + beeLoggerMap.RLock() + l, ok := beeLoggerMap.logs[prefix] + if ok { + beeLoggerMap.RUnlock() + return l + } + beeLoggerMap.RUnlock() + beeLoggerMap.Lock() + defer beeLoggerMap.Unlock() + l, ok = beeLoggerMap.logs[prefix] + if !ok { + l = log.New(beeLogger, prefix, 0) + beeLoggerMap.logs[prefix] = l + } + return l +} + +// Reset will remove all the adapter +func Reset() { + beeLogger.Reset() +} + +func Async(msgLen ...int64) *BeeLogger { + return beeLogger.Async(msgLen...) +} + +// SetLevel sets the global log level used by the simple logger. +func SetLevel(l int) { + beeLogger.SetLevel(l) +} + +// EnableFuncCallDepth enable log funcCallDepth +func EnableFuncCallDepth(b bool) { + beeLogger.enableFuncCallDepth = b +} + +// SetLogFuncCall set the CallDepth, default is 3 +func SetLogFuncCall(b bool) { + beeLogger.EnableFuncCallDepth(b) + beeLogger.SetLogFuncCallDepth(3) +} + +// SetLogFuncCallDepth set log funcCallDepth +func SetLogFuncCallDepth(d int) { + beeLogger.loggerFuncCallDepth = d +} + +// SetLogger sets a new logger. +func SetLogger(adapter string, config ...string) error { + err := beeLogger.SetLogger(adapter, config...) + if err != nil { + return err + } + return nil +} + +// Emergency logs a message at emergency level. +func Emergency(f interface{}, v ...interface{}) { + beeLogger.Emergency(formatLog(f, v...)) +} + +// Alert logs a message at alert level. +func Alert(f interface{}, v ...interface{}) { + beeLogger.Alert(formatLog(f, v...)) +} + +// Critical logs a message at critical level. +func Critical(f interface{}, v ...interface{}) { + beeLogger.Critical(formatLog(f, v...)) +} + +// Error logs a message at error level. +func Error(f interface{}, v ...interface{}) { + beeLogger.Error(formatLog(f, v...)) +} + +// Warning logs a message at warning level. +func Warning(f interface{}, v ...interface{}) { + beeLogger.Warn(formatLog(f, v...)) +} + +// Warn compatibility alias for Warning() +func Warn(f interface{}, v ...interface{}) { + beeLogger.Warn(formatLog(f, v...)) +} + +// Notice logs a message at notice level. +func Notice(f interface{}, v ...interface{}) { + beeLogger.Notice(formatLog(f, v...)) +} + +// Informational logs a message at info level. +func Informational(f interface{}, v ...interface{}) { + beeLogger.Info(formatLog(f, v...)) +} + +// Info compatibility alias for Warning() +func Info(f interface{}, v ...interface{}) { + beeLogger.Info(formatLog(f, v...)) +} + +// Debug logs a message at debug level. +func Debug(f interface{}, v ...interface{}) { + beeLogger.Debug(formatLog(f, v...)) +} + +// Trace logs a message at trace level. +// compatibility alias for Warning() +func Trace(f interface{}, v ...interface{}) { + beeLogger.Trace(formatLog(f, v...)) +} + +func formatLog(f interface{}, v ...interface{}) string { + var msg string + switch f.(type) { + case string: + msg = f.(string) + if len(v) == 0 { + return msg + } + if strings.Contains(msg, "%") && !strings.Contains(msg, "%%") { + //format string + } else { + //do not contain format char + msg += strings.Repeat(" %v", len(v)) + } + default: + msg = fmt.Sprint(f) + if len(v) == 0 { + return msg + } + msg += strings.Repeat(" %v", len(v)) + } + return fmt.Sprintf(msg, v...) +} diff --git a/logs/logger.go b/logs/logger.go index 323c41c5..b25bfaef 100644 --- a/logs/logger.go +++ b/logs/logger.go @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - package logs import ( diff --git a/logs/multifile.go b/logs/multifile.go index b82ba274..63204e17 100644 --- a/logs/multifile.go +++ b/logs/multifile.go @@ -112,5 +112,5 @@ func newFilesWriter() Logger { } func init() { - Register("multifile", newFilesWriter) + Register(AdapterMultiFile, newFilesWriter) } diff --git a/logs/smtp.go b/logs/smtp.go index 47f5a0c6..834130ef 100644 --- a/logs/smtp.go +++ b/logs/smtp.go @@ -156,5 +156,5 @@ func (s *SMTPWriter) Destroy() { } func init() { - Register("smtp", newSMTPWriter) + Register(AdapterMail, newSMTPWriter) } diff --git a/migration/migration.go b/migration/migration.go index 1591bc50..c9ca1bc6 100644 --- a/migration/migration.go +++ b/migration/migration.go @@ -33,7 +33,7 @@ import ( "strings" "time" - "github.com/astaxie/beego" + "github.com/astaxie/beego/logs" "github.com/astaxie/beego/orm" ) @@ -90,7 +90,7 @@ func (m *Migration) Reset() { func (m *Migration) Exec(name, status string) error { o := orm.NewOrm() for _, s := range m.sqls { - beego.Info("exec sql:", s) + logs.Info("exec sql:", s) r := o.Raw(s) _, err := r.Exec() if err != nil { @@ -144,20 +144,20 @@ func Upgrade(lasttime int64) error { i := 0 for _, v := range sm { if v.created > lasttime { - beego.Info("start upgrade", v.name) + logs.Info("start upgrade", v.name) v.m.Reset() v.m.Up() err := v.m.Exec(v.name, "up") if err != nil { - beego.Error("execute error:", err) + logs.Error("execute error:", err) time.Sleep(2 * time.Second) return err } - beego.Info("end upgrade:", v.name) + logs.Info("end upgrade:", v.name) i++ } } - beego.Info("total success upgrade:", i, " migration") + logs.Info("total success upgrade:", i, " migration") time.Sleep(2 * time.Second) return nil } @@ -165,20 +165,20 @@ func Upgrade(lasttime int64) error { // Rollback rollback the migration by the name func Rollback(name string) error { if v, ok := migrationMap[name]; ok { - beego.Info("start rollback") + logs.Info("start rollback") v.Reset() v.Down() err := v.Exec(name, "down") if err != nil { - beego.Error("execute error:", err) + logs.Error("execute error:", err) time.Sleep(2 * time.Second) return err } - beego.Info("end rollback") + logs.Info("end rollback") time.Sleep(2 * time.Second) return nil } - beego.Error("not exist the migrationMap name:" + name) + logs.Error("not exist the migrationMap name:" + name) time.Sleep(2 * time.Second) return errors.New("not exist the migrationMap name:" + name) } @@ -191,23 +191,23 @@ func Reset() error { for j := len(sm) - 1; j >= 0; j-- { v := sm[j] if isRollBack(v.name) { - beego.Info("skip the", v.name) + logs.Info("skip the", v.name) time.Sleep(1 * time.Second) continue } - beego.Info("start reset:", v.name) + logs.Info("start reset:", v.name) v.m.Reset() v.m.Down() err := v.m.Exec(v.name, "down") if err != nil { - beego.Error("execute error:", err) + logs.Error("execute error:", err) time.Sleep(2 * time.Second) return err } i++ - beego.Info("end reset:", v.name) + logs.Info("end reset:", v.name) } - beego.Info("total success reset:", i, " migration") + logs.Info("total success reset:", i, " migration") time.Sleep(2 * time.Second) return nil } @@ -216,7 +216,7 @@ func Reset() error { func Refresh() error { err := Reset() if err != nil { - beego.Error("execute error:", err) + logs.Error("execute error:", err) time.Sleep(2 * time.Second) return err } @@ -265,7 +265,7 @@ func isRollBack(name string) bool { var maps []orm.Params num, err := o.Raw("select * from migrations where `name` = ? order by id_migration desc", name).Values(&maps) if err != nil { - beego.Info("get name has error", err) + logs.Info("get name has error", err) return false } if num <= 0 { diff --git a/namespace.go b/namespace.go index 0dfdd7af..cfde0111 100644 --- a/namespace.go +++ b/namespace.go @@ -44,7 +44,7 @@ func NewNamespace(prefix string, params ...LinkNamespace) *Namespace { return ns } -// Cond set condtion function +// Cond set condition function // if cond return true can run this namespace, else can't // usage: // ns.Cond(func (ctx *context.Context) bool{ @@ -60,7 +60,7 @@ func (n *Namespace) Cond(cond namespaceCond) *Namespace { exception("405", ctx) } } - if v, ok := n.handlers.filters[BeforeRouter]; ok { + if v := n.handlers.filters[BeforeRouter]; len(v) > 0 { mr := new(FilterRouter) mr.tree = NewTree() mr.pattern = "*" @@ -388,3 +388,10 @@ func NSNamespace(prefix string, params ...LinkNamespace) LinkNamespace { ns.Namespace(n) } } + +// NSHandler add handler +func NSHandler(rootpath string, h http.Handler) LinkNamespace { + return func(ns *Namespace) { + ns.Handler(rootpath, h) + } +} diff --git a/orm/db.go b/orm/db.go index 314c3535..efc90e6e 100644 --- a/orm/db.go +++ b/orm/db.go @@ -71,12 +71,12 @@ type dbBase struct { var _ dbBaser = new(dbBase) // get struct columns values as interface slice. -func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, err error) { - var columns []string - - if names != nil { - columns = *names +func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, autoFields []string, err error) { + if names == nil { + ns := make([]string, 0, len(cols)) + names = &ns } + values = make([]interface{}, 0, len(cols)) for _, column := range cols { var fi *fieldInfo @@ -90,18 +90,24 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, } value, err := d.collectFieldValue(mi, fi, ind, insert, tz) if err != nil { - return nil, err + return nil, nil, err } - if names != nil { - columns = append(columns, column) + // ignore empty value auto field + if insert && fi.auto { + if fi.fieldType&IsPositiveIntegerField > 0 { + if vu, ok := value.(uint64); !ok || vu == 0 { + continue + } + } else { + if vu, ok := value.(int64); !ok || vu == 0 { + continue + } + } + autoFields = append(autoFields, fi.column) } - values = append(values, value) - } - - if names != nil { - *names = columns + *names, values = append(*names, column), append(values, value) } return @@ -181,7 +187,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val } default: switch { - case fi.fieldType&IsPostiveIntegerField > 0: + case fi.fieldType&IsPositiveIntegerField > 0: if field.Kind() == reflect.Ptr { if field.IsNil() { value = nil @@ -273,7 +279,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, // insert struct with prepared statement and given struct reflect value. func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { - values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) + values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) if err != nil { return 0, err } @@ -300,7 +306,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo if len(cols) > 0 { var err error whereCols = make([]string, 0, len(cols)) - args, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz) + args, _, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz) if err != nil { return err } @@ -349,13 +355,21 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo // execute insert sql dbQuerier with given struct reflect.Value. func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { - names := make([]string, 0, len(mi.fields.dbcols)-1) - values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz) + names := make([]string, 0, len(mi.fields.dbcols)) + values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) if err != nil { return 0, err } - return d.InsertValue(q, mi, false, names, values) + id, err := d.InsertValue(q, mi, false, names, values) + if err != nil { + return 0, err + } + + if len(autoFields) > 0 { + err = d.ins.setval(q, mi, autoFields) + } + return id, err } // multi-insert sql with given slice struct reflect.Value. @@ -369,7 +383,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul // typ := reflect.Indirect(mi.addrField).Type() - length := sind.Len() + length, autoFields := sind.Len(), make([]string, 0, 1) for i := 1; i <= length; i++ { @@ -381,16 +395,18 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul // } if i == 1 { - vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz) + var ( + vus []interface{} + err error + ) + vus, autoFields, err = d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) if err != nil { return cnt, err } values = make([]interface{}, bulk*len(vus)) nums += copy(values, vus) - } else { - - vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) + vus, _, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, nil, tz) if err != nil { return cnt, err } @@ -412,7 +428,12 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul } } - return cnt, nil + var err error + if len(autoFields) > 0 { + err = d.ins.setval(q, mi, autoFields) + } + + return cnt, err } // execute insert sql with given struct and given values. @@ -472,7 +493,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. setNames = make([]string, 0, len(cols)) } - setValues, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz) + setValues, _, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz) if err != nil { return 0, err } @@ -516,7 +537,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. } if num > 0 { if mi.fields.pk.auto { - if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { + if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(0) } else { ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0) @@ -1140,7 +1161,7 @@ setValue: tErr = err goto end } - if fieldType&IsPostiveIntegerField > 0 { + if fieldType&IsPositiveIntegerField > 0 { v, _ := str.Uint64() value = v } else { @@ -1292,7 +1313,7 @@ setValue: field.Set(reflect.ValueOf(&v)) } case fieldType&IsIntegerField > 0: - if fieldType&IsPostiveIntegerField > 0 { + if fieldType&IsPositiveIntegerField > 0 { if isNative { if value == nil { value = uint64(0) @@ -1562,6 +1583,11 @@ func (d *dbBase) HasReturningID(*modelInfo, *string) bool { return false } +// sync auto key +func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { + return nil +} + // convert time from db. func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) { *t = t.In(tz) diff --git a/orm/db_alias.go b/orm/db_alias.go index 79576b8e..b6c833a7 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -59,6 +59,7 @@ var ( "postgres": DRPostgres, "sqlite3": DRSqlite, "tidb": DRTiDB, + "oracle": DROracle, } dbBasers = map[DriverType]dbBaser{ DRMySQL: newdbBaseMysql(), @@ -151,7 +152,7 @@ func detectTZ(al *alias) { al.Engine = "INNODB" } - case DRSqlite: + case DRSqlite, DROracle: al.TZ = time.UTC case DRPostgres: diff --git a/orm/db_oracle.go b/orm/db_oracle.go index 1e385c9a..deca36ad 100644 --- a/orm/db_oracle.go +++ b/orm/db_oracle.go @@ -14,6 +14,41 @@ package orm +import ( + "fmt" + "strings" +) + +// oracle operators. +var oracleOperators = map[string]string{ + "exact": "= ?", + "gt": "> ?", + "gte": ">= ?", + "lt": "< ?", + "lte": "<= ?", + "//iendswith": "LIKE ?", +} + +// oracle column field types. +var oracleTypes = map[string]string{ + "pk": "NOT NULL PRIMARY KEY", + "bool": "bool", + "string": "VARCHAR2(%d)", + "string-text": "VARCHAR2(%d)", + "time.Time-date": "DATE", + "time.Time": "TIMESTAMP", + "int8": "INTEGER", + "int16": "INTEGER", + "int32": "INTEGER", + "int64": "INTEGER", + "uint8": "INTEGER", + "uint16": "INTEGER", + "uint32": "INTEGER", + "uint64": "INTEGER", + "float64": "NUMBER", + "float64-decimal": "NUMBER(%d, %d)", +} + // oracle dbBaser type dbBaseOracle struct { dbBase @@ -27,3 +62,35 @@ func newdbBaseOracle() dbBaser { b.ins = b return b } + +// OperatorSQL get oracle operator. +func (d *dbBaseOracle) OperatorSQL(operator string) string { + return oracleOperators[operator] +} + +// DbTypes get oracle table field types. +func (d *dbBaseOracle) DbTypes() map[string]string { + return oracleTypes +} + +//ShowTablesQuery show all the tables in database +func (d *dbBaseOracle) ShowTablesQuery() string { + return "SELECT TABLE_NAME FROM USER_TABLES" +} + +// Oracle +func (d *dbBaseOracle) ShowColumnsQuery(table string) string { + return fmt.Sprintf("SELECT COLUMN_NAME FROM ALL_TAB_COLUMNS "+ + "WHERE TABLE_NAME ='%s'", strings.ToUpper(table)) +} + +// check index is exist +func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool { + row := db.QueryRow("SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+ + "WHERE USER_IND_COLUMNS.INDEX_NAME = USER_INDEXES.INDEX_NAME "+ + "AND USER_IND_COLUMNS.TABLE_NAME = ? AND USER_IND_COLUMNS.INDEX_NAME = ?", strings.ToUpper(table), strings.ToUpper(name)) + + var cnt int + row.Scan(&cnt) + return cnt > 0 +} diff --git a/orm/db_postgres.go b/orm/db_postgres.go index 7dbef95a..8fbcb88d 100644 --- a/orm/db_postgres.go +++ b/orm/db_postgres.go @@ -123,14 +123,35 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) { } // make returning sql support for postgresql. -func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) { - if mi.fields.pk.auto { - if query != nil { - *query = fmt.Sprintf(`%s RETURNING "%s"`, *query, mi.fields.pk.column) - } - has = true +func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool { + fi := mi.fields.pk + if fi.fieldType&IsPositiveIntegerField == 0 && fi.fieldType&IsIntegerField == 0 { + return false } - return + + if query != nil { + *query = fmt.Sprintf(`%s RETURNING "%s"`, *query, fi.column) + } + return true +} + +// sync auto key +func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { + if len(autoFields) == 0 { + return nil + } + + Q := d.ins.TableQuote() + for _, name := range autoFields { + query := fmt.Sprintf("SELECT setval(pg_get_serial_sequence('%s', '%s'), (SELECT MAX(%s%s%s) FROM %s%s%s));", + mi.table, name, + Q, name, Q, + Q, mi.table, Q) + if _, err := db.Exec(query); err != nil { + return err + } + } + return nil } // show table sql for postgresql. diff --git a/orm/db_utils.go b/orm/db_utils.go index c97caf36..ff36b286 100644 --- a/orm/db_utils.go +++ b/orm/db_utils.go @@ -33,13 +33,13 @@ func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interfac fi := mi.fields.pk v := ind.FieldByIndex(fi.fieldIndex) - if fi.fieldType&IsPostiveIntegerField > 0 { + if fi.fieldType&IsPositiveIntegerField > 0 { vu := v.Uint() exist = vu > 0 value = vu } else if fi.fieldType&IsIntegerField > 0 { vu := v.Int() - exist = vu > 0 + exist = true value = vu } else { vu := v.String() diff --git a/orm/models_fields.go b/orm/models_fields.go index a8cf8e4f..ad1fb6f9 100644 --- a/orm/models_fields.go +++ b/orm/models_fields.go @@ -46,10 +46,10 @@ const ( // Define some logic enum const ( - IsIntegerField = ^-TypePositiveBigIntegerField >> 4 << 5 - IsPostiveIntegerField = ^-TypePositiveBigIntegerField >> 8 << 9 - IsRelField = ^-RelReverseMany >> 14 << 15 - IsFieldType = ^-RelReverseMany<<1 + 1 + IsIntegerField = ^-TypePositiveBigIntegerField >> 4 << 5 + IsPositiveIntegerField = ^-TypePositiveBigIntegerField >> 8 << 9 + IsRelField = ^-RelReverseMany >> 14 << 15 + IsFieldType = ^-RelReverseMany<<1 + 1 ) // BooleanField A true/false field. diff --git a/orm/models_test.go b/orm/models_test.go index 1ff4e53b..f42d725e 100644 --- a/orm/models_test.go +++ b/orm/models_test.go @@ -352,7 +352,7 @@ type GroupPermissions struct { } type ModelID struct { - Id int64 + ID int64 } type ModelBase struct { @@ -375,6 +375,28 @@ func NewInLine() *InLine { return new(InLine) } +type InLineOneToOne struct { + // Common Fields + ModelBase + + Note string + InLine *InLine `orm:"rel(fk);column(inline)"` +} + +func NewInLineOneToOne() *InLineOneToOne { + return new(InLineOneToOne) +} + +type IntegerPk struct { + Id int64 `orm:"pk"` + Value string +} + +type UintPk struct { + Id uint32 `orm:"pk"` + Name string +} + var DBARGS = struct { Driver string Source string diff --git a/orm/orm.go b/orm/orm.go index 0ffb6b86..6f4b7731 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -140,7 +140,14 @@ func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, i return (err == nil), id, err } - return false, ind.FieldByIndex(mi.fields.pk.fieldIndex).Int(), err + id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex) + if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { + id = int64(vid.Uint()) + } else { + id = vid.Int() + } + + return false, id, err } // insert model data to database @@ -159,7 +166,7 @@ func (o *orm) Insert(md interface{}) (int64, error) { // set auto pk field func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) { if mi.fields.pk.auto { - if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { + if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id)) } else { ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(id) diff --git a/orm/orm_log.go b/orm/orm_log.go index 712eb219..54723273 100644 --- a/orm/orm_log.go +++ b/orm/orm_log.go @@ -31,7 +31,7 @@ type Log struct { // NewLog set io.Writer to create a Logger. func NewLog(out io.Writer) *Log { d := new(Log) - d.Logger = log.New(out, "[ORM]", 1e9) + d.Logger = log.New(out, "[ORM]", log.LstdFlags) return d } diff --git a/orm/orm_object.go b/orm/orm_object.go index 8a5d85e2..de3181ce 100644 --- a/orm/orm_object.go +++ b/orm/orm_object.go @@ -50,7 +50,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) { } if id > 0 { if o.mi.fields.pk.auto { - if o.mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { + if o.mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetUint(uint64(id)) } else { ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetInt(id) diff --git a/orm/orm_test.go b/orm/orm_test.go index f638fba4..3f6e1566 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -19,6 +19,7 @@ import ( "database/sql" "fmt" "io/ioutil" + "math" "os" "path/filepath" "reflect" @@ -188,6 +189,9 @@ func TestSyncDb(t *testing.T) { RegisterModel(new(Permission)) RegisterModel(new(GroupPermissions)) RegisterModel(new(InLine)) + RegisterModel(new(InLineOneToOne)) + RegisterModel(new(IntegerPk)) + RegisterModel(new(UintPk)) err := RunSyncdb("default", true, Debug) throwFail(t, err) @@ -208,6 +212,9 @@ func TestRegisterModels(t *testing.T) { RegisterModel(new(Permission)) RegisterModel(new(GroupPermissions)) RegisterModel(new(InLine)) + RegisterModel(new(InLineOneToOne)) + RegisterModel(new(IntegerPk)) + RegisterModel(new(UintPk)) BootStrap() @@ -1943,7 +1950,7 @@ func TestInLine(t *testing.T) { throwFail(t, AssertIs(id, 1)) il := NewInLine() - il.Id = 1 + il.ID = 1 err = dORM.Read(il) throwFail(t, err) @@ -1952,3 +1959,120 @@ func TestInLine(t *testing.T) { throwFail(t, AssertIs(il.Created.In(DefaultTimeLoc), inline.Created.In(DefaultTimeLoc), testDate)) throwFail(t, AssertIs(il.Updated.In(DefaultTimeLoc), inline.Updated.In(DefaultTimeLoc), testDateTime)) } + +func TestInLineOneToOne(t *testing.T) { + name := "121" + email := "121@go.com" + inline := NewInLine() + inline.Name = name + inline.Email = email + + id, err := dORM.Insert(inline) + throwFail(t, err) + throwFail(t, AssertIs(id, 2)) + + note := "one2one" + il121 := NewInLineOneToOne() + il121.Note = note + il121.InLine = inline + _, err = dORM.Insert(il121) + throwFail(t, err) + throwFail(t, AssertIs(il121.ID, 1)) + + il := NewInLineOneToOne() + err = dORM.QueryTable(il).Filter("Id", 1).RelatedSel().One(il) + + throwFail(t, err) + throwFail(t, AssertIs(il.Note, note)) + throwFail(t, AssertIs(il.InLine.ID, id)) + throwFail(t, AssertIs(il.InLine.Name, name)) + throwFail(t, AssertIs(il.InLine.Email, email)) + + rinline := NewInLine() + err = dORM.QueryTable(rinline).Filter("InLineOneToOne__Id", 1).One(rinline) + + throwFail(t, err) + throwFail(t, AssertIs(rinline.ID, id)) + throwFail(t, AssertIs(rinline.Name, name)) + throwFail(t, AssertIs(rinline.Email, email)) +} + +func TestIntegerPk(t *testing.T) { + its := []IntegerPk{ + {Id: math.MinInt64, Value: "-"}, + {Id: 0, Value: "0"}, + {Id: math.MaxInt64, Value: "+"}, + } + + num, err := dORM.InsertMulti(len(its), its) + throwFail(t, err) + throwFail(t, AssertIs(num, len(its))) + + for _, intPk := range its { + out := IntegerPk{Id: intPk.Id} + err = dORM.Read(&out) + throwFail(t, err) + throwFail(t, AssertIs(out.Value, intPk.Value)) + } +} + +func TestInsertAuto(t *testing.T) { + u := &User{ + UserName: "autoPre", + Email: "autoPre@gmail.com", + } + + id, err := dORM.Insert(u) + throwFail(t, err) + + id += 100 + su := &User{ + ID: int(id), + UserName: "auto", + Email: "auto@gmail.com", + } + + nid, err := dORM.Insert(su) + throwFail(t, err) + throwFail(t, AssertIs(nid, id)) + + users := []User{ + {ID: int(id + 100), UserName: "auto_100"}, + {ID: int(id + 110), UserName: "auto_110"}, + {ID: int(id + 120), UserName: "auto_120"}, + } + num, err := dORM.InsertMulti(100, users) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + u = &User{ + UserName: "auto_121", + } + + nid, err = dORM.Insert(u) + throwFail(t, err) + throwFail(t, AssertIs(nid, id+120+1)) +} + +func TestUintPk(t *testing.T) { + name := "go" + u := &UintPk{ + Id: 8, + Name: name, + } + + created, pk, err := dORM.ReadOrCreate(u, "Id") + throwFail(t, err) + throwFail(t, AssertIs(created, true)) + throwFail(t, AssertIs(u.Name, name)) + + nu := &UintPk{Id: 8} + created, pk, err = dORM.ReadOrCreate(nu, "Id") + throwFail(t, err) + throwFail(t, AssertIs(created, false)) + throwFail(t, AssertIs(nu.Id, u.Id)) + throwFail(t, AssertIs(pk, u.Id)) + throwFail(t, AssertIs(nu.Name, name)) + + dORM.Delete(u) +} diff --git a/orm/types.go b/orm/types.go index 41933dd1..cb55e71a 100644 --- a/orm/types.go +++ b/orm/types.go @@ -420,4 +420,5 @@ type dbBaser interface { ShowColumnsQuery(string) string IndexExists(dbQuerier, string, string) bool collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) + setval(dbQuerier, *modelInfo, []string) error } diff --git a/parser.go b/parser.go index f23f4720..24f7549b 100644 --- a/parser.go +++ b/parser.go @@ -27,6 +27,7 @@ import ( "sort" "strings" + "github.com/astaxie/beego/logs" "github.com/astaxie/beego/utils" ) @@ -58,7 +59,7 @@ func parserPkg(pkgRealpath, pkgpath string) error { rep := strings.NewReplacer("/", "_", ".", "_") commentFilename = coomentPrefix + rep.Replace(pkgpath) + ".go" if !compareFile(pkgRealpath) { - Info(pkgRealpath + " has not changed, not reloading") + logs.Info(pkgRealpath + " no changed") return nil } genInfoList = make(map[string][]ControllerComments) @@ -131,7 +132,7 @@ func parserComments(comments *ast.CommentGroup, funcName, controllerName, pkgpat func genRouterCode() { os.Mkdir(path.Join(AppPath, "routers"), 0755) - Info("generate router from comments") + logs.Info("generate router from comments") var ( globalinfo string sortKey []string diff --git a/plugins/apiauth/apiauth.go b/plugins/apiauth/apiauth.go index 8af08088..970367c9 100644 --- a/plugins/apiauth/apiauth.go +++ b/plugins/apiauth/apiauth.go @@ -35,7 +35,7 @@ // // beego.InsertFilter("*", beego.BeforeRouter,apiauth.APISecretAuth(getAppSecret, 360)) // -// Infomation: +// Information: // // In the request user should include these params in the query // diff --git a/router.go b/router.go index d0bf534f..960cd104 100644 --- a/router.go +++ b/router.go @@ -28,6 +28,7 @@ import ( "time" beecontext "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" "github.com/astaxie/beego/toolbox" "github.com/astaxie/beego/utils" ) @@ -114,7 +115,7 @@ type controllerInfo struct { type ControllerRegister struct { routers map[string]*Tree enableFilter bool - filters map[int][]*FilterRouter + filters [FinishRouter + 1][]*FilterRouter pool sync.Pool } @@ -122,7 +123,6 @@ type ControllerRegister struct { func NewControllerRegister() *ControllerRegister { cr := &ControllerRegister{ routers: make(map[string]*Tree), - filters: make(map[int][]*FilterRouter), } cr.pool.New = func() interface{} { return beecontext.NewContext() @@ -408,7 +408,6 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) // InsertFilter Add a FilterFunc with pattern rule and action constant. // The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { - mr := new(FilterRouter) mr.tree = NewTree() mr.pattern = pattern @@ -426,9 +425,13 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter } // add Filter into -func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) error { - p.filters[pos] = append(p.filters[pos], mr) +func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) { + if pos < BeforeStatic || pos > FinishRouter { + err = fmt.Errorf("can not find your filter postion") + return + } p.enableFilter = true + p.filters[pos] = append(p.filters[pos], mr) return nil } @@ -437,11 +440,11 @@ func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) error func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) string { paths := strings.Split(endpoint, ".") if len(paths) <= 1 { - Warn("urlfor endpoint must like path.controller.method") + logs.Warn("urlfor endpoint must like path.controller.method") return "" } if len(values)%2 != 0 { - Warn("urlfor params must key-value pair") + logs.Warn("urlfor params must key-value pair") return "" } params := make(map[string]string) @@ -577,20 +580,16 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin return false, "" } -func (p *ControllerRegister) execFilter(context *beecontext.Context, pos int, urlPath string) (started bool) { - if p.enableFilter { - if l, ok := p.filters[pos]; ok { - for _, filterR := range l { - if filterR.returnOnOutput && context.ResponseWriter.Started { - return true - } - if ok := filterR.ValidRouter(urlPath, context); ok { - filterR.filterFunc(context) - } - if filterR.returnOnOutput && context.ResponseWriter.Started { - return true - } - } +func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath string, pos int) (started bool) { + for _, filterR := range p.filters[pos] { + if filterR.returnOnOutput && context.ResponseWriter.Started { + return true + } + if ok := filterR.ValidRouter(urlPath, context); ok { + filterR.filterFunc(context) + } + if filterR.returnOnOutput && context.ResponseWriter.Started { + return true } } return false @@ -617,11 +616,10 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) context.Output.Header("Server", BConfig.ServerName) } - var urlPath string + var urlPath = r.URL.Path + if !BConfig.RouterCaseSensitive { - urlPath = strings.ToLower(r.URL.Path) - } else { - urlPath = r.URL.Path + urlPath = strings.ToLower(urlPath) } // filter wrong http method @@ -631,11 +629,12 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) } // filter for static file - if p.execFilter(context, BeforeStatic, urlPath) { + if len(p.filters[BeforeStatic]) > 0 && p.execFilter(context, urlPath, BeforeStatic) { goto Admin } serverStaticRouter(context) + if context.ResponseWriter.Started { findRouter = true goto Admin @@ -653,9 +652,9 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) var err error context.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) if err != nil { - Error(err) + logs.Error(err) exception("503", context) - return + goto Admin } defer func() { if context.Input.CruSession != nil { @@ -663,8 +662,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) } }() } - - if p.execFilter(context, BeforeRouter, urlPath) { + if len(p.filters[BeforeRouter]) > 0 && p.execFilter(context, urlPath, BeforeRouter) { goto Admin } @@ -693,7 +691,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) if findRouter { //execute middleware filters - if p.execFilter(context, BeforeExec, urlPath) { + if len(p.filters[BeforeExec]) > 0 && p.execFilter(context, urlPath, BeforeExec) { goto Admin } isRunnable := false @@ -783,7 +781,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) if !context.ResponseWriter.Started && context.Output.Status == 0 { if BConfig.WebConfig.AutoRender { if err := execController.Render(); err != nil { - panic(err) + logs.Error(err) } } } @@ -794,17 +792,18 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) } //execute middleware filters - if p.execFilter(context, AfterExec, urlPath) { + if len(p.filters[AfterExec]) > 0 && p.execFilter(context, urlPath, AfterExec) { goto Admin } } - - p.execFilter(context, FinishRouter, urlPath) + if len(p.filters[FinishRouter]) > 0 && p.execFilter(context, urlPath, FinishRouter) { + goto Admin + } Admin: - timeDur := time.Since(startTime) //admin module record QPS if BConfig.Listen.EnableAdmin { + timeDur := time.Since(startTime) if FilterMonitorFunc(r.Method, r.URL.Path, timeDur) { if runRouter != nil { go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, runRouter.Name(), timeDur) @@ -815,6 +814,7 @@ Admin: } if BConfig.RunMode == DEV || BConfig.Log.AccessLogs { + timeDur := time.Since(startTime) var devInfo string if findRouter { if routerInfo != nil { @@ -826,7 +826,7 @@ Admin: devInfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s |", r.Method, r.URL.Path, timeDur.String(), "notmatch") } if DefaultAccessLogFilter == nil || !DefaultAccessLogFilter.Filter(context) { - Debug(devInfo) + logs.Debug(devInfo) } } @@ -843,27 +843,26 @@ func (p *ControllerRegister) recoverPanic(context *beecontext.Context) { } if !BConfig.RecoverPanic { panic(err) - } else { - if BConfig.EnableErrorsShow { - if _, ok := ErrorMaps[fmt.Sprint(err)]; ok { - exception(fmt.Sprint(err), context) - return - } + } + if BConfig.EnableErrorsShow { + if _, ok := ErrorMaps[fmt.Sprint(err)]; ok { + exception(fmt.Sprint(err), context) + return } - var stack string - Critical("the request url is ", context.Input.URL()) - Critical("Handler crashed with error", err) - for i := 1; ; i++ { - _, file, line, ok := runtime.Caller(i) - if !ok { - break - } - Critical(fmt.Sprintf("%s:%d", file, line)) - stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line)) - } - if BConfig.RunMode == DEV { - showErr(err, context, stack) + } + var stack string + logs.Critical("the request url is ", context.Input.URL()) + logs.Critical("Handler crashed with error", err) + for i := 1; ; i++ { + _, file, line, ok := runtime.Caller(i) + if !ok { + break } + logs.Critical(fmt.Sprintf("%s:%d", file, line)) + stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line)) + } + if BConfig.RunMode == DEV { + showErr(err, context, stack) } } } diff --git a/router_test.go b/router_test.go index f26f0c86..9f11286c 100644 --- a/router_test.go +++ b/router_test.go @@ -21,6 +21,7 @@ import ( "testing" "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" ) type TestController struct { @@ -94,7 +95,7 @@ func TestUrlFor(t *testing.T) { handler.Add("/api/list", &TestController{}, "*:List") handler.Add("/person/:last/:first", &TestController{}, "*:Param") if a := handler.URLFor("TestController.List"); a != "/api/list" { - Info(a) + logs.Info(a) t.Errorf("TestController.List must equal to /api/list") } if a := handler.URLFor("TestController.Param", ":last", "xie", ":first", "asta"); a != "/person/xie/asta" { @@ -120,24 +121,24 @@ func TestUrlFor2(t *testing.T) { handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, "*:Param") handler.Add("/:year:int/:month:int/:title/:entid", &TestController{}) if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" { - Info(handler.URLFor("TestController.GetURL")) + logs.Info(handler.URLFor("TestController.GetURL")) t.Errorf("TestController.List must equal to /v1/astaxie/edit") } if handler.URLFor("TestController.List", ":v", "za", ":id", "12", ":page", "123") != "/v1/za/cms_12_123.html" { - Info(handler.URLFor("TestController.List")) + logs.Info(handler.URLFor("TestController.List")) t.Errorf("TestController.List must equal to /v1/za/cms_12_123.html") } if handler.URLFor("TestController.Param", ":v", "za", ":id", "12", ":page", "123") != "/v1/za_cms/ttt_12_123.html" { - Info(handler.URLFor("TestController.Param")) + logs.Info(handler.URLFor("TestController.Param")) t.Errorf("TestController.List must equal to /v1/za_cms/ttt_12_123.html") } if handler.URLFor("TestController.Get", ":year", "1111", ":month", "11", ":title", "aaaa", ":entid", "aaaa") != "/1111/11/aaaa/aaaa" { - Info(handler.URLFor("TestController.Get")) + logs.Info(handler.URLFor("TestController.Get")) t.Errorf("TestController.Get must equal to /1111/11/aaaa/aaaa") } } diff --git a/session/mysql/sess_mysql.go b/session/mysql/sess_mysql.go index 969d26c9..838ec669 100644 --- a/session/mysql/sess_mysql.go +++ b/session/mysql/sess_mysql.go @@ -115,7 +115,6 @@ func (st *SessionStore) SessionRelease(w http.ResponseWriter) { } st.c.Exec("UPDATE "+TableName+" set `session_data`=?, `session_expiry`=? where session_key=?", b, time.Now().Unix(), st.sid) - } // Provider mysql session provider diff --git a/session/sess_file.go b/session/sess_file.go index 9265b030..91acfcd4 100644 --- a/session/sess_file.go +++ b/session/sess_file.go @@ -16,7 +16,6 @@ package session import ( "errors" - "fmt" "io" "io/ioutil" "net/http" @@ -82,14 +81,17 @@ func (fs *FileSessionStore) SessionID() string { func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) { b, err := EncodeGob(fs.values) if err != nil { + SLogger.Println(err) return } _, err = os.Stat(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid)) var f *os.File if err == nil { f, err = os.OpenFile(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid), os.O_RDWR, 0777) + SLogger.Println(err) } else if os.IsNotExist(err) { f, err = os.Create(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid)) + SLogger.Println(err) } else { return } @@ -123,7 +125,7 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) { err := os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777) if err != nil { - println(err.Error()) + SLogger.Println(err.Error()) } _, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) var f *os.File @@ -191,7 +193,7 @@ func (fp *FileProvider) SessionAll() int { return a.visit(path, f, err) }) if err != nil { - fmt.Printf("filepath.Walk() returned %v\n", err) + SLogger.Printf("filepath.Walk() returned %v\n", err) return 0 } return a.total @@ -205,11 +207,11 @@ func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) { err := os.MkdirAll(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])), 0777) if err != nil { - println(err.Error()) + SLogger.Println(err.Error()) } err = os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777) if err != nil { - println(err.Error()) + SLogger.Println(err.Error()) } _, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) var newf *os.File diff --git a/session/session.go b/session/session.go index 9fe99a17..2c4b9351 100644 --- a/session/session.go +++ b/session/session.go @@ -32,8 +32,11 @@ import ( "encoding/hex" "encoding/json" "fmt" + "io" + "log" "net/http" "net/url" + "os" "time" ) @@ -61,6 +64,9 @@ type Provider interface { var provides = make(map[string]Provider) +// SLogger a helpful variable to log information about session +var SLogger = NewSessionLog(os.Stderr) + // Register makes a session provide available by the provided name. // If Register is called twice with the same name or if driver is nil, // it panics. @@ -296,3 +302,15 @@ func (manager *Manager) isSecure(req *http.Request) bool { } return true } + +// Log implement the log.Logger +type Log struct { + *log.Logger +} + +// NewSessionLog set io.Writer to create a Logger for session. +func NewSessionLog(out io.Writer) *Log { + sl := new(Log) + sl.Logger = log.New(out, "[SESSION]", 1e9) + return sl +} diff --git a/staticfile.go b/staticfile.go index 0aad2c81..8a1bc57b 100644 --- a/staticfile.go +++ b/staticfile.go @@ -27,6 +27,7 @@ import ( "time" "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" ) var errNotStaticRequest = errors.New("request not a static file request") @@ -48,14 +49,19 @@ func serverStaticRouter(ctx *context.Context) { if filePath == "" || fileInfo == nil { if BConfig.RunMode == DEV { - Warn("Can't find/open the file:", filePath, err) + logs.Warn("Can't find/open the file:", filePath, err) } http.NotFound(ctx.ResponseWriter, ctx.Request) return } if fileInfo.IsDir() { - //serveFile will list dir - http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath) + requestURL := ctx.Input.URL() + if requestURL[len(requestURL)-1] != '/' { + ctx.Redirect(302, requestURL+"/") + } else { + //serveFile will list dir + http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath) + } return } @@ -67,7 +73,7 @@ func serverStaticRouter(ctx *context.Context) { b, n, sch, err := openFile(filePath, fileInfo, acceptEncoding) if err != nil { if BConfig.RunMode == DEV { - Warn("Can't compress the file:", filePath, err) + logs.Warn("Can't compress the file:", filePath, err) } http.NotFound(ctx.ResponseWriter, ctx.Request) return diff --git a/template.go b/template.go index e6c43f87..a7c719e8 100644 --- a/template.go +++ b/template.go @@ -26,6 +26,7 @@ import ( "strings" "sync" + "github.com/astaxie/beego/logs" "github.com/astaxie/beego/utils" ) @@ -46,7 +47,7 @@ func executeTemplate(wr io.Writer, name string, data interface{}) error { if t, ok := beeTemplates[name]; ok { err := t.ExecuteTemplate(wr, name, data) if err != nil { - Trace("template Execute err:", err) + logs.Trace("template Execute err:", err) } return err } @@ -162,7 +163,7 @@ func BuildTemplate(dir string, files ...string) error { templatesLock.Lock() t, err := getTemplate(self.root, file, v...) if err != nil { - Trace("parse template err:", file, err) + logs.Trace("parse template err:", file, err) } else { beeTemplates[file] = t } @@ -240,7 +241,7 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others var subMods1 [][]string t, subMods1, err = getTplDeep(root, otherFile, "", t) if err != nil { - Trace("template parse file err:", err) + logs.Trace("template parse file err:", err) } else if subMods1 != nil && len(subMods1) > 0 { t, err = _getTemplate(t, root, subMods1, others...) } @@ -261,7 +262,7 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others var subMods1 [][]string t, subMods1, err = getTplDeep(root, otherFile, "", t) if err != nil { - Trace("template parse file err:", err) + logs.Trace("template parse file err:", err) } else if subMods1 != nil && len(subMods1) > 0 { t, err = _getTemplate(t, root, subMods1, others...) } diff --git a/tree.go b/tree.go index 594e9999..25b78e50 100644 --- a/tree.go +++ b/tree.go @@ -141,7 +141,7 @@ func (t *Tree) addtree(segments []string, tree *Tree, wildcards []string, reg st regexpStr = "([^.]+).(.+)" params = params[1:] } else { - for _ = range params { + for range params { regexpStr = "([^/]+)/" + regexpStr } } @@ -254,7 +254,7 @@ func (t *Tree) addseg(segments []string, route interface{}, wildcards []string, regexpStr = "/([^.]+).(.+)" params = params[1:] } else { - for _ = range params { + for range params { regexpStr = "/([^/]+)" + regexpStr } } @@ -389,7 +389,7 @@ type leafInfo struct { func (leaf *leafInfo) match(wildcardValues []string, ctx *context.Context) (ok bool) { //fmt.Println("Leaf:", wildcardValues, leaf.wildcards, leaf.regexps) if leaf.regexps == nil { - if len(wildcardValues) == 0 { // static path + if len(wildcardValues) == 0 && len(leaf.wildcards) == 0 { // static path return true } // match * diff --git a/tree_test.go b/tree_test.go index 531df046..81ff7edd 100644 --- a/tree_test.go +++ b/tree_test.go @@ -97,6 +97,21 @@ func TestTreeRouters(t *testing.T) { } } +func TestStaticPath(t *testing.T) { + tr := NewTree() + tr.AddRouter("/topic/:id", "wildcard") + tr.AddRouter("/topic", "static") + ctx := context.NewContext() + obj := tr.Match("/topic", ctx) + if obj == nil || obj.(string) != "static" { + t.Fatal("/topic is a static route") + } + obj = tr.Match("/topic/1", ctx) + if obj == nil || obj.(string) != "wildcard" { + t.Fatal("/topic/1 is a wildcard route") + } +} + func TestAddTree(t *testing.T) { tr := NewTree() tr.AddRouter("/shop/:id/account", "astaxie") diff --git a/utils/captcha/captcha.go b/utils/captcha/captcha.go index 1a4a6edc..42ac70d3 100644 --- a/utils/captcha/captcha.go +++ b/utils/captcha/captcha.go @@ -69,6 +69,7 @@ import ( "github.com/astaxie/beego" "github.com/astaxie/beego/cache" "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" "github.com/astaxie/beego/utils" ) @@ -139,7 +140,7 @@ func (c *Captcha) Handler(ctx *context.Context) { if err := c.store.Put(key, chars, c.Expiration); err != nil { ctx.Output.SetStatus(500) ctx.WriteString("captcha reload error") - beego.Error("Reload Create Captcha Error:", err) + logs.Error("Reload Create Captcha Error:", err) return } } else { @@ -154,7 +155,7 @@ func (c *Captcha) Handler(ctx *context.Context) { img := NewImage(chars, c.StdWidth, c.StdHeight) if _, err := img.WriteTo(ctx.ResponseWriter); err != nil { - beego.Error("Write Captcha Image Error:", err) + logs.Error("Write Captcha Image Error:", err) } } @@ -162,7 +163,7 @@ func (c *Captcha) Handler(ctx *context.Context) { func (c *Captcha) CreateCaptchaHTML() template.HTML { value, err := c.CreateCaptcha() if err != nil { - beego.Error("Create Captcha Error:", err) + logs.Error("Create Captcha Error:", err) return "" } diff --git a/utils/pagination/controller.go b/utils/pagination/controller.go index 1d99cac5..2f022d0c 100644 --- a/utils/pagination/controller.go +++ b/utils/pagination/controller.go @@ -18,7 +18,7 @@ import ( "github.com/astaxie/beego/context" ) -// SetPaginator Instantiates a Paginator and assigns it to context.Input.Data["paginator"]. +// SetPaginator Instantiates a Paginator and assigns it to context.Input.Data("paginator"). func SetPaginator(context *context.Context, per int, nums int64) (paginator *Paginator) { paginator = NewPaginator(context.Request, per, nums) context.Input.SetData("paginator", &paginator)