diff --git a/.gosimpleignore b/.gosimpleignore deleted file mode 100644 index 84df9b95..00000000 --- a/.gosimpleignore +++ /dev/null @@ -1,4 +0,0 @@ -github.com/astaxie/beego/*/*:S1012 -github.com/astaxie/beego/*:S1012 -github.com/astaxie/beego/*/*:S1007 -github.com/astaxie/beego/*:S1007 \ No newline at end of file diff --git a/.travis.yml b/.travis.yml index ed04c9d1..1bb121a2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,6 @@ language: go go: - - "1.10.x" - "1.11.x" services: - redis-server @@ -9,9 +8,19 @@ services: - postgresql - memcached env: - - ORM_DRIVER=sqlite3 ORM_SOURCE=$TRAVIS_BUILD_DIR/orm_test.db - - ORM_DRIVER=postgres ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" + global: + - GO_REPO_FULLNAME="github.com/astaxie/beego" + matrix: + - ORM_DRIVER=sqlite3 ORM_SOURCE=$TRAVIS_BUILD_DIR/orm_test.db + - ORM_DRIVER=postgres ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" before_install: + # link the local repo with ${GOPATH}/src// + - GO_REPO_NAMESPACE=${GO_REPO_FULLNAME%/*} + # relies on GOPATH to contain only one directory... + - mkdir -p ${GOPATH}/src/${GO_REPO_NAMESPACE} + - ln -sv ${TRAVIS_BUILD_DIR} ${GOPATH}/src/${GO_REPO_FULLNAME} + - cd ${GOPATH}/src/${GO_REPO_FULLNAME} + # get and build ssdb - git clone git://github.com/ideawu/ssdb.git - cd ssdb - make @@ -35,7 +44,9 @@ install: - go get github.com/Knetic/govaluate - go get github.com/casbin/casbin - go get github.com/elazarl/go-bindata-assetfs - - go get -u honnef.co/go/tools/cmd/gosimple + - go get github.com/OwnLocal/goes + - go get github.com/shiena/ansicolor + - go get -u honnef.co/go/tools/cmd/staticcheck - go get -u github.com/mdempsky/unconvert - go get -u github.com/gordonklaus/ineffassign - go get -u github.com/golang/lint/golint @@ -54,7 +65,7 @@ after_script: - rm -rf ./res/var/* script: - go test -v ./... - - gosimple -ignore "$(cat .gosimpleignore)" $(go list ./... | grep -v /vendor/) + - staticcheck -show-ignored -checks "-ST1017,-U1000,-ST1005,-S1034,-S1012,-SA4006,-SA6005,-SA1019,-SA1024" - unconvert $(go list ./... | grep -v /vendor/) - ineffassign . - find . ! \( -path './vendor' -prune \) -type f -name '*.go' -print0 | xargs -0 gofmt -l -s diff --git a/app.go b/app.go index 32ff159d..d9e85e9b 100644 --- a/app.go +++ b/app.go @@ -176,7 +176,7 @@ func (app *App) Run(mws ...MiddleWare) { if BConfig.Listen.HTTPSPort != 0 { app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) } else if BConfig.Listen.EnableHTTP { - BeeLogger.Info("Start https server error, conflict with http. Please reset https port") + logs.Info("Start https server error, conflict with http. Please reset https port") return } logs.Info("https server Running on https://%s", app.Server.Addr) @@ -192,7 +192,7 @@ func (app *App) Run(mws ...MiddleWare) { pool := x509.NewCertPool() data, err := ioutil.ReadFile(BConfig.Listen.TrustCaFile) if err != nil { - BeeLogger.Info("MutualHTTPS should provide TrustCaFile") + logs.Info("MutualHTTPS should provide TrustCaFile") return } pool.AppendCertsFromPEM(data) diff --git a/beego.go b/beego.go index ff89f2f5..66b19f36 100644 --- a/beego.go +++ b/beego.go @@ -23,7 +23,7 @@ import ( const ( // VERSION represent beego web framework version. - VERSION = "1.11.1" + VERSION = "1.12.0" // DEV is for develop DEV = "dev" diff --git a/cache/cache_test.go b/cache/cache_test.go index 9ceb606a..470c0a43 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -16,10 +16,33 @@ package cache import ( "os" + "sync" "testing" "time" ) +func TestCacheIncr(t *testing.T) { + bm, err := NewCache("memory", `{"interval":20}`) + if err != nil { + t.Error("init err") + } + //timeoutDuration := 10 * time.Second + + bm.Put("edwardhey", 0, time.Second*20) + wg := sync.WaitGroup{} + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + bm.Incr("edwardhey") + }() + } + wg.Wait() + if bm.Get("edwardhey").(int) != 10 { + t.Error("Incr err") + } +} + func TestCache(t *testing.T) { bm, err := NewCache("memory", `{"interval":20}`) if err != nil { @@ -98,7 +121,7 @@ func TestCache(t *testing.T) { } func TestFileCache(t *testing.T) { - bm, err := NewCache("file", `{"CachePath":"cache","FileSuffix":".bin","DirectoryLevel":2,"EmbedExpiry":0}`) + bm, err := NewCache("file", `{"CachePath":"cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`) if err != nil { t.Error("init err") } diff --git a/cache/file.go b/cache/file.go index 691ce7cd..6f12d3ee 100644 --- a/cache/file.go +++ b/cache/file.go @@ -62,11 +62,14 @@ func NewFileCache() Cache { } // StartAndGC will start and begin gc for file cache. -// the config need to be like {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":2,"EmbedExpiry":0} +// the config need to be like {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"} func (fc *FileCache) StartAndGC(config string) error { - var cfg map[string]string - json.Unmarshal([]byte(config), &cfg) + cfg := make(map[string]string) + err := json.Unmarshal([]byte(config), &cfg) + if err != nil { + return err + } if _, ok := cfg["CachePath"]; !ok { cfg["CachePath"] = FileCachePath } @@ -142,12 +145,12 @@ func (fc *FileCache) GetMulti(keys []string) []interface{} { // Put value into file cache. // timeout means how long to keep this file, unit of ms. -// if timeout equals FileCacheEmbedExpiry(default is 0), cache this item forever. +// if timeout equals fc.EmbedExpiry(default is 0), cache this item forever. func (fc *FileCache) Put(key string, val interface{}, timeout time.Duration) error { gob.Register(val) item := FileCacheItem{Data: val} - if timeout == FileCacheEmbedExpiry { + if timeout == time.Duration(fc.EmbedExpiry) { item.Expired = time.Now().Add((86400 * 365 * 10) * time.Second) // ten years } else { item.Expired = time.Now().Add(timeout) @@ -179,7 +182,7 @@ func (fc *FileCache) Incr(key string) error { } else { incr = data.(int) + 1 } - fc.Put(key, incr, FileCacheEmbedExpiry) + fc.Put(key, incr, time.Duration(fc.EmbedExpiry)) return nil } @@ -192,7 +195,7 @@ func (fc *FileCache) Decr(key string) error { } else { decr = data.(int) - 1 } - fc.Put(key, decr, FileCacheEmbedExpiry) + fc.Put(key, decr, time.Duration(fc.EmbedExpiry)) return nil } diff --git a/cache/memcache/memcache.go b/cache/memcache/memcache.go index 0624f5fa..19116bfa 100644 --- a/cache/memcache/memcache.go +++ b/cache/memcache/memcache.go @@ -146,7 +146,7 @@ func (rc *Cache) IsExist(key string) bool { } } _, err := rc.conn.Get(key) - return !(err != nil) + return err == nil } // ClearAll clear all cached in memcache. diff --git a/cache/memory.go b/cache/memory.go index cb9802ab..1fec2eff 100644 --- a/cache/memory.go +++ b/cache/memory.go @@ -110,25 +110,25 @@ func (bc *MemoryCache) Delete(name string) error { // Incr increase cache counter in memory. // it supports int,int32,int64,uint,uint32,uint64. func (bc *MemoryCache) Incr(key string) error { - bc.RLock() - defer bc.RUnlock() + bc.Lock() + defer bc.Unlock() itm, ok := bc.items[key] if !ok { return errors.New("key not exist") } - switch itm.val.(type) { + switch val := itm.val.(type) { case int: - itm.val = itm.val.(int) + 1 + itm.val = val + 1 case int32: - itm.val = itm.val.(int32) + 1 + itm.val = val + 1 case int64: - itm.val = itm.val.(int64) + 1 + itm.val = val + 1 case uint: - itm.val = itm.val.(uint) + 1 + itm.val = val + 1 case uint32: - itm.val = itm.val.(uint32) + 1 + itm.val = val + 1 case uint64: - itm.val = itm.val.(uint64) + 1 + itm.val = val + 1 default: return errors.New("item val is not (u)int (u)int32 (u)int64") } @@ -137,34 +137,34 @@ func (bc *MemoryCache) Incr(key string) error { // Decr decrease counter in memory. func (bc *MemoryCache) Decr(key string) error { - bc.RLock() - defer bc.RUnlock() + bc.Lock() + defer bc.Unlock() itm, ok := bc.items[key] if !ok { return errors.New("key not exist") } - switch itm.val.(type) { + switch val := itm.val.(type) { case int: - itm.val = itm.val.(int) - 1 + itm.val = val - 1 case int64: - itm.val = itm.val.(int64) - 1 + itm.val = val - 1 case int32: - itm.val = itm.val.(int32) - 1 + itm.val = val - 1 case uint: - if itm.val.(uint) > 0 { - itm.val = itm.val.(uint) - 1 + if val > 0 { + itm.val = val - 1 } else { return errors.New("item val is less than 0") } case uint32: - if itm.val.(uint32) > 0 { - itm.val = itm.val.(uint32) - 1 + if val > 0 { + itm.val = val - 1 } else { return errors.New("item val is less than 0") } case uint64: - if itm.val.(uint64) > 0 { - itm.val = itm.val.(uint64) - 1 + if val > 0 { + itm.val = val - 1 } else { return errors.New("item val is less than 0") } diff --git a/config/yaml/yaml.go b/config/yaml/yaml.go index 7bf1335c..5def2da3 100644 --- a/config/yaml/yaml.go +++ b/config/yaml/yaml.go @@ -97,7 +97,7 @@ func parseYML(buf []byte) (cnf map[string]interface{}, err error) { } } - data, err := goyaml2.Read(bytes.NewBuffer(buf)) + data, err := goyaml2.Read(bytes.NewReader(buf)) if err != nil { log.Println("Goyaml2 ERR>", string(buf), err) return diff --git a/context/input.go b/context/input.go index 81952158..76040616 100644 --- a/context/input.go +++ b/context/input.go @@ -27,6 +27,7 @@ import ( "regexp" "strconv" "strings" + "sync" "github.com/astaxie/beego/session" ) @@ -49,6 +50,7 @@ type BeegoInput struct { pnames []string pvalues []string data map[interface{}]interface{} // store some values in this context when calling context in filter or controller. + dataLock sync.RWMutex RequestBody []byte RunMethod string RunController reflect.Type @@ -204,6 +206,7 @@ func (input *BeegoInput) AcceptsXML() bool { func (input *BeegoInput) AcceptsJSON() bool { return acceptsJSONRegex.MatchString(input.Header("Accept")) } + // AcceptsYAML Checks if request accepts json response func (input *BeegoInput) AcceptsYAML() bool { return acceptsYAMLRegex.MatchString(input.Header("Accept")) @@ -377,6 +380,8 @@ func (input *BeegoInput) CopyBody(MaxMemory int64) []byte { // Data return the implicit data in the input func (input *BeegoInput) Data() map[interface{}]interface{} { + input.dataLock.Lock() + defer input.dataLock.Unlock() if input.data == nil { input.data = make(map[interface{}]interface{}) } @@ -385,6 +390,8 @@ func (input *BeegoInput) Data() map[interface{}]interface{} { // GetData returns the stored data in this context. func (input *BeegoInput) GetData(key interface{}) interface{} { + input.dataLock.Lock() + defer input.dataLock.Unlock() if v, ok := input.data[key]; ok { return v } @@ -394,6 +401,8 @@ func (input *BeegoInput) GetData(key interface{}) interface{} { // SetData stores data with given key in this context. // This data are only available in this context. func (input *BeegoInput) SetData(key, val interface{}) { + input.dataLock.Lock() + defer input.dataLock.Unlock() if input.data == nil { input.data = make(map[interface{}]interface{}) } diff --git a/context/output.go b/context/output.go index 3e277ab2..238dcf45 100644 --- a/context/output.go +++ b/context/output.go @@ -30,7 +30,8 @@ import ( "strconv" "strings" "time" - "gopkg.in/yaml.v2" + + yaml "gopkg.in/yaml.v2" ) // BeegoOutput does work for sending response header. @@ -203,7 +204,6 @@ func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, encoding bool) return output.Body(content) } - // YAML writes yaml to response body. func (output *BeegoOutput) YAML(data interface{}) error { output.Header("Content-Type", "application/x-yaml; charset=utf-8") @@ -288,7 +288,20 @@ func (output *BeegoOutput) Download(file string, filename ...string) { } else { fName = filepath.Base(file) } - output.Header("Content-Disposition", "attachment; filename="+url.PathEscape(fName)) + //https://tools.ietf.org/html/rfc6266#section-4.3 + fn := url.PathEscape(fName) + if fName == fn { + fn = "filename=" + fn + } else { + /** + The parameters "filename" and "filename*" differ only in that + "filename*" uses the encoding defined in [RFC5987], allowing the use + of characters not present in the ISO-8859-1 character set + ([ISO-8859-1]). + */ + fn = "filename=" + fName + "; filename*=utf-8''" + fn + } + output.Header("Content-Disposition", "attachment; "+fn) output.Header("Content-Description", "File Transfer") output.Header("Content-Type", "application/octet-stream") output.Header("Content-Transfer-Encoding", "binary") diff --git a/controller.go b/controller.go index 4b8f9807..0e8853b3 100644 --- a/controller.go +++ b/controller.go @@ -17,6 +17,7 @@ package beego import ( "bytes" "errors" + "fmt" "html/template" "io" "mime/multipart" @@ -34,7 +35,7 @@ import ( var ( // ErrAbort custom error when user stop request handler manually. - ErrAbort = errors.New("User stop run") + ErrAbort = errors.New("user stop run") // GlobalControllerRouter store comments with controller. pkgpath+controller:comments GlobalControllerRouter = make(map[string][]ControllerComments) ) @@ -93,7 +94,6 @@ type Controller struct { controllerName string actionName string methodMapping map[string]func() //method:routertree - gotofunc string AppController interface{} // template data @@ -125,6 +125,7 @@ type ControllerInterface interface { Head() Patch() Options() + Trace() Finish() Render() error XSRFToken() string @@ -156,37 +157,59 @@ func (c *Controller) Finish() {} // Get adds a request function to handle GET request. func (c *Controller) Get() { - http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405) + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) } // Post adds a request function to handle POST request. func (c *Controller) Post() { - http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405) + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) } // Delete adds a request function to handle DELETE request. func (c *Controller) Delete() { - http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405) + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) } // Put adds a request function to handle PUT request. func (c *Controller) Put() { - http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405) + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) } // Head adds a request function to handle HEAD request. func (c *Controller) Head() { - http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405) + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) } // Patch adds a request function to handle PATCH request. func (c *Controller) Patch() { - http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405) + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) } // Options adds a request function to handle OPTIONS request. func (c *Controller) Options() { - http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405) + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Trace adds a request function to handle Trace request. +// this method SHOULD NOT be overridden. +// https://tools.ietf.org/html/rfc7231#section-4.3.8 +// The TRACE method requests a remote, application-level loop-back of +// the request message. The final recipient of the request SHOULD +// reflect the message received, excluding some fields described below, +// back to the client as the message body of a 200 (OK) response with a +// Content-Type of "message/http" (Section 8.3.1 of [RFC7230]). +func (c *Controller) Trace() { + ts := func(h http.Header) (hs string) { + for k, v := range h { + hs += fmt.Sprintf("\r\n%s: %s", k, v) + } + return + } + hs := fmt.Sprintf("\r\nTRACE %s %s%s\r\n", c.Ctx.Request.RequestURI, c.Ctx.Request.Proto, ts(c.Ctx.Request.Header)) + c.Ctx.Output.Header("Content-Type", "message/http") + c.Ctx.Output.Header("Content-Length", fmt.Sprint(len(hs))) + c.Ctx.Output.Header("Cache-Control", "no-cache, no-store, must-revalidate") + c.Ctx.WriteString(hs) } // HandlerFunc call function with the name @@ -292,7 +315,7 @@ func (c *Controller) viewPath() string { // Redirect sends the redirection response to url with status code. func (c *Controller) Redirect(url string, code int) { - logAccess(c.Ctx, nil, code) + LogAccess(c.Ctx, nil, code) c.Ctx.Redirect(code, url) } diff --git a/error.go b/error.go index 727830df..e5e9fd47 100644 --- a/error.go +++ b/error.go @@ -435,7 +435,7 @@ func exception(errCode string, ctx *context.Context) { func executeError(err *errorInfo, ctx *context.Context, code int) { //make sure to log the error in the access log - logAccess(ctx, nil, code) + LogAccess(ctx, nil, code) if err.errorType == errorTypeHandler { ctx.ResponseWriter.WriteHeader(code) diff --git a/fs.go b/fs.go index bf7002ad..41cc6f6e 100644 --- a/fs.go +++ b/fs.go @@ -42,13 +42,13 @@ func walk(fs http.FileSystem, path string, info os.FileInfo, walkFn filepath.Wal } dir, err := fs.Open(path) - defer dir.Close() if err != nil { if err1 := walkFn(path, info, err); err1 != nil { return err1 } return err } + defer dir.Close() dirs, err := dir.Readdir(-1) err1 := walkFn(path, info, err) // If err != nil, walk can't walk into this directory. diff --git a/go.mod b/go.mod index 9b3eb08e..fbdec124 100644 --- a/go.mod +++ b/go.mod @@ -2,9 +2,9 @@ module github.com/astaxie/beego require ( github.com/Knetic/govaluate v3.0.0+incompatible // indirect + github.com/OwnLocal/goes v1.0.0 github.com/beego/goyaml2 v0.0.0-20130207012346-5545475820dd github.com/beego/x2j v0.0.0-20131220205130-a0352aadc542 - github.com/belogik/goes v0.0.0-20151229125003-e54d722c3aff github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737 github.com/casbin/casbin v1.7.0 github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58 diff --git a/go.sum b/go.sum index fbe3a8c3..ab233162 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,6 @@ github.com/Knetic/govaluate v3.0.0+incompatible h1:7o6+MAPhYTCF0+fdvoz1xDedhRb4f6s9Tn1Tt7/WTEg= github.com/Knetic/govaluate v3.0.0+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= +github.com/OwnLocal/goes v1.0.0/go.mod h1:8rIFjBGTue3lCU0wplczcUgt9Gxgrkkrw7etMIcn8TM= github.com/beego/goyaml2 v0.0.0-20130207012346-5545475820dd h1:jZtX5jh5IOMu0fpOTC3ayh6QGSPJ/KWOv1lgPvbRw1M= github.com/beego/goyaml2 v0.0.0-20130207012346-5545475820dd/go.mod h1:1b+Y/CofkYwXMUU0OhQqGvsY2Bvgr4j6jfT699wyZKQ= github.com/beego/x2j v0.0.0-20131220205130-a0352aadc542 h1:nYXb+3jF6Oq/j8R/y90XrKpreCxIalBWfeyeKymgOPk= diff --git a/grace/conn.go b/grace/conn.go deleted file mode 100644 index e020f850..00000000 --- a/grace/conn.go +++ /dev/null @@ -1,39 +0,0 @@ -package grace - -import ( - "errors" - "net" - "sync" -) - -type graceConn struct { - net.Conn - server *Server - m sync.Mutex - closed bool -} - -func (c *graceConn) Close() (err error) { - defer func() { - if r := recover(); r != nil { - switch x := r.(type) { - case string: - err = errors.New(x) - case error: - err = x - default: - err = errors.New("Unknown panic") - } - } - }() - - c.m.Lock() - if c.closed { - c.m.Unlock() - return - } - c.server.wg.Done() - c.closed = true - c.m.Unlock() - return c.Conn.Close() -} diff --git a/grace/grace.go b/grace/grace.go index 6ebf8455..fb0cb7bb 100644 --- a/grace/grace.go +++ b/grace/grace.go @@ -78,7 +78,7 @@ var ( DefaultReadTimeOut time.Duration // DefaultWriteTimeOut is the HTTP Write timeout DefaultWriteTimeOut time.Duration - // DefaultMaxHeaderBytes is the Max HTTP Herder size, default is 0, no limit + // DefaultMaxHeaderBytes is the Max HTTP Header size, default is 0, no limit DefaultMaxHeaderBytes int // DefaultTimeout is the shutdown server's timeout. default is 60s DefaultTimeout = 60 * time.Second @@ -122,7 +122,6 @@ func NewServer(addr string, handler http.Handler) (srv *Server) { } srv = &Server{ - wg: sync.WaitGroup{}, sigChan: make(chan os.Signal), isChild: isChild, SignalHooks: map[int]map[os.Signal][]func(){ @@ -137,20 +136,21 @@ func NewServer(addr string, handler http.Handler) (srv *Server) { syscall.SIGTERM: {}, }, }, - state: StateInit, - Network: "tcp", + state: StateInit, + Network: "tcp", + terminalChan: make(chan error), //no cache channel + } + srv.Server = &http.Server{ + Addr: addr, + ReadTimeout: DefaultReadTimeOut, + WriteTimeout: DefaultWriteTimeOut, + MaxHeaderBytes: DefaultMaxHeaderBytes, + Handler: handler, } - srv.Server = &http.Server{} - srv.Server.Addr = addr - srv.Server.ReadTimeout = DefaultReadTimeOut - srv.Server.WriteTimeout = DefaultWriteTimeOut - srv.Server.MaxHeaderBytes = DefaultMaxHeaderBytes - srv.Server.Handler = handler runningServersOrder = append(runningServersOrder, addr) runningServers[addr] = srv - - return + return srv } // ListenAndServe refer http.ListenAndServe diff --git a/grace/listener.go b/grace/listener.go deleted file mode 100644 index 7ede63a3..00000000 --- a/grace/listener.go +++ /dev/null @@ -1,62 +0,0 @@ -package grace - -import ( - "net" - "os" - "syscall" - "time" -) - -type graceListener struct { - net.Listener - stop chan error - stopped bool - server *Server -} - -func newGraceListener(l net.Listener, srv *Server) (el *graceListener) { - el = &graceListener{ - Listener: l, - stop: make(chan error), - server: srv, - } - go func() { - <-el.stop - el.stopped = true - el.stop <- el.Listener.Close() - }() - return -} - -func (gl *graceListener) Accept() (c net.Conn, err error) { - tc, err := gl.Listener.(*net.TCPListener).AcceptTCP() - if err != nil { - return - } - - tc.SetKeepAlive(true) - tc.SetKeepAlivePeriod(3 * time.Minute) - - c = &graceConn{ - Conn: tc, - server: gl.server, - } - - gl.server.wg.Add(1) - return -} - -func (gl *graceListener) Close() error { - if gl.stopped { - return syscall.EINVAL - } - gl.stop <- nil - return <-gl.stop -} - -func (gl *graceListener) File() *os.File { - // returns a dup(2) - FD_CLOEXEC flag *not* set - tl := gl.Listener.(*net.TCPListener) - fl, _ := tl.File() - return fl -} diff --git a/grace/server.go b/grace/server.go index 513a52a9..1ce8bc78 100644 --- a/grace/server.go +++ b/grace/server.go @@ -1,6 +1,7 @@ package grace import ( + "context" "crypto/tls" "crypto/x509" "fmt" @@ -12,7 +13,6 @@ import ( "os/exec" "os/signal" "strings" - "sync" "syscall" "time" ) @@ -20,14 +20,13 @@ import ( // Server embedded http.Server type Server struct { *http.Server - GraceListener net.Listener - SignalHooks map[int]map[os.Signal][]func() - tlsInnerListener *graceListener - wg sync.WaitGroup - sigChan chan os.Signal - isChild bool - state uint8 - Network string + ln net.Listener + SignalHooks map[int]map[os.Signal][]func() + sigChan chan os.Signal + isChild bool + state uint8 + Network string + terminalChan chan error } // Serve accepts incoming connections on the Listener l, @@ -35,11 +34,19 @@ type Server struct { // The service goroutines read requests and then call srv.Handler to reply to them. func (srv *Server) Serve() (err error) { srv.state = StateRunning - err = srv.Server.Serve(srv.GraceListener) - log.Println(syscall.Getpid(), "Waiting for connections to finish...") - srv.wg.Wait() - srv.state = StateTerminate - return + defer func() { srv.state = StateTerminate }() + + // When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS + // immediately return ErrServerClosed. Make sure the program doesn't exit + // and waits instead for Shutdown to return. + if err = srv.Server.Serve(srv.ln); err != nil && err != http.ErrServerClosed { + log.Println(syscall.Getpid(), "Server.Serve() error:", err) + return err + } + + log.Println(syscall.Getpid(), srv.ln.Addr(), "Listener closed.") + // wait for Shutdown to return + return <-srv.terminalChan } // ListenAndServe listens on the TCP network address srv.Addr and then calls Serve @@ -53,14 +60,12 @@ func (srv *Server) ListenAndServe() (err error) { go srv.handleSignals() - l, err := srv.getListener(addr) + srv.ln, err = srv.getListener(addr) if err != nil { log.Println(err) return err } - srv.GraceListener = newGraceListener(l, srv) - if srv.isChild { process, err := os.FindProcess(os.Getppid()) if err != nil { @@ -107,14 +112,12 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { go srv.handleSignals() - l, err := srv.getListener(addr) + ln, err := srv.getListener(addr) if err != nil { log.Println(err) return err } - - srv.tlsInnerListener = newGraceListener(l, srv) - srv.GraceListener = tls.NewListener(srv.tlsInnerListener, srv.TLSConfig) + srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) if srv.isChild { process, err := os.FindProcess(os.Getppid()) @@ -127,6 +130,7 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { return err } } + log.Println(os.Getpid(), srv.Addr) return srv.Serve() } @@ -163,14 +167,12 @@ func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) log.Println("Mutual HTTPS") go srv.handleSignals() - l, err := srv.getListener(addr) + ln, err := srv.getListener(addr) if err != nil { log.Println(err) return err } - - srv.tlsInnerListener = newGraceListener(l, srv) - srv.GraceListener = tls.NewListener(srv.tlsInnerListener, srv.TLSConfig) + srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) if srv.isChild { process, err := os.FindProcess(os.Getppid()) @@ -183,6 +185,7 @@ func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) return err } } + log.Println(os.Getpid(), srv.Addr) return srv.Serve() } @@ -213,6 +216,20 @@ func (srv *Server) getListener(laddr string) (l net.Listener, err error) { return } +type tcpKeepAliveListener struct { + *net.TCPListener +} + +func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { + tc, err := ln.AcceptTCP() + if err != nil { + return + } + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(3 * time.Minute) + return tc, nil +} + // handleSignals listens for os Signals and calls any hooked in function that the // user had registered with the signal. func (srv *Server) handleSignals() { @@ -265,37 +282,14 @@ func (srv *Server) shutdown() { } srv.state = StateShuttingDown + log.Println(syscall.Getpid(), "Waiting for connections to finish...") + ctx := context.Background() if DefaultTimeout >= 0 { - go srv.serverTimeout(DefaultTimeout) - } - err := srv.GraceListener.Close() - if err != nil { - log.Println(syscall.Getpid(), "Listener.Close() error:", err) - } else { - log.Println(syscall.Getpid(), srv.GraceListener.Addr(), "Listener closed.") - } -} - -// serverTimeout forces the server to shutdown in a given timeout - whether it -// finished outstanding requests or not. if Read/WriteTimeout are not set or the -// max header size is very big a connection could hang -func (srv *Server) serverTimeout(d time.Duration) { - defer func() { - if r := recover(); r != nil { - log.Println("WaitGroup at 0", r) - } - }() - if srv.state != StateShuttingDown { - return - } - time.Sleep(d) - log.Println("[STOP - Hammer Time] Forcefully shutting down parent") - for { - if srv.state == StateTerminate { - break - } - srv.wg.Done() + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout) + defer cancel() } + srv.terminalChan <- srv.Server.Shutdown(ctx) } func (srv *Server) fork() (err error) { @@ -309,12 +303,8 @@ func (srv *Server) fork() (err error) { var files = make([]*os.File, len(runningServers)) var orderArgs = make([]string, len(runningServers)) for _, srvPtr := range runningServers { - switch srvPtr.GraceListener.(type) { - case *graceListener: - files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.GraceListener.(*graceListener).File() - default: - files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.tlsInnerListener.File() - } + f, _ := srvPtr.ln.(*net.TCPListener).File() + files[socketPtrOffsetMap[srvPtr.Server.Addr]] = f orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr } diff --git a/httplib/httplib_test.go b/httplib/httplib_test.go index 8970b764..7314ae01 100644 --- a/httplib/httplib_test.go +++ b/httplib/httplib_test.go @@ -206,10 +206,16 @@ func TestToJson(t *testing.T) { t.Fatal(err) } t.Log(ip.Origin) - - if n := strings.Count(ip.Origin, "."); n != 3 { + ips := strings.Split(ip.Origin, ",") + if len(ips) == 0 { t.Fatal("response is not valid ip") } + for i := range ips { + if net.ParseIP(strings.TrimSpace(ips[i])).To4() == nil { + t.Fatal("response is not valid ip") + } + } + } func TestToFile(t *testing.T) { diff --git a/log.go b/log.go index e9412f92..cc4c0f81 100644 --- a/log.go +++ b/log.go @@ -21,6 +21,7 @@ import ( ) // Log levels to control the logging output. +// Deprecated: use github.com/astaxie/beego/logs instead. const ( LevelEmergency = iota LevelAlert @@ -33,75 +34,90 @@ const ( ) // BeeLogger references the used application logger. +// Deprecated: use github.com/astaxie/beego/logs instead. var BeeLogger = logs.GetBeeLogger() // SetLevel sets the global log level used by the simple logger. +// Deprecated: use github.com/astaxie/beego/logs instead. func SetLevel(l int) { logs.SetLevel(l) } // SetLogFuncCall set the CallDepth, default is 3 +// Deprecated: use github.com/astaxie/beego/logs instead. func SetLogFuncCall(b bool) { logs.SetLogFuncCall(b) } // SetLogger sets a new logger. +// Deprecated: use github.com/astaxie/beego/logs instead. func SetLogger(adaptername string, config string) error { return logs.SetLogger(adaptername, config) } // Emergency logs a message at emergency level. +// Deprecated: use github.com/astaxie/beego/logs instead. func Emergency(v ...interface{}) { logs.Emergency(generateFmtStr(len(v)), v...) } // Alert logs a message at alert level. +// Deprecated: use github.com/astaxie/beego/logs instead. func Alert(v ...interface{}) { logs.Alert(generateFmtStr(len(v)), v...) } // Critical logs a message at critical level. +// Deprecated: use github.com/astaxie/beego/logs instead. func Critical(v ...interface{}) { logs.Critical(generateFmtStr(len(v)), v...) } // Error logs a message at error level. +// Deprecated: use github.com/astaxie/beego/logs instead. func Error(v ...interface{}) { logs.Error(generateFmtStr(len(v)), v...) } // Warning logs a message at warning level. +// Deprecated: use github.com/astaxie/beego/logs instead. func Warning(v ...interface{}) { logs.Warning(generateFmtStr(len(v)), v...) } // Warn compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. func Warn(v ...interface{}) { logs.Warn(generateFmtStr(len(v)), v...) } // Notice logs a message at notice level. +// Deprecated: use github.com/astaxie/beego/logs instead. func Notice(v ...interface{}) { logs.Notice(generateFmtStr(len(v)), v...) } // Informational logs a message at info level. +// Deprecated: use github.com/astaxie/beego/logs instead. func Informational(v ...interface{}) { logs.Informational(generateFmtStr(len(v)), v...) } // Info compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. func Info(v ...interface{}) { logs.Info(generateFmtStr(len(v)), v...) } // Debug logs a message at debug level. +// Deprecated: use github.com/astaxie/beego/logs instead. func Debug(v ...interface{}) { logs.Debug(generateFmtStr(len(v)), v...) } // Trace logs a message at trace level. // compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. func Trace(v ...interface{}) { logs.Trace(generateFmtStr(len(v)), v...) } diff --git a/logs/color.go b/logs/color.go deleted file mode 100644 index 41d23638..00000000 --- a/logs/color.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build !windows - -package logs - -import "io" - -type ansiColorWriter struct { - w io.Writer - mode outputMode -} - -func (cw *ansiColorWriter) Write(p []byte) (int, error) { - return cw.w.Write(p) -} diff --git a/logs/color_windows.go b/logs/color_windows.go deleted file mode 100644 index 4e28f188..00000000 --- a/logs/color_windows.go +++ /dev/null @@ -1,428 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build windows - -package logs - -import ( - "bytes" - "io" - "strings" - "syscall" - "unsafe" -) - -type ( - csiState int - parseResult int -) - -const ( - outsideCsiCode csiState = iota - firstCsiCode - secondCsiCode -) - -const ( - noConsole parseResult = iota - changedColor - unknown -) - -type ansiColorWriter struct { - w io.Writer - mode outputMode - state csiState - paramStartBuf bytes.Buffer - paramBuf bytes.Buffer -} - -const ( - firstCsiChar byte = '\x1b' - secondeCsiChar byte = '[' - separatorChar byte = ';' - sgrCode byte = 'm' -) - -const ( - foregroundBlue = uint16(0x0001) - foregroundGreen = uint16(0x0002) - foregroundRed = uint16(0x0004) - foregroundIntensity = uint16(0x0008) - backgroundBlue = uint16(0x0010) - backgroundGreen = uint16(0x0020) - backgroundRed = uint16(0x0040) - backgroundIntensity = uint16(0x0080) - underscore = uint16(0x8000) - - foregroundMask = foregroundBlue | foregroundGreen | foregroundRed | foregroundIntensity - backgroundMask = backgroundBlue | backgroundGreen | backgroundRed | backgroundIntensity -) - -const ( - ansiReset = "0" - ansiIntensityOn = "1" - ansiIntensityOff = "21" - ansiUnderlineOn = "4" - ansiUnderlineOff = "24" - ansiBlinkOn = "5" - ansiBlinkOff = "25" - - ansiForegroundBlack = "30" - ansiForegroundRed = "31" - ansiForegroundGreen = "32" - ansiForegroundYellow = "33" - ansiForegroundBlue = "34" - ansiForegroundMagenta = "35" - ansiForegroundCyan = "36" - ansiForegroundWhite = "37" - ansiForegroundDefault = "39" - - ansiBackgroundBlack = "40" - ansiBackgroundRed = "41" - ansiBackgroundGreen = "42" - ansiBackgroundYellow = "43" - ansiBackgroundBlue = "44" - ansiBackgroundMagenta = "45" - ansiBackgroundCyan = "46" - ansiBackgroundWhite = "47" - ansiBackgroundDefault = "49" - - ansiLightForegroundGray = "90" - ansiLightForegroundRed = "91" - ansiLightForegroundGreen = "92" - ansiLightForegroundYellow = "93" - ansiLightForegroundBlue = "94" - ansiLightForegroundMagenta = "95" - ansiLightForegroundCyan = "96" - ansiLightForegroundWhite = "97" - - ansiLightBackgroundGray = "100" - ansiLightBackgroundRed = "101" - ansiLightBackgroundGreen = "102" - ansiLightBackgroundYellow = "103" - ansiLightBackgroundBlue = "104" - ansiLightBackgroundMagenta = "105" - ansiLightBackgroundCyan = "106" - ansiLightBackgroundWhite = "107" -) - -type drawType int - -const ( - foreground drawType = iota - background -) - -type winColor struct { - code uint16 - drawType drawType -} - -var colorMap = map[string]winColor{ - ansiForegroundBlack: {0, foreground}, - ansiForegroundRed: {foregroundRed, foreground}, - ansiForegroundGreen: {foregroundGreen, foreground}, - ansiForegroundYellow: {foregroundRed | foregroundGreen, foreground}, - ansiForegroundBlue: {foregroundBlue, foreground}, - ansiForegroundMagenta: {foregroundRed | foregroundBlue, foreground}, - ansiForegroundCyan: {foregroundGreen | foregroundBlue, foreground}, - ansiForegroundWhite: {foregroundRed | foregroundGreen | foregroundBlue, foreground}, - ansiForegroundDefault: {foregroundRed | foregroundGreen | foregroundBlue, foreground}, - - ansiBackgroundBlack: {0, background}, - ansiBackgroundRed: {backgroundRed, background}, - ansiBackgroundGreen: {backgroundGreen, background}, - ansiBackgroundYellow: {backgroundRed | backgroundGreen, background}, - ansiBackgroundBlue: {backgroundBlue, background}, - ansiBackgroundMagenta: {backgroundRed | backgroundBlue, background}, - ansiBackgroundCyan: {backgroundGreen | backgroundBlue, background}, - ansiBackgroundWhite: {backgroundRed | backgroundGreen | backgroundBlue, background}, - ansiBackgroundDefault: {0, background}, - - ansiLightForegroundGray: {foregroundIntensity, foreground}, - ansiLightForegroundRed: {foregroundIntensity | foregroundRed, foreground}, - ansiLightForegroundGreen: {foregroundIntensity | foregroundGreen, foreground}, - ansiLightForegroundYellow: {foregroundIntensity | foregroundRed | foregroundGreen, foreground}, - ansiLightForegroundBlue: {foregroundIntensity | foregroundBlue, foreground}, - ansiLightForegroundMagenta: {foregroundIntensity | foregroundRed | foregroundBlue, foreground}, - ansiLightForegroundCyan: {foregroundIntensity | foregroundGreen | foregroundBlue, foreground}, - ansiLightForegroundWhite: {foregroundIntensity | foregroundRed | foregroundGreen | foregroundBlue, foreground}, - - ansiLightBackgroundGray: {backgroundIntensity, background}, - ansiLightBackgroundRed: {backgroundIntensity | backgroundRed, background}, - ansiLightBackgroundGreen: {backgroundIntensity | backgroundGreen, background}, - ansiLightBackgroundYellow: {backgroundIntensity | backgroundRed | backgroundGreen, background}, - ansiLightBackgroundBlue: {backgroundIntensity | backgroundBlue, background}, - ansiLightBackgroundMagenta: {backgroundIntensity | backgroundRed | backgroundBlue, background}, - ansiLightBackgroundCyan: {backgroundIntensity | backgroundGreen | backgroundBlue, background}, - ansiLightBackgroundWhite: {backgroundIntensity | backgroundRed | backgroundGreen | backgroundBlue, background}, -} - -var ( - kernel32 = syscall.NewLazyDLL("kernel32.dll") - procSetConsoleTextAttribute = kernel32.NewProc("SetConsoleTextAttribute") - procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo") - defaultAttr *textAttributes -) - -func init() { - screenInfo := getConsoleScreenBufferInfo(uintptr(syscall.Stdout)) - if screenInfo != nil { - colorMap[ansiForegroundDefault] = winColor{ - screenInfo.WAttributes & (foregroundRed | foregroundGreen | foregroundBlue), - foreground, - } - colorMap[ansiBackgroundDefault] = winColor{ - screenInfo.WAttributes & (backgroundRed | backgroundGreen | backgroundBlue), - background, - } - defaultAttr = convertTextAttr(screenInfo.WAttributes) - } -} - -type coord struct { - X, Y int16 -} - -type smallRect struct { - Left, Top, Right, Bottom int16 -} - -type consoleScreenBufferInfo struct { - DwSize coord - DwCursorPosition coord - WAttributes uint16 - SrWindow smallRect - DwMaximumWindowSize coord -} - -func getConsoleScreenBufferInfo(hConsoleOutput uintptr) *consoleScreenBufferInfo { - var csbi consoleScreenBufferInfo - ret, _, _ := procGetConsoleScreenBufferInfo.Call( - hConsoleOutput, - uintptr(unsafe.Pointer(&csbi))) - if ret == 0 { - return nil - } - return &csbi -} - -func setConsoleTextAttribute(hConsoleOutput uintptr, wAttributes uint16) bool { - ret, _, _ := procSetConsoleTextAttribute.Call( - hConsoleOutput, - uintptr(wAttributes)) - return ret != 0 -} - -type textAttributes struct { - foregroundColor uint16 - backgroundColor uint16 - foregroundIntensity uint16 - backgroundIntensity uint16 - underscore uint16 - otherAttributes uint16 -} - -func convertTextAttr(winAttr uint16) *textAttributes { - fgColor := winAttr & (foregroundRed | foregroundGreen | foregroundBlue) - bgColor := winAttr & (backgroundRed | backgroundGreen | backgroundBlue) - fgIntensity := winAttr & foregroundIntensity - bgIntensity := winAttr & backgroundIntensity - underline := winAttr & underscore - otherAttributes := winAttr &^ (foregroundMask | backgroundMask | underscore) - return &textAttributes{fgColor, bgColor, fgIntensity, bgIntensity, underline, otherAttributes} -} - -func convertWinAttr(textAttr *textAttributes) uint16 { - var winAttr uint16 - winAttr |= textAttr.foregroundColor - winAttr |= textAttr.backgroundColor - winAttr |= textAttr.foregroundIntensity - winAttr |= textAttr.backgroundIntensity - winAttr |= textAttr.underscore - winAttr |= textAttr.otherAttributes - return winAttr -} - -func changeColor(param []byte) parseResult { - screenInfo := getConsoleScreenBufferInfo(uintptr(syscall.Stdout)) - if screenInfo == nil { - return noConsole - } - - winAttr := convertTextAttr(screenInfo.WAttributes) - strParam := string(param) - if len(strParam) <= 0 { - strParam = "0" - } - csiParam := strings.Split(strParam, string(separatorChar)) - for _, p := range csiParam { - c, ok := colorMap[p] - switch { - case !ok: - switch p { - case ansiReset: - winAttr.foregroundColor = defaultAttr.foregroundColor - winAttr.backgroundColor = defaultAttr.backgroundColor - winAttr.foregroundIntensity = defaultAttr.foregroundIntensity - winAttr.backgroundIntensity = defaultAttr.backgroundIntensity - winAttr.underscore = 0 - winAttr.otherAttributes = 0 - case ansiIntensityOn: - winAttr.foregroundIntensity = foregroundIntensity - case ansiIntensityOff: - winAttr.foregroundIntensity = 0 - case ansiUnderlineOn: - winAttr.underscore = underscore - case ansiUnderlineOff: - winAttr.underscore = 0 - case ansiBlinkOn: - winAttr.backgroundIntensity = backgroundIntensity - case ansiBlinkOff: - winAttr.backgroundIntensity = 0 - default: - // unknown code - } - case c.drawType == foreground: - winAttr.foregroundColor = c.code - case c.drawType == background: - winAttr.backgroundColor = c.code - } - } - winTextAttribute := convertWinAttr(winAttr) - setConsoleTextAttribute(uintptr(syscall.Stdout), winTextAttribute) - - return changedColor -} - -func parseEscapeSequence(command byte, param []byte) parseResult { - if defaultAttr == nil { - return noConsole - } - - switch command { - case sgrCode: - return changeColor(param) - default: - return unknown - } -} - -func (cw *ansiColorWriter) flushBuffer() (int, error) { - return cw.flushTo(cw.w) -} - -func (cw *ansiColorWriter) resetBuffer() (int, error) { - return cw.flushTo(nil) -} - -func (cw *ansiColorWriter) flushTo(w io.Writer) (int, error) { - var n1, n2 int - var err error - - startBytes := cw.paramStartBuf.Bytes() - cw.paramStartBuf.Reset() - if w != nil { - n1, err = cw.w.Write(startBytes) - if err != nil { - return n1, err - } - } else { - n1 = len(startBytes) - } - paramBytes := cw.paramBuf.Bytes() - cw.paramBuf.Reset() - if w != nil { - n2, err = cw.w.Write(paramBytes) - if err != nil { - return n1 + n2, err - } - } else { - n2 = len(paramBytes) - } - return n1 + n2, nil -} - -func isParameterChar(b byte) bool { - return ('0' <= b && b <= '9') || b == separatorChar -} - -func (cw *ansiColorWriter) Write(p []byte) (int, error) { - var r, nw, first, last int - if cw.mode != DiscardNonColorEscSeq { - cw.state = outsideCsiCode - cw.resetBuffer() - } - - var err error - for i, ch := range p { - switch cw.state { - case outsideCsiCode: - if ch == firstCsiChar { - cw.paramStartBuf.WriteByte(ch) - cw.state = firstCsiCode - } - case firstCsiCode: - switch ch { - case firstCsiChar: - cw.paramStartBuf.WriteByte(ch) - break - case secondeCsiChar: - cw.paramStartBuf.WriteByte(ch) - cw.state = secondCsiCode - last = i - 1 - default: - cw.resetBuffer() - cw.state = outsideCsiCode - } - case secondCsiCode: - if isParameterChar(ch) { - cw.paramBuf.WriteByte(ch) - } else { - nw, err = cw.w.Write(p[first:last]) - r += nw - if err != nil { - return r, err - } - first = i + 1 - result := parseEscapeSequence(ch, cw.paramBuf.Bytes()) - if result == noConsole || (cw.mode == OutputNonColorEscSeq && result == unknown) { - cw.paramBuf.WriteByte(ch) - nw, err := cw.flushBuffer() - if err != nil { - return r, err - } - r += nw - } else { - n, _ := cw.resetBuffer() - // Add one more to the size of the buffer for the last ch - r += n + 1 - } - - cw.state = outsideCsiCode - } - default: - cw.state = outsideCsiCode - } - } - - if cw.mode != DiscardNonColorEscSeq || cw.state == outsideCsiCode { - nw, err = cw.w.Write(p[first:]) - r += nw - } - - return r, err -} diff --git a/logs/color_windows_test.go b/logs/color_windows_test.go deleted file mode 100644 index 5074841a..00000000 --- a/logs/color_windows_test.go +++ /dev/null @@ -1,294 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build windows - -package logs - -import ( - "bytes" - "fmt" - "syscall" - "testing" -) - -var GetConsoleScreenBufferInfo = getConsoleScreenBufferInfo - -func ChangeColor(color uint16) { - setConsoleTextAttribute(uintptr(syscall.Stdout), color) -} - -func ResetColor() { - ChangeColor(uint16(0x0007)) -} - -func TestWritePlanText(t *testing.T) { - inner := bytes.NewBufferString("") - w := NewAnsiColorWriter(inner) - expected := "plain text" - fmt.Fprintf(w, expected) - actual := inner.String() - if actual != expected { - t.Errorf("Get %q, want %q", actual, expected) - } -} - -func TestWriteParseText(t *testing.T) { - inner := bytes.NewBufferString("") - w := NewAnsiColorWriter(inner) - - inputTail := "\x1b[0mtail text" - expectedTail := "tail text" - fmt.Fprintf(w, inputTail) - actualTail := inner.String() - inner.Reset() - if actualTail != expectedTail { - t.Errorf("Get %q, want %q", actualTail, expectedTail) - } - - inputHead := "head text\x1b[0m" - expectedHead := "head text" - fmt.Fprintf(w, inputHead) - actualHead := inner.String() - inner.Reset() - if actualHead != expectedHead { - t.Errorf("Get %q, want %q", actualHead, expectedHead) - } - - inputBothEnds := "both ends \x1b[0m text" - expectedBothEnds := "both ends text" - fmt.Fprintf(w, inputBothEnds) - actualBothEnds := inner.String() - inner.Reset() - if actualBothEnds != expectedBothEnds { - t.Errorf("Get %q, want %q", actualBothEnds, expectedBothEnds) - } - - inputManyEsc := "\x1b\x1b\x1b\x1b[0m many esc" - expectedManyEsc := "\x1b\x1b\x1b many esc" - fmt.Fprintf(w, inputManyEsc) - actualManyEsc := inner.String() - inner.Reset() - if actualManyEsc != expectedManyEsc { - t.Errorf("Get %q, want %q", actualManyEsc, expectedManyEsc) - } - - expectedSplit := "split text" - for _, ch := range "split \x1b[0m text" { - fmt.Fprintf(w, string(ch)) - } - actualSplit := inner.String() - inner.Reset() - if actualSplit != expectedSplit { - t.Errorf("Get %q, want %q", actualSplit, expectedSplit) - } -} - -type screenNotFoundError struct { - error -} - -func writeAnsiColor(expectedText, colorCode string) (actualText string, actualAttributes uint16, err error) { - inner := bytes.NewBufferString("") - w := NewAnsiColorWriter(inner) - fmt.Fprintf(w, "\x1b[%sm%s", colorCode, expectedText) - - actualText = inner.String() - screenInfo := GetConsoleScreenBufferInfo(uintptr(syscall.Stdout)) - if screenInfo != nil { - actualAttributes = screenInfo.WAttributes - } else { - err = &screenNotFoundError{} - } - return -} - -type testParam struct { - text string - attributes uint16 - ansiColor string -} - -func TestWriteAnsiColorText(t *testing.T) { - screenInfo := GetConsoleScreenBufferInfo(uintptr(syscall.Stdout)) - if screenInfo == nil { - t.Fatal("Could not get ConsoleScreenBufferInfo") - } - defer ChangeColor(screenInfo.WAttributes) - defaultFgColor := screenInfo.WAttributes & uint16(0x0007) - defaultBgColor := screenInfo.WAttributes & uint16(0x0070) - defaultFgIntensity := screenInfo.WAttributes & uint16(0x0008) - defaultBgIntensity := screenInfo.WAttributes & uint16(0x0080) - - fgParam := []testParam{ - {"foreground black ", uint16(0x0000 | 0x0000), "30"}, - {"foreground red ", uint16(0x0004 | 0x0000), "31"}, - {"foreground green ", uint16(0x0002 | 0x0000), "32"}, - {"foreground yellow ", uint16(0x0006 | 0x0000), "33"}, - {"foreground blue ", uint16(0x0001 | 0x0000), "34"}, - {"foreground magenta", uint16(0x0005 | 0x0000), "35"}, - {"foreground cyan ", uint16(0x0003 | 0x0000), "36"}, - {"foreground white ", uint16(0x0007 | 0x0000), "37"}, - {"foreground default", defaultFgColor | 0x0000, "39"}, - {"foreground light gray ", uint16(0x0000 | 0x0008 | 0x0000), "90"}, - {"foreground light red ", uint16(0x0004 | 0x0008 | 0x0000), "91"}, - {"foreground light green ", uint16(0x0002 | 0x0008 | 0x0000), "92"}, - {"foreground light yellow ", uint16(0x0006 | 0x0008 | 0x0000), "93"}, - {"foreground light blue ", uint16(0x0001 | 0x0008 | 0x0000), "94"}, - {"foreground light magenta", uint16(0x0005 | 0x0008 | 0x0000), "95"}, - {"foreground light cyan ", uint16(0x0003 | 0x0008 | 0x0000), "96"}, - {"foreground light white ", uint16(0x0007 | 0x0008 | 0x0000), "97"}, - } - - bgParam := []testParam{ - {"background black ", uint16(0x0007 | 0x0000), "40"}, - {"background red ", uint16(0x0007 | 0x0040), "41"}, - {"background green ", uint16(0x0007 | 0x0020), "42"}, - {"background yellow ", uint16(0x0007 | 0x0060), "43"}, - {"background blue ", uint16(0x0007 | 0x0010), "44"}, - {"background magenta", uint16(0x0007 | 0x0050), "45"}, - {"background cyan ", uint16(0x0007 | 0x0030), "46"}, - {"background white ", uint16(0x0007 | 0x0070), "47"}, - {"background default", uint16(0x0007) | defaultBgColor, "49"}, - {"background light gray ", uint16(0x0007 | 0x0000 | 0x0080), "100"}, - {"background light red ", uint16(0x0007 | 0x0040 | 0x0080), "101"}, - {"background light green ", uint16(0x0007 | 0x0020 | 0x0080), "102"}, - {"background light yellow ", uint16(0x0007 | 0x0060 | 0x0080), "103"}, - {"background light blue ", uint16(0x0007 | 0x0010 | 0x0080), "104"}, - {"background light magenta", uint16(0x0007 | 0x0050 | 0x0080), "105"}, - {"background light cyan ", uint16(0x0007 | 0x0030 | 0x0080), "106"}, - {"background light white ", uint16(0x0007 | 0x0070 | 0x0080), "107"}, - } - - resetParam := []testParam{ - {"all reset", defaultFgColor | defaultBgColor | defaultFgIntensity | defaultBgIntensity, "0"}, - {"all reset", defaultFgColor | defaultBgColor | defaultFgIntensity | defaultBgIntensity, ""}, - } - - boldParam := []testParam{ - {"bold on", uint16(0x0007 | 0x0008), "1"}, - {"bold off", uint16(0x0007), "21"}, - } - - underscoreParam := []testParam{ - {"underscore on", uint16(0x0007 | 0x8000), "4"}, - {"underscore off", uint16(0x0007), "24"}, - } - - blinkParam := []testParam{ - {"blink on", uint16(0x0007 | 0x0080), "5"}, - {"blink off", uint16(0x0007), "25"}, - } - - mixedParam := []testParam{ - {"both black, bold, underline, blink", uint16(0x0000 | 0x0000 | 0x0008 | 0x8000 | 0x0080), "30;40;1;4;5"}, - {"both red, bold, underline, blink", uint16(0x0004 | 0x0040 | 0x0008 | 0x8000 | 0x0080), "31;41;1;4;5"}, - {"both green, bold, underline, blink", uint16(0x0002 | 0x0020 | 0x0008 | 0x8000 | 0x0080), "32;42;1;4;5"}, - {"both yellow, bold, underline, blink", uint16(0x0006 | 0x0060 | 0x0008 | 0x8000 | 0x0080), "33;43;1;4;5"}, - {"both blue, bold, underline, blink", uint16(0x0001 | 0x0010 | 0x0008 | 0x8000 | 0x0080), "34;44;1;4;5"}, - {"both magenta, bold, underline, blink", uint16(0x0005 | 0x0050 | 0x0008 | 0x8000 | 0x0080), "35;45;1;4;5"}, - {"both cyan, bold, underline, blink", uint16(0x0003 | 0x0030 | 0x0008 | 0x8000 | 0x0080), "36;46;1;4;5"}, - {"both white, bold, underline, blink", uint16(0x0007 | 0x0070 | 0x0008 | 0x8000 | 0x0080), "37;47;1;4;5"}, - {"both default, bold, underline, blink", uint16(defaultFgColor | defaultBgColor | 0x0008 | 0x8000 | 0x0080), "39;49;1;4;5"}, - } - - assertTextAttribute := func(expectedText string, expectedAttributes uint16, ansiColor string) { - actualText, actualAttributes, err := writeAnsiColor(expectedText, ansiColor) - if actualText != expectedText { - t.Errorf("Get %q, want %q", actualText, expectedText) - } - if err != nil { - t.Fatal("Could not get ConsoleScreenBufferInfo") - } - if actualAttributes != expectedAttributes { - t.Errorf("Text: %q, Get 0x%04x, want 0x%04x", expectedText, actualAttributes, expectedAttributes) - } - } - - for _, v := range fgParam { - ResetColor() - assertTextAttribute(v.text, v.attributes, v.ansiColor) - } - - for _, v := range bgParam { - ChangeColor(uint16(0x0070 | 0x0007)) - assertTextAttribute(v.text, v.attributes, v.ansiColor) - } - - for _, v := range resetParam { - ChangeColor(uint16(0x0000 | 0x0070 | 0x0008)) - assertTextAttribute(v.text, v.attributes, v.ansiColor) - } - - ResetColor() - for _, v := range boldParam { - assertTextAttribute(v.text, v.attributes, v.ansiColor) - } - - ResetColor() - for _, v := range underscoreParam { - assertTextAttribute(v.text, v.attributes, v.ansiColor) - } - - ResetColor() - for _, v := range blinkParam { - assertTextAttribute(v.text, v.attributes, v.ansiColor) - } - - for _, v := range mixedParam { - ResetColor() - assertTextAttribute(v.text, v.attributes, v.ansiColor) - } -} - -func TestIgnoreUnknownSequences(t *testing.T) { - inner := bytes.NewBufferString("") - w := NewModeAnsiColorWriter(inner, OutputNonColorEscSeq) - - inputText := "\x1b[=decpath mode" - expectedTail := inputText - fmt.Fprintf(w, inputText) - actualTail := inner.String() - inner.Reset() - if actualTail != expectedTail { - t.Errorf("Get %q, want %q", actualTail, expectedTail) - } - - inputText = "\x1b[=tailing esc and bracket\x1b[" - expectedTail = inputText - fmt.Fprintf(w, inputText) - actualTail = inner.String() - inner.Reset() - if actualTail != expectedTail { - t.Errorf("Get %q, want %q", actualTail, expectedTail) - } - - inputText = "\x1b[?tailing esc\x1b" - expectedTail = inputText - fmt.Fprintf(w, inputText) - actualTail = inner.String() - inner.Reset() - if actualTail != expectedTail { - t.Errorf("Get %q, want %q", actualTail, expectedTail) - } - - inputText = "\x1b[1h;3punended color code invalid\x1b3" - expectedTail = inputText - fmt.Fprintf(w, inputText) - actualTail = inner.String() - inner.Reset() - if actualTail != expectedTail { - t.Errorf("Get %q, want %q", actualTail, expectedTail) - } -} diff --git a/logs/conn.go b/logs/conn.go index 6d5bf6bf..afe0cbb7 100644 --- a/logs/conn.go +++ b/logs/conn.go @@ -63,7 +63,7 @@ func (c *connWriter) WriteMsg(when time.Time, msg string, level int) error { defer c.innerWriter.Close() } - c.lg.println(when, msg) + c.lg.writeln(when, msg) return nil } diff --git a/logs/console.go b/logs/console.go index e75f2a1b..3dcaee1d 100644 --- a/logs/console.go +++ b/logs/console.go @@ -17,8 +17,10 @@ package logs import ( "encoding/json" "os" - "runtime" + "strings" "time" + + "github.com/shiena/ansicolor" ) // brush is a color join function @@ -54,9 +56,9 @@ type consoleWriter struct { // NewConsole create ConsoleWriter returning as LoggerInterface. func NewConsole() Logger { cw := &consoleWriter{ - lg: newLogWriter(os.Stdout), + lg: newLogWriter(ansicolor.NewAnsiColorWriter(os.Stdout)), Level: LevelDebug, - Colorful: runtime.GOOS != "windows", + Colorful: true, } return cw } @@ -67,11 +69,7 @@ func (c *consoleWriter) Init(jsonConfig string) error { if len(jsonConfig) == 0 { return nil } - err := json.Unmarshal([]byte(jsonConfig), c) - if runtime.GOOS == "windows" { - c.Colorful = false - } - return err + return json.Unmarshal([]byte(jsonConfig), c) } // WriteMsg write message in console. @@ -80,9 +78,9 @@ func (c *consoleWriter) WriteMsg(when time.Time, msg string, level int) error { return nil } if c.Colorful { - msg = colors[level](msg) + msg = strings.Replace(msg, levelPrefix[level], colors[level](levelPrefix[level]), 1) } - c.lg.println(when, msg) + c.lg.writeln(when, msg) return nil } diff --git a/logs/es/es.go b/logs/es/es.go index 22f4f650..9d6a615c 100644 --- a/logs/es/es.go +++ b/logs/es/es.go @@ -8,8 +8,8 @@ import ( "net/url" "time" + "github.com/OwnLocal/goes" "github.com/astaxie/beego/logs" - "github.com/belogik/goes" ) // NewES return a LoggerInterface @@ -21,7 +21,7 @@ func NewES() logs.Logger { } type esLogger struct { - *goes.Connection + *goes.Client DSN string `json:"dsn"` Level int `json:"level"` } @@ -41,8 +41,8 @@ func (el *esLogger) Init(jsonconfig string) error { } else if host, port, err := net.SplitHostPort(u.Host); err != nil { return err } else { - conn := goes.NewConnection(host, port) - el.Connection = conn + conn := goes.NewClient(host, port) + el.Client = conn } return nil } @@ -78,3 +78,4 @@ func (el *esLogger) Flush() { func init() { logs.Register(logs.AdapterEs, NewES) } + diff --git a/logs/log.go b/logs/log.go index a3614165..49f3794f 100644 --- a/logs/log.go +++ b/logs/log.go @@ -47,7 +47,7 @@ import ( // RFC5424 log message levels. const ( - LevelEmergency = iota + LevelEmergency = iota LevelAlert LevelCritical LevelError @@ -92,7 +92,7 @@ type Logger interface { } var adapters = make(map[string]newLoggerFunc) -var levelPrefix = [LevelDebug + 1]string{"[M] ", "[A] ", "[C] ", "[E] ", "[W] ", "[N] ", "[I] ", "[D] "} +var levelPrefix = [LevelDebug + 1]string{"[M]", "[A]", "[C]", "[E]", "[W]", "[N]", "[I]", "[D]"} // Register makes a log provide available by the provided name. // If Register is called twice with the same name or if driver is nil, @@ -187,12 +187,12 @@ func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { } } - log, ok := adapters[adapterName] + logAdapter, ok := adapters[adapterName] if !ok { return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName) } - lg := log() + lg := logAdapter() err := lg.Init(config) if err != nil { fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error()) @@ -248,7 +248,7 @@ func (bl *BeeLogger) Write(p []byte) (n int, err error) { } // writeMsg will always add a '\n' character if p[len(p)-1] == '\n' { - p = p[0: len(p)-1] + p = p[0 : len(p)-1] } // set levelLoggerImpl to ensure all log message will be write out err = bl.writeMsg(levelLoggerImpl, string(p)) @@ -287,7 +287,7 @@ func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error // set to emergency to ensure all log will be print out correctly logLevel = LevelEmergency } else { - msg = levelPrefix[logLevel] + msg + msg = levelPrefix[logLevel] + " " + msg } if bl.asynchronous { diff --git a/logs/logger.go b/logs/logger.go index 428d3aa0..c7cf8a56 100644 --- a/logs/logger.go +++ b/logs/logger.go @@ -15,9 +15,8 @@ package logs import ( - "fmt" "io" - "os" + "runtime" "sync" "time" ) @@ -31,47 +30,13 @@ func newLogWriter(wr io.Writer) *logWriter { return &logWriter{writer: wr} } -func (lg *logWriter) println(when time.Time, msg string) { +func (lg *logWriter) writeln(when time.Time, msg string) { lg.Lock() - h, _, _:= formatTimeHeader(when) + h, _, _ := formatTimeHeader(when) lg.writer.Write(append(append(h, msg...), '\n')) lg.Unlock() } -type outputMode int - -// DiscardNonColorEscSeq supports the divided color escape sequence. -// But non-color escape sequence is not output. -// Please use the OutputNonColorEscSeq If you want to output a non-color -// escape sequences such as ncurses. However, it does not support the divided -// color escape sequence. -const ( - _ outputMode = iota - DiscardNonColorEscSeq - OutputNonColorEscSeq -) - -// NewAnsiColorWriter creates and initializes a new ansiColorWriter -// using io.Writer w as its initial contents. -// In the console of Windows, which change the foreground and background -// colors of the text by the escape sequence. -// In the console of other systems, which writes to w all text. -func NewAnsiColorWriter(w io.Writer) io.Writer { - return NewModeAnsiColorWriter(w, DiscardNonColorEscSeq) -} - -// NewModeAnsiColorWriter create and initializes a new ansiColorWriter -// by specifying the outputMode. -func NewModeAnsiColorWriter(w io.Writer, mode outputMode) io.Writer { - if _, ok := w.(*ansiColorWriter); !ok { - return &ansiColorWriter{ - w: w, - mode: mode, - } - } - return w -} - const ( y1 = `0123456789` y2 = `0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789` @@ -146,63 +111,65 @@ var ( reset = string([]byte{27, 91, 48, 109}) ) +var once sync.Once +var colorMap map[string]string + +func initColor() { + if runtime.GOOS == "windows" { + green = w32Green + white = w32White + yellow = w32Yellow + red = w32Red + blue = w32Blue + magenta = w32Magenta + cyan = w32Cyan + } + colorMap = map[string]string{ + //by color + "green": green, + "white": white, + "yellow": yellow, + "red": red, + //by method + "GET": blue, + "POST": cyan, + "PUT": yellow, + "DELETE": red, + "PATCH": green, + "HEAD": magenta, + "OPTIONS": white, + } +} + // ColorByStatus return color by http code // 2xx return Green // 3xx return White // 4xx return Yellow // 5xx return Red -func ColorByStatus(cond bool, code int) string { +func ColorByStatus(code int) string { + once.Do(initColor) switch { case code >= 200 && code < 300: - return map[bool]string{true: green, false: w32Green}[cond] + return colorMap["green"] case code >= 300 && code < 400: - return map[bool]string{true: white, false: w32White}[cond] + return colorMap["white"] case code >= 400 && code < 500: - return map[bool]string{true: yellow, false: w32Yellow}[cond] + return colorMap["yellow"] default: - return map[bool]string{true: red, false: w32Red}[cond] + return colorMap["red"] } } // ColorByMethod return color by http code -// GET return Blue -// POST return Cyan -// PUT return Yellow -// DELETE return Red -// PATCH return Green -// HEAD return Magenta -// OPTIONS return WHITE -func ColorByMethod(cond bool, method string) string { - switch method { - case "GET": - return map[bool]string{true: blue, false: w32Blue}[cond] - case "POST": - return map[bool]string{true: cyan, false: w32Cyan}[cond] - case "PUT": - return map[bool]string{true: yellow, false: w32Yellow}[cond] - case "DELETE": - return map[bool]string{true: red, false: w32Red}[cond] - case "PATCH": - return map[bool]string{true: green, false: w32Green}[cond] - case "HEAD": - return map[bool]string{true: magenta, false: w32Magenta}[cond] - case "OPTIONS": - return map[bool]string{true: white, false: w32White}[cond] - default: - return reset +func ColorByMethod(method string) string { + once.Do(initColor) + if c := colorMap[method]; c != "" { + return c } + return reset } -// Guard Mutex to guarantee atomic of W32Debug(string) function -var mu sync.Mutex - -// W32Debug Helper method to output colored logs in Windows terminals -func W32Debug(msg string) { - mu.Lock() - defer mu.Unlock() - - current := time.Now() - w := NewAnsiColorWriter(os.Stdout) - - fmt.Fprintf(w, "[beego] %v %s\n", current.Format("2006/01/02 - 15:04:05"), msg) +// ResetColor return reset color +func ResetColor() string { + return reset } diff --git a/logs/logger_test.go b/logs/logger_test.go index 78c67737..15be500d 100644 --- a/logs/logger_test.go +++ b/logs/logger_test.go @@ -15,7 +15,6 @@ package logs import ( - "bytes" "testing" "time" ) @@ -56,20 +55,3 @@ func TestFormatHeader_1(t *testing.T) { tm = tm.Add(dur) } } - -func TestNewAnsiColor1(t *testing.T) { - inner := bytes.NewBufferString("") - w := NewAnsiColorWriter(inner) - if w == inner { - t.Errorf("Get %#v, want %#v", w, inner) - } -} - -func TestNewAnsiColor2(t *testing.T) { - inner := bytes.NewBufferString("") - w1 := NewAnsiColorWriter(inner) - w2 := NewAnsiColorWriter(w1) - if w1 != w2 { - t.Errorf("Get %#v, want %#v", w1, w2) - } -} diff --git a/migration/ddl.go b/migration/ddl.go index 9313acf8..cd2c1c49 100644 --- a/migration/ddl.go +++ b/migration/ddl.go @@ -17,7 +17,7 @@ package migration import ( "fmt" - "github.com/astaxie/beego" + "github.com/astaxie/beego/logs" ) // Index struct defines the structure of Index Columns @@ -316,7 +316,7 @@ func (m *Migration) GetSQL() (sql string) { sql += fmt.Sprintf("ALTER TABLE `%s` ", m.TableName) for index, column := range m.Columns { if !column.remove { - beego.BeeLogger.Info("col") + logs.Info("col") sql += fmt.Sprintf("\n ADD `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default) } else { sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name) diff --git a/migration/migration.go b/migration/migration.go index 97e10c2e..5ddfd972 100644 --- a/migration/migration.go +++ b/migration/migration.go @@ -176,8 +176,9 @@ func Register(name string, m Migrationer) error { func Upgrade(lasttime int64) error { sm := sortMap(migrationMap) i := 0 + migs, _ := getAllMigrations() for _, v := range sm { - if v.created > lasttime { + if _, ok := migs[v.name]; !ok { logs.Info("start upgrade", v.name) v.m.Reset() v.m.Up() @@ -310,3 +311,20 @@ func isRollBack(name string) bool { } return false } +func getAllMigrations() (map[string]string, error) { + o := orm.NewOrm() + var maps []orm.Params + migs := make(map[string]string) + num, err := o.Raw("select * from migrations order by id_migration desc").Values(&maps) + if err != nil { + logs.Info("get name has error", err) + return migs, err + } + if num > 0 { + for _, v := range maps { + name := v["name"].(string) + migs[name] = v["status"].(string) + } + } + return migs, nil +} diff --git a/namespace.go b/namespace.go index 72f22a72..4952c9d5 100644 --- a/namespace.go +++ b/namespace.go @@ -207,11 +207,11 @@ func (n *Namespace) Include(cList ...ControllerInterface) *Namespace { func (n *Namespace) Namespace(ns ...*Namespace) *Namespace { for _, ni := range ns { for k, v := range ni.handlers.routers { - if t, ok := n.handlers.routers[k]; ok { + if _, ok := n.handlers.routers[k]; ok { addPrefix(v, ni.prefix) n.handlers.routers[k].AddTree(ni.prefix, v) } else { - t = NewTree() + t := NewTree() t.AddTree(ni.prefix, v) addPrefix(t, ni.prefix) n.handlers.routers[k] = t @@ -236,11 +236,11 @@ func (n *Namespace) Namespace(ns ...*Namespace) *Namespace { func AddNamespace(nl ...*Namespace) { for _, n := range nl { for k, v := range n.handlers.routers { - if t, ok := BeeApp.Handlers.routers[k]; ok { + if _, ok := BeeApp.Handlers.routers[k]; ok { addPrefix(v, n.prefix) BeeApp.Handlers.routers[k].AddTree(n.prefix, v) } else { - t = NewTree() + t := NewTree() t.AddTree(n.prefix, v) addPrefix(t, n.prefix) BeeApp.Handlers.routers[k] = t diff --git a/orm/db.go b/orm/db.go index dfaa5f1d..2148daaa 100644 --- a/orm/db.go +++ b/orm/db.go @@ -621,6 +621,31 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. return 0, err } + var findAutoNowAdd, findAutoNow bool + var index int + for i, col := range setNames { + if mi.fields.GetByColumn(col).autoNowAdd { + index = i + findAutoNowAdd = true + } + if mi.fields.GetByColumn(col).autoNow { + findAutoNow = true + } + } + if findAutoNowAdd { + setNames = append(setNames[0:index], setNames[index+1:]...) + setValues = append(setValues[0:index], setValues[index+1:]...) + } + + if !findAutoNow { + for col, info := range mi.fields.columns { + if info.autoNow { + setNames = append(setNames, col) + setValues = append(setValues, time.Now()) + } + } + } + setValues = append(setValues, pkValue) Q := d.ins.TableQuote() diff --git a/orm/db_alias.go b/orm/db_alias.go index a43e70e3..51ce10f3 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -15,6 +15,7 @@ package orm import ( + "context" "database/sql" "fmt" "reflect" @@ -103,6 +104,96 @@ func (ac *_dbCache) getDefault() (al *alias) { return } +type DB struct { + *sync.RWMutex + DB *sql.DB + stmts map[string]*sql.Stmt +} + +func (d *DB) Begin() (*sql.Tx, error) { + return d.DB.Begin() +} + +func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { + return d.DB.BeginTx(ctx, opts) +} + +func (d *DB) getStmt(query string) (*sql.Stmt, error) { + d.RLock() + if stmt, ok := d.stmts[query]; ok { + d.RUnlock() + return stmt, nil + } + d.RUnlock() + + stmt, err := d.Prepare(query) + if err != nil { + return nil, err + } + d.Lock() + d.stmts[query] = stmt + d.Unlock() + return stmt, nil +} + +func (d *DB) Prepare(query string) (*sql.Stmt, error) { + return d.DB.Prepare(query) +} + +func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + return d.DB.PrepareContext(ctx, query) +} + +func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { + stmt, err := d.getStmt(query) + if err != nil { + return nil, err + } + return stmt.Exec(args...) +} + +func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + stmt, err := d.getStmt(query) + if err != nil { + return nil, err + } + return stmt.ExecContext(ctx, args...) +} + +func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { + stmt, err := d.getStmt(query) + if err != nil { + return nil, err + } + return stmt.Query(args...) +} + +func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + stmt, err := d.getStmt(query) + if err != nil { + return nil, err + } + return stmt.QueryContext(ctx, args...) +} + +func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { + stmt, err := d.getStmt(query) + if err != nil { + panic(err) + } + return stmt.QueryRow(args...) + +} + +func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + + stmt, err := d.getStmt(query) + if err != nil { + panic(err) + } + return stmt.QueryRowContext(ctx, args) +} + type alias struct { Name string Driver DriverType @@ -110,7 +201,7 @@ type alias struct { DataSource string MaxIdleConns int MaxOpenConns int - DB *sql.DB + DB *DB DbBaser dbBaser TZ *time.Location Engine string @@ -176,7 +267,11 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { al := new(alias) al.Name = aliasName al.DriverName = driverName - al.DB = db + al.DB = &DB{ + RWMutex: new(sync.RWMutex), + DB: db, + stmts: make(map[string]*sql.Stmt), + } if dr, ok := drivers[driverName]; ok { al.DbBaser = dbBasers[dr] @@ -272,7 +367,7 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error { func SetMaxIdleConns(aliasName string, maxIdleConns int) { al := getDbAlias(aliasName) al.MaxIdleConns = maxIdleConns - al.DB.SetMaxIdleConns(maxIdleConns) + al.DB.DB.SetMaxIdleConns(maxIdleConns) } // SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name @@ -296,7 +391,7 @@ func GetDB(aliasNames ...string) (*sql.DB, error) { } al, ok := dataBaseCache.get(name) if ok { - return al.DB, nil + return al.DB.DB, nil } return nil, fmt.Errorf("DataBase of alias name `%s` not found", name) } diff --git a/orm/models_boot.go b/orm/models_boot.go index badfd11b..456e5896 100644 --- a/orm/models_boot.go +++ b/orm/models_boot.go @@ -335,11 +335,11 @@ func RegisterModelWithSuffix(suffix string, models ...interface{}) { // BootStrap bootstrap models. // make all model parsed and can not add more models func BootStrap() { + modelCache.Lock() + defer modelCache.Unlock() if modelCache.done { return } - modelCache.Lock() - defer modelCache.Unlock() bootStrap() modelCache.done = true } diff --git a/orm/models_info_f.go b/orm/models_info_f.go index 479f5ae6..7044b0bd 100644 --- a/orm/models_info_f.go +++ b/orm/models_info_f.go @@ -301,7 +301,7 @@ checkType: fi.sf = sf fi.fullName = mi.fullName + mName + "." + sf.Name - fi.description = sf.Tag.Get("description") + fi.description = tags["description"] fi.null = attrs["null"] fi.index = attrs["index"] fi.auto = attrs["auto"] diff --git a/orm/models_utils.go b/orm/models_utils.go index 31f8fb5a..71127a6b 100644 --- a/orm/models_utils.go +++ b/orm/models_utils.go @@ -44,6 +44,7 @@ var supportTag = map[string]int{ "decimals": 2, "on_delete": 2, "type": 2, + "description": 2, } // get reflect.Type name with package path. @@ -65,7 +66,7 @@ func getTableName(val reflect.Value) string { return snakeString(reflect.Indirect(val).Type().Name()) } -// get table engine, mysiam or innodb. +// get table engine, myisam or innodb. func getTableEngine(val reflect.Value) string { fun := val.MethodByName("TableEngine") if fun.IsValid() { diff --git a/orm/orm.go b/orm/orm.go index bcf6e4be..11e38fd9 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -60,6 +60,7 @@ import ( "fmt" "os" "reflect" + "sync" "time" ) @@ -72,7 +73,7 @@ const ( var ( Debug = false DebugLog = NewLog(os.Stdout) - DefaultRowsLimit = 1000 + DefaultRowsLimit = -1 DefaultRelsDepth = 2 DefaultTimeLoc = time.Local ErrTxHasBegan = errors.New(" transaction already begin") @@ -522,6 +523,15 @@ func (o *orm) Driver() Driver { return driver(o.alias.Name) } +// return sql.DBStats for current database +func (o *orm) DBStats() *sql.DBStats { + if o.alias != nil && o.alias.DB != nil { + stats := o.alias.DB.DB.Stats() + return &stats + } + return nil +} + // NewOrm create new orm func NewOrm() Ormer { BootStrap() // execute only once @@ -548,7 +558,11 @@ func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { al.Name = aliasName al.DriverName = driverName - al.DB = db + al.DB = &DB{ + RWMutex: new(sync.RWMutex), + DB: db, + stmts: make(map[string]*sql.Stmt), + } detectTZ(al) diff --git a/orm/orm_log.go b/orm/orm_log.go index 2a879c13..f107bb59 100644 --- a/orm/orm_log.go +++ b/orm/orm_log.go @@ -29,6 +29,9 @@ type Log struct { *log.Logger } +//costomer log func +var LogFunc func(query map[string]interface{}) + // NewLog set io.Writer to create a Logger. func NewLog(out io.Writer) *Log { d := new(Log) @@ -37,12 +40,15 @@ func NewLog(out io.Writer) *Log { } func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) { + var logMap = make(map[string]interface{}) sub := time.Now().Sub(t) / 1e5 elsp := float64(int(sub)) / 10.0 + logMap["cost_time"] = elsp flag := " OK" if err != nil { flag = "FAIL" } + logMap["flag"] = flag con := fmt.Sprintf(" -[Queries/%s] - [%s / %11s / %7.1fms] - [%s]", alias.Name, flag, operaton, elsp, query) cons := make([]string, 0, len(args)) for _, arg := range args { @@ -54,6 +60,10 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error if err != nil { con += " - " + err.Error() } + logMap["sql"] = fmt.Sprintf("%s-`%s`", query, strings.Join(cons, "`, `")) + if LogFunc != nil{ + LogFunc(logMap) + } DebugLog.Println(con) } diff --git a/orm/orm_raw.go b/orm/orm_raw.go index c8ef4398..3325a7ea 100644 --- a/orm/orm_raw.go +++ b/orm/orm_raw.go @@ -150,8 +150,10 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { case reflect.Struct: if value == nil { ind.Set(reflect.Zero(ind.Type())) - - } else if _, ok := ind.Interface().(time.Time); ok { + return + } + switch ind.Interface().(type) { + case time.Time: var str string switch d := value.(type) { case time.Time: @@ -178,7 +180,25 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { } } } + case sql.NullString, sql.NullInt64, sql.NullFloat64, sql.NullBool: + indi := reflect.New(ind.Type()).Interface() + sc, ok := indi.(sql.Scanner) + if !ok { + return + } + err := sc.Scan(value) + if err == nil { + ind.Set(reflect.Indirect(reflect.ValueOf(sc))) + } } + + case reflect.Ptr: + if value == nil { + ind.Set(reflect.Zero(ind.Type())) + break + } + ind.Set(reflect.New(ind.Type().Elem())) + o.setFieldValue(reflect.Indirect(ind), value) } } diff --git a/orm/orm_test.go b/orm/orm_test.go index 89a714b6..bdb430b6 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -458,6 +458,15 @@ func TestNullDataTypes(t *testing.T) { throwFail(t, AssertIs((*d.TimePtr).UTC().Format(testTime), timePtr.UTC().Format(testTime))) throwFail(t, AssertIs((*d.DatePtr).UTC().Format(testDate), datePtr.UTC().Format(testDate))) throwFail(t, AssertIs((*d.DateTimePtr).UTC().Format(testDateTime), dateTimePtr.UTC().Format(testDateTime))) + + // test support for pointer fields using RawSeter.QueryRows() + var dnList []*DataNull + Q := dDbBaser.TableQuote() + num, err = dORM.Raw(fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q), 3).QueryRows(&dnList) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + equal := reflect.DeepEqual(*dnList[0], d) + throwFailNow(t, AssertIs(equal, true)) } func TestDataCustomTypes(t *testing.T) { @@ -1679,6 +1688,31 @@ func TestRawQueryRow(t *testing.T) { throwFail(t, AssertIs(uid, 4)) throwFail(t, AssertIs(*status, 3)) throwFail(t, AssertIs(pid, nil)) + + // test for sql.Null* fields + nData := &DataNull{ + NullString: sql.NullString{String: "test sql.null", Valid: true}, + NullBool: sql.NullBool{Bool: true, Valid: true}, + NullInt64: sql.NullInt64{Int64: 42, Valid: true}, + NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, + } + newId, err := dORM.Insert(nData) + throwFailNow(t, err) + + var nd *DataNull + query = fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q) + err = dORM.Raw(query, newId).QueryRow(&nd) + throwFailNow(t, err) + + throwFailNow(t, AssertNot(nd, nil)) + throwFail(t, AssertIs(nd.NullBool.Valid, true)) + throwFail(t, AssertIs(nd.NullBool.Bool, true)) + throwFail(t, AssertIs(nd.NullString.Valid, true)) + throwFail(t, AssertIs(nd.NullString.String, "test sql.null")) + throwFail(t, AssertIs(nd.NullInt64.Valid, true)) + throwFail(t, AssertIs(nd.NullInt64.Int64, 42)) + throwFail(t, AssertIs(nd.NullFloat64.Valid, true)) + throwFail(t, AssertIs(nd.NullFloat64.Float64, 42.42)) } // user_profile table @@ -1771,6 +1805,32 @@ func TestQueryRows(t *testing.T) { throwFailNow(t, AssertIs(l[1].UserName, "astaxie")) throwFailNow(t, AssertIs(l[1].Age, 30)) + // test for sql.Null* fields + nData := &DataNull{ + NullString: sql.NullString{String: "test sql.null", Valid: true}, + NullBool: sql.NullBool{Bool: true, Valid: true}, + NullInt64: sql.NullInt64{Int64: 42, Valid: true}, + NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, + } + newId, err := dORM.Insert(nData) + throwFailNow(t, err) + + var nDataList []*DataNull + query = fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q) + num, err = dORM.Raw(query, newId).QueryRows(&nDataList) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + nd := nDataList[0] + throwFailNow(t, AssertNot(nd, nil)) + throwFail(t, AssertIs(nd.NullBool.Valid, true)) + throwFail(t, AssertIs(nd.NullBool.Bool, true)) + throwFail(t, AssertIs(nd.NullString.Valid, true)) + throwFail(t, AssertIs(nd.NullString.String, "test sql.null")) + throwFail(t, AssertIs(nd.NullInt64.Valid, true)) + throwFail(t, AssertIs(nd.NullInt64.Int64, 42)) + throwFail(t, AssertIs(nd.NullFloat64.Valid, true)) + throwFail(t, AssertIs(nd.NullFloat64.Float64, 42.42)) } func TestRawValues(t *testing.T) { diff --git a/orm/types.go b/orm/types.go index ddf39a2b..2fd10774 100644 --- a/orm/types.go +++ b/orm/types.go @@ -55,7 +55,7 @@ type Ormer interface { // for example: // user := new(User) // id, err = Ormer.Insert(user) - // user must a pointer and Insert will set user's pk field + // user must be a pointer and Insert will set user's pk field Insert(interface{}) (int64, error) // mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value") // if colu type is integer : can use(+-*/), string : convert(colu,"value") @@ -128,6 +128,7 @@ type Ormer interface { // // update user testing's name to slene Raw(query string, args ...interface{}) RawSeter Driver() Driver + DBStats() *sql.DBStats } // Inserter insert prepared statement diff --git a/parser.go b/parser.go index a8690274..5e6b9111 100644 --- a/parser.go +++ b/parser.go @@ -35,7 +35,7 @@ import ( "github.com/astaxie/beego/utils" ) -var globalRouterTemplate = `package routers +var globalRouterTemplate = `package {{.routersDir}} import ( "github.com/astaxie/beego" @@ -459,13 +459,17 @@ func genRouterCode(pkgRealpath string) { imports := "" if len(c.ImportComments) > 0 { for _, i := range c.ImportComments { + var s string if i.ImportAlias != "" { - imports += fmt.Sprintf(` + s = fmt.Sprintf(` %s "%s"`, i.ImportAlias, i.ImportPath) } else { - imports += fmt.Sprintf(` + s = fmt.Sprintf(` "%s"`, i.ImportPath) } + if !strings.Contains(globalimport, s) { + imports += s + } } } @@ -490,7 +494,7 @@ func genRouterCode(pkgRealpath string) { }`, filters) } - globalimport = imports + globalimport += imports globalinfo = globalinfo + ` beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"], @@ -512,7 +516,9 @@ func genRouterCode(pkgRealpath string) { } defer f.Close() + routersDir := AppConfig.DefaultString("routersdir", "routers") content := strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1) + content = strings.Replace(content, "{{.routersDir}}", routersDir, -1) content = strings.Replace(content, "{{.globalimport}}", globalimport, -1) f.WriteString(content) } @@ -570,7 +576,8 @@ func getpathTime(pkgRealpath string) (lastupdate int64, err error) { func getRouterDir(pkgRealpath string) string { dir := filepath.Dir(pkgRealpath) for { - d := filepath.Join(dir, "routers") + routersDir := AppConfig.DefaultString("routersdir", "routers") + d := filepath.Join(dir, routersDir) if utils.FileExists(d) { return d } diff --git a/plugins/apiauth/apiauth.go b/plugins/apiauth/apiauth.go index f816029c..10e25f3f 100644 --- a/plugins/apiauth/apiauth.go +++ b/plugins/apiauth/apiauth.go @@ -72,8 +72,8 @@ import ( // AppIDToAppSecret is used to get appsecret throw appid type AppIDToAppSecret func(string) string -// APIBaiscAuth use the basic appid/appkey as the AppIdToAppSecret -func APIBaiscAuth(appid, appkey string) beego.FilterFunc { +// APIBasicAuth use the basic appid/appkey as the AppIdToAppSecret +func APIBasicAuth(appid, appkey string) beego.FilterFunc { ft := func(aid string) string { if aid == appid { return appkey @@ -83,6 +83,11 @@ func APIBaiscAuth(appid, appkey string) beego.FilterFunc { return APISecretAuth(ft, 300) } +// APIBaiscAuth calls APIBasicAuth for previous callers +func APIBaiscAuth(appid, appkey string) beego.FilterFunc { + return APIBasicAuth(appid, appkey) +} + // APISecretAuth use AppIdToAppSecret verify and func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc { return func(ctx *context.Context) { diff --git a/router.go b/router.go index 997b6854..3593be4c 100644 --- a/router.go +++ b/router.go @@ -15,12 +15,12 @@ package beego import ( + "errors" "fmt" "net/http" "path" "path/filepath" "reflect" - "runtime" "strconv" "strings" "sync" @@ -479,8 +479,7 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter // add Filter into func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) { if pos < BeforeStatic || pos > FinishRouter { - err = fmt.Errorf("can not find your filter position") - return + return errors.New("can not find your filter position") } p.enableFilter = true p.filters[pos] = append(p.filters[pos], mr) @@ -510,10 +509,10 @@ func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) stri } } } - controllName := strings.Join(paths[:len(paths)-1], "/") + controllerName := strings.Join(paths[:len(paths)-1], "/") methodName := paths[len(paths)-1] for m, t := range p.routers { - ok, url := p.geturl(t, "/", controllName, methodName, params, m) + ok, url := p.getURL(t, "/", controllerName, methodName, params, m) if ok { return url } @@ -521,17 +520,17 @@ func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) stri return "" } -func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName string, params map[string]string, httpMethod string) (bool, string) { +func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName string, params map[string]string, httpMethod string) (bool, string) { for _, subtree := range t.fixrouters { u := path.Join(url, subtree.prefix) - ok, u := p.geturl(subtree, u, controllName, methodName, params, httpMethod) + ok, u := p.getURL(subtree, u, controllerName, methodName, params, httpMethod) if ok { return ok, u } } if t.wildcard != nil { u := path.Join(url, urlPlaceholder) - ok, u := p.geturl(t.wildcard, u, controllName, methodName, params, httpMethod) + ok, u := p.getURL(t.wildcard, u, controllerName, methodName, params, httpMethod) if ok { return ok, u } @@ -539,7 +538,7 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin for _, l := range t.leaves { if c, ok := l.runObject.(*ControllerInfo); ok { if c.routerType == routerTypeBeego && - strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllName) { + strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllerName) { find := false if HTTPMETHOD[strings.ToUpper(methodName)] { if len(c.methods) == 0 { @@ -578,18 +577,18 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin } } } - canskip := false + canSkip := false for _, v := range l.wildcards { if v == ":" { - canskip = true + canSkip = true continue } if u, ok := params[v]; ok { delete(params, v) url = strings.Replace(url, urlPlaceholder, u, 1) } else { - if canskip { - canskip = false + if canSkip { + canSkip = false continue } return false, "" @@ -598,27 +597,27 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin return true, url + toURL(params) } var i int - var startreg bool - regurl := "" + var startReg bool + regURL := "" for _, v := range strings.Trim(l.regexps.String(), "^$") { if v == '(' { - startreg = true + startReg = true continue } else if v == ')' { - startreg = false + startReg = false if v, ok := params[l.wildcards[i]]; ok { delete(params, l.wildcards[i]) - regurl = regurl + v + regURL = regURL + v i++ } else { break } - } else if !startreg { - regurl = string(append([]rune(regurl), v)) + } else if !startReg { + regURL = string(append([]rune(regURL), v)) } } - if l.regexps.MatchString(regurl) { - ps := strings.Split(regurl, "/") + if l.regexps.MatchString(regURL) { + ps := strings.Split(regURL, "/") for _, p := range ps { url = strings.Replace(url, urlPlaceholder, p, 1) } @@ -690,7 +689,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) // filter wrong http method if !HTTPMETHOD[r.Method] { - http.Error(rw, "Method Not Allowed", 405) + exception("405", context) goto Admin } @@ -779,7 +778,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) runRouter = routerInfo.controllerType methodParams = routerInfo.methodParams method := r.Method - if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodPost { + if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodPut { method = http.MethodPut } if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodDelete { @@ -844,6 +843,8 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) execController.Patch() case http.MethodOptions: execController.Options() + case http.MethodTrace: + execController.Trace() default: if !execController.HandlerFunc(runMethod) { vc := reflect.ValueOf(execController) @@ -889,7 +890,7 @@ Admin: statusCode = 200 } - logAccess(context, &startTime, statusCode) + LogAccess(context, &startTime, statusCode) timeDur := time.Since(startTime) context.ResponseWriter.Elapsed = timeDur @@ -900,38 +901,28 @@ Admin: } if FilterMonitorFunc(r.Method, r.URL.Path, timeDur, pattern, statusCode) { + routerName := "" if runRouter != nil { - go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, runRouter.Name(), timeDur) - } else { - go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, "", timeDur) + routerName = runRouter.Name() } + go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, routerName, timeDur) } } if BConfig.RunMode == DEV && !BConfig.Log.AccessLogs { - var devInfo string - iswin := (runtime.GOOS == "windows") - statusColor := logs.ColorByStatus(iswin, statusCode) - methodColor := logs.ColorByMethod(iswin, r.Method) - resetColor := logs.ColorByMethod(iswin, "") - if findRouter { - if routerInfo != nil { - devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s r:%s", context.Input.IP(), statusColor, statusCode, - resetColor, timeDur.String(), "match", methodColor, r.Method, resetColor, r.URL.Path, - routerInfo.pattern) - } else { - devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", context.Input.IP(), statusColor, statusCode, resetColor, - timeDur.String(), "match", methodColor, r.Method, resetColor, r.URL.Path) - } - } else { - devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", context.Input.IP(), statusColor, statusCode, resetColor, - timeDur.String(), "nomatch", methodColor, r.Method, resetColor, r.URL.Path) - } - if iswin { - logs.W32Debug(devInfo) - } else { - logs.Debug(devInfo) + match := map[bool]string{true: "match", false: "nomatch"} + devInfo := fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", + context.Input.IP(), + logs.ColorByStatus(statusCode), statusCode, logs.ResetColor(), + timeDur.String(), + match[findRouter], + logs.ColorByMethod(r.Method), r.Method, logs.ResetColor(), + r.URL.Path) + if routerInfo != nil { + devInfo += fmt.Sprintf(" r:%s", routerInfo.pattern) } + + logs.Debug(devInfo) } // Call WriteHeader if status code has been set changed if context.Output.Status != 0 { @@ -980,7 +971,8 @@ func toURL(params map[string]string) string { return strings.TrimRight(u, "&") } -func logAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) { +// LogAccess logging info HTTP Access +func LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) { //Skip logging if AccessLogs config is false if !BConfig.Log.AccessLogs { return diff --git a/router_test.go b/router_test.go index 90104427..2797b33a 100644 --- a/router_test.go +++ b/router_test.go @@ -71,10 +71,6 @@ func (tc *TestController) GetEmptyBody() { tc.Ctx.Output.Body(res) } -type ResStatus struct { - Code int - Msg string -} type JSONController struct { Controller @@ -475,7 +471,7 @@ func TestParamResetFilter(t *testing.T) { // a response header of `Splat`. The expectation here is that that Header // value should match what the _request's_ router set, not the filter's. - headers := rw.HeaderMap + headers := rw.Result().Header if len(headers["Splat"]) != 1 { t.Errorf( "%s: There was an error in the test. Splat param not set in Header", @@ -660,25 +656,16 @@ func beegoBeforeRouter1(ctx *context.Context) { ctx.WriteString("|BeforeRouter1") } -func beegoBeforeRouter2(ctx *context.Context) { - ctx.WriteString("|BeforeRouter2") -} func beegoBeforeExec1(ctx *context.Context) { ctx.WriteString("|BeforeExec1") } -func beegoBeforeExec2(ctx *context.Context) { - ctx.WriteString("|BeforeExec2") -} func beegoAfterExec1(ctx *context.Context) { ctx.WriteString("|AfterExec1") } -func beegoAfterExec2(ctx *context.Context) { - ctx.WriteString("|AfterExec2") -} func beegoFinishRouter1(ctx *context.Context) { ctx.WriteString("|FinishRouter1") diff --git a/session/ledis/ledis_session.go b/session/ledis/ledis_session.go index 77685d1e..c0d4bf82 100644 --- a/session/ledis/ledis_session.go +++ b/session/ledis/ledis_session.go @@ -133,7 +133,7 @@ func (lp *Provider) SessionRead(sid string) (session.Store, error) { // SessionExist check ledis session exist by sid func (lp *Provider) SessionExist(sid string) bool { count, _ := c.Exists([]byte(sid)) - return !(count == 0) + return count != 0 } // SessionRegenerate generate new sid for ledis session diff --git a/session/memcache/sess_memcache.go b/session/memcache/sess_memcache.go index 755979c4..85a2d815 100644 --- a/session/memcache/sess_memcache.go +++ b/session/memcache/sess_memcache.go @@ -128,9 +128,12 @@ func (rp *MemProvider) SessionRead(sid string) (session.Store, error) { } } item, err := client.Get(sid) - if err != nil && err == memcache.ErrCacheMiss { - rs := &SessionStore{sid: sid, values: make(map[interface{}]interface{}), maxlifetime: rp.maxlifetime} - return rs, nil + if err != nil { + if err == memcache.ErrCacheMiss { + rs := &SessionStore{sid: sid, values: make(map[interface{}]interface{}), maxlifetime: rp.maxlifetime} + return rs, nil + } + return nil, err } var kv map[interface{}]interface{} if len(item.Value) == 0 { diff --git a/session/mysql/sess_mysql.go b/session/mysql/sess_mysql.go index 4c9251e7..301353ab 100644 --- a/session/mysql/sess_mysql.go +++ b/session/mysql/sess_mysql.go @@ -170,7 +170,7 @@ func (mp *Provider) SessionExist(sid string) bool { row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) var sessiondata []byte err := row.Scan(&sessiondata) - return !(err == sql.ErrNoRows) + return err != sql.ErrNoRows } // SessionRegenerate generate new sid for mysql session diff --git a/session/postgres/sess_postgresql.go b/session/postgres/sess_postgresql.go index ffc27def..0b8b9645 100644 --- a/session/postgres/sess_postgresql.go +++ b/session/postgres/sess_postgresql.go @@ -184,7 +184,7 @@ func (mp *Provider) SessionExist(sid string) bool { row := c.QueryRow("select session_data from session where session_key=$1", sid) var sessiondata []byte err := row.Scan(&sessiondata) - return !(err == sql.ErrNoRows) + return err != sql.ErrNoRows } // SessionRegenerate generate new sid for postgresql session diff --git a/session/redis_sentinel/sess_redis_sentinel.go b/session/redis_sentinel/sess_redis_sentinel.go new file mode 100644 index 00000000..6ecb2977 --- /dev/null +++ b/session/redis_sentinel/sess_redis_sentinel.go @@ -0,0 +1,234 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for session provider +// +// depend on github.com/go-redis/redis +// +// go install github.com/go-redis/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/redis_sentinel" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("redis_sentinel", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:26379;127.0.0.2:26379"}``) +// go globalSessions.GC() +// } +// +// more detail about params: please check the notes on the function SessionInit in this package +package redis_sentinel + +import ( + "github.com/astaxie/beego/session" + "github.com/go-redis/redis" + "net/http" + "strconv" + "strings" + "sync" + "time" +) + +var redispder = &Provider{} + +// DefaultPoolSize redis_sentinel default pool size +var DefaultPoolSize = 100 + +// SessionStore redis_sentinel session store +type SessionStore struct { + p *redis.Client + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxlifetime int64 +} + +// Set value in redis_sentinel session +func (rs *SessionStore) Set(key, value interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values[key] = value + return nil +} + +// Get value in redis_sentinel session +func (rs *SessionStore) Get(key interface{}) interface{} { + rs.lock.RLock() + defer rs.lock.RUnlock() + if v, ok := rs.values[key]; ok { + return v + } + return nil +} + +// Delete value in redis_sentinel session +func (rs *SessionStore) Delete(key interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + delete(rs.values, key) + return nil +} + +// Flush clear all values in redis_sentinel session +func (rs *SessionStore) Flush() error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get redis_sentinel session id +func (rs *SessionStore) SessionID() string { + return rs.sid +} + +// SessionRelease save session values to redis_sentinel +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + b, err := session.EncodeGob(rs.values) + if err != nil { + return + } + c := rs.p + c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) +} + +// Provider redis_sentinel session provider +type Provider struct { + maxlifetime int64 + savePath string + poolsize int + password string + dbNum int + poollist *redis.Client + masterName string +} + +// SessionInit init redis_sentinel session +// savepath like redis sentinel addr,pool size,password,dbnum,masterName +// e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster +func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { + rp.maxlifetime = maxlifetime + configs := strings.Split(savePath, ",") + if len(configs) > 0 { + rp.savePath = configs[0] + } + if len(configs) > 1 { + poolsize, err := strconv.Atoi(configs[1]) + if err != nil || poolsize < 0 { + rp.poolsize = DefaultPoolSize + } else { + rp.poolsize = poolsize + } + } else { + rp.poolsize = DefaultPoolSize + } + if len(configs) > 2 { + rp.password = configs[2] + } + if len(configs) > 3 { + dbnum, err := strconv.Atoi(configs[3]) + if err != nil || dbnum < 0 { + rp.dbNum = 0 + } else { + rp.dbNum = dbnum + } + } else { + rp.dbNum = 0 + } + if len(configs) > 4 { + if configs[4] != "" { + rp.masterName = configs[4] + } else { + rp.masterName = "mymaster" + } + } else { + rp.masterName = "mymaster" + } + + rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{ + SentinelAddrs: strings.Split(rp.savePath, ";"), + Password: rp.password, + PoolSize: rp.poolsize, + DB: rp.dbNum, + MasterName: rp.masterName, + }) + + return rp.poollist.Ping().Err() +} + +// SessionRead read redis_sentinel session by sid +func (rp *Provider) SessionRead(sid string) (session.Store, error) { + var kv map[interface{}]interface{} + kvs, err := rp.poollist.Get(sid).Result() + if err != nil && err != redis.Nil { + return nil, err + } + if len(kvs) == 0 { + kv = make(map[interface{}]interface{}) + } else { + if kv, err = session.DecodeGob([]byte(kvs)); err != nil { + return nil, err + } + } + + rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} + return rs, nil +} + +// SessionExist check redis_sentinel session exist by sid +func (rp *Provider) SessionExist(sid string) bool { + c := rp.poollist + if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { + return false + } + return true +} + +// SessionRegenerate generate new sid for redis_sentinel session +func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + c := rp.poollist + + if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 { + // oldsid doesn't exists, set the new sid directly + // ignore error here, since if it return error + // the existed value will be 0 + c.Set(sid, "", time.Duration(rp.maxlifetime)*time.Second) + } else { + c.Rename(oldsid, sid) + c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second) + } + return rp.SessionRead(sid) +} + +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(sid string) error { + c := rp.poollist + c.Del(sid) + return nil +} + +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC() { +} + +// SessionAll return all activeSession +func (rp *Provider) SessionAll() int { + return 0 +} + +func init() { + session.Register("redis_sentinel", redispder) +} diff --git a/session/redis_sentinel/sess_redis_sentinel_test.go b/session/redis_sentinel/sess_redis_sentinel_test.go new file mode 100644 index 00000000..fd4155c6 --- /dev/null +++ b/session/redis_sentinel/sess_redis_sentinel_test.go @@ -0,0 +1,90 @@ +package redis_sentinel + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/astaxie/beego/session" +) + +func TestRedisSentinel(t *testing.T) { + sessionConfig := &session.ManagerConfig{ + CookieName: "gosessionid", + EnableSetCookie: true, + Gclifetime: 3600, + Maxlifetime: 3600, + Secure: false, + CookieLifeTime: 3600, + ProviderConfig: "127.0.0.1:6379,100,,0,master", + } + globalSessions, e := session.NewManager("redis_sentinel", sessionConfig) + if e != nil { + t.Log(e) + return + } + //todo test if e==nil + go globalSessions.GC() + + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("session start failed:", err) + } + defer sess.SessionRelease(w) + + // SET AND GET + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set username failed:", err) + } + username := sess.Get("username") + if username != "astaxie" { + t.Fatal("get username failed") + } + + // DELETE + err = sess.Delete("username") + if err != nil { + t.Fatal("delete username failed:", err) + } + username = sess.Get("username") + if username != nil { + t.Fatal("delete username failed") + } + + // FLUSH + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set failed:", err) + } + err = sess.Set("password", "1qaz2wsx") + if err != nil { + t.Fatal("set failed:", err) + } + username = sess.Get("username") + if username != "astaxie" { + t.Fatal("get username failed") + } + password := sess.Get("password") + if password != "1qaz2wsx" { + t.Fatal("get password failed") + } + err = sess.Flush() + if err != nil { + t.Fatal("flush failed:", err) + } + username = sess.Get("username") + if username != nil { + t.Fatal("flush failed") + } + password = sess.Get("password") + if password != nil { + t.Fatal("flush failed") + } + + sess.SessionRelease(w) + +} diff --git a/session/sess_file.go b/session/sess_file.go index c089dade..db143522 100644 --- a/session/sess_file.go +++ b/session/sess_file.go @@ -19,6 +19,7 @@ import ( "io/ioutil" "net/http" "os" + "errors" "path" "path/filepath" "strings" @@ -131,6 +132,9 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) { if strings.ContainsAny(sid, "./") { return nil, nil } + if len(sid) < 2 { + return nil, errors.New("length of the sid is less than 2") + } filepder.lock.Lock() defer filepder.lock.Unlock() diff --git a/session/session.go b/session/session.go index c7e7dc69..46a9f1f0 100644 --- a/session/session.go +++ b/session/session.go @@ -81,6 +81,15 @@ func Register(name string, provide Provider) { provides[name] = provide } +//GetProvider +func GetProvider(name string) (Provider, error) { + provider, ok := provides[name] + if !ok { + return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", name) + } + return provider, nil +} + // ManagerConfig define the session config type ManagerConfig struct { CookieName string `json:"cookieName"` diff --git a/template.go b/template.go index cf41cb9b..59875be7 100644 --- a/template.go +++ b/template.go @@ -38,7 +38,7 @@ var ( beeViewPathTemplates = make(map[string]map[string]*template.Template) templatesLock sync.RWMutex // beeTemplateExt stores the template extension which will build - beeTemplateExt = []string{"tpl", "html"} + beeTemplateExt = []string{"tpl", "html", "gohtml"} // beeTemplatePreprocessors stores associations of extension -> preprocessor handler beeTemplateEngines = map[string]templatePreProcessor{} beeTemplateFS = defaultFSFunc @@ -240,7 +240,7 @@ func getTplDeep(root string, fs http.FileSystem, file string, parent string, t * var fileAbsPath string var rParent string var err error - if filepath.HasPrefix(file, "../") { + if strings.HasPrefix(file, "../") { rParent = filepath.Join(filepath.Dir(parent), file) fileAbsPath = filepath.Join(root, filepath.Dir(parent), file) } else { @@ -248,10 +248,10 @@ func getTplDeep(root string, fs http.FileSystem, file string, parent string, t * fileAbsPath = filepath.Join(root, file) } f, err := fs.Open(fileAbsPath) - defer f.Close() if err != nil { panic("can't find template file:" + file) } + defer f.Close() data, err := ioutil.ReadAll(f) if err != nil { return nil, [][]string{}, err diff --git a/templatefunc.go b/templatefunc.go index 8c1504aa..ba1ec5eb 100644 --- a/templatefunc.go +++ b/templatefunc.go @@ -55,21 +55,21 @@ func Substr(s string, start, length int) string { // HTML2str returns escaping text convert from html. func HTML2str(html string) string { - re, _ := regexp.Compile(`\<[\S\s]+?\>`) + re := regexp.MustCompile(`\<[\S\s]+?\>`) html = re.ReplaceAllStringFunc(html, strings.ToLower) //remove STYLE - re, _ = regexp.Compile(`\`) + re = regexp.MustCompile(`\`) html = re.ReplaceAllString(html, "") //remove SCRIPT - re, _ = regexp.Compile(`\`) + re = regexp.MustCompile(`\`) html = re.ReplaceAllString(html, "") - re, _ = regexp.Compile(`\<[\S\s]+?\>`) + re = regexp.MustCompile(`\<[\S\s]+?\>`) html = re.ReplaceAllString(html, "\n") - re, _ = regexp.Compile(`\s{2,}`) + re = regexp.MustCompile(`\s{2,}`) html = re.ReplaceAllString(html, "\n") return strings.TrimSpace(html) @@ -85,24 +85,24 @@ func DateFormat(t time.Time, layout string) (datestring string) { var datePatterns = []string{ // year "Y", "2006", // A full numeric representation of a year, 4 digits Examples: 1999 or 2003 - "y", "06", //A two digit representation of a year Examples: 99 or 03 + "y", "06", //A two digit representation of a year Examples: 99 or 03 // month - "m", "01", // Numeric representation of a month, with leading zeros 01 through 12 - "n", "1", // Numeric representation of a month, without leading zeros 1 through 12 - "M", "Jan", // A short textual representation of a month, three letters Jan through Dec + "m", "01", // Numeric representation of a month, with leading zeros 01 through 12 + "n", "1", // Numeric representation of a month, without leading zeros 1 through 12 + "M", "Jan", // A short textual representation of a month, three letters Jan through Dec "F", "January", // A full textual representation of a month, such as January or March January through December // day "d", "02", // Day of the month, 2 digits with leading zeros 01 to 31 - "j", "2", // Day of the month without leading zeros 1 to 31 + "j", "2", // Day of the month without leading zeros 1 to 31 // week - "D", "Mon", // A textual representation of a day, three letters Mon through Sun + "D", "Mon", // A textual representation of a day, three letters Mon through Sun "l", "Monday", // A full textual representation of the day of the week Sunday through Saturday // time - "g", "3", // 12-hour format of an hour without leading zeros 1 through 12 + "g", "3", // 12-hour format of an hour without leading zeros 1 through 12 "G", "15", // 24-hour format of an hour without leading zeros 0 through 23 "h", "03", // 12-hour format of an hour with leading zeros 01 through 12 "H", "15", // 24-hour format of an hour with leading zeros 00 through 23 @@ -172,7 +172,7 @@ func GetConfig(returnType, key string, defaultVal interface{}) (value interface{ case "DIY": value, err = AppConfig.DIY(key) default: - err = errors.New("Config keys must be of type String, Bool, Int, Int64, Float, or DIY") + err = errors.New("config keys must be of type String, Bool, Int, Int64, Float, or DIY") } if err != nil { @@ -297,9 +297,21 @@ func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) e tag = tags[0] } - value := form.Get(tag) - if len(value) == 0 { - continue + formValues := form[tag] + var value string + if len(formValues) == 0 { + defaultValue := fieldT.Tag.Get("default") + if defaultValue != "" { + value = defaultValue + } else { + continue + } + } + if len(formValues) == 1 { + value = formValues[0] + if value == "" { + continue + } } switch fieldT.Type.Kind() { @@ -349,6 +361,8 @@ func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) e if len(value) >= 25 { value = value[:25] t, err = time.ParseInLocation(time.RFC3339, value, time.Local) + } else if strings.HasSuffix(strings.ToUpper(value), "Z") { + t, err = time.ParseInLocation(time.RFC3339, value, time.Local) } else if len(value) >= 19 { if strings.Contains(value, "T") { value = value[:19] diff --git a/templatefunc_test.go b/templatefunc_test.go index c7b8fbd3..b4c19c2e 100644 --- a/templatefunc_test.go +++ b/templatefunc_test.go @@ -111,7 +111,7 @@ func TestHtmlunquote(t *testing.T) { func TestParseForm(t *testing.T) { type ExtendInfo struct { - Hobby string `form:"hobby"` + Hobby []string `form:"hobby"` Memo string } @@ -146,7 +146,7 @@ func TestParseForm(t *testing.T) { "date": []string{"2014-11-12"}, "organization": []string{"beego"}, "title": []string{"CXO"}, - "hobby": []string{"Basketball"}, + "hobby": []string{"", "Basketball", "Football"}, "memo": []string{"nothing"}, } if err := ParseForm(form, u); err == nil { @@ -186,8 +186,14 @@ func TestParseForm(t *testing.T) { if u.Title != "CXO" { t.Errorf("Title should equal `CXO`, but got `%v`", u.Title) } - if u.Hobby != "Basketball" { - t.Errorf("Hobby should equal `Basketball`, but got `%v`", u.Hobby) + if u.Hobby[0] != "" { + t.Errorf("Hobby should equal ``, but got `%v`", u.Hobby[0]) + } + if u.Hobby[1] != "Basketball" { + t.Errorf("Hobby should equal `Basketball`, but got `%v`", u.Hobby[1]) + } + if u.Hobby[2] != "Football" { + t.Errorf("Hobby should equal `Football`, but got `%v`", u.Hobby[2]) } if len(u.Memo) != 0 { t.Errorf("Memo's length should equal 0 but got %v", len(u.Memo)) @@ -197,7 +203,6 @@ func TestParseForm(t *testing.T) { func TestRenderForm(t *testing.T) { type user struct { ID int `form:"-"` - tag string `form:"tag"` Name interface{} `form:"username"` Age int `form:"age,text,年龄:"` Sex string diff --git a/toolbox/task.go b/toolbox/task.go index 7e841e89..d1343023 100644 --- a/toolbox/task.go +++ b/toolbox/task.go @@ -20,6 +20,7 @@ import ( "sort" "strconv" "strings" + "sync" "time" ) @@ -32,6 +33,7 @@ type bounds struct { // The bounds for each field. var ( AdminTaskList map[string]Tasker + taskLock sync.Mutex stop chan bool changed chan bool isstart bool @@ -389,6 +391,8 @@ func dayMatches(s *Schedule, t time.Time) bool { // StartTask start all tasks func StartTask() { + taskLock.Lock() + defer taskLock.Unlock() if isstart { //If already started, no need to start another goroutine. return @@ -440,6 +444,8 @@ func run() { // StopTask stop all tasks func StopTask() { + taskLock.Lock() + defer taskLock.Unlock() if isstart { isstart = false stop <- true @@ -449,6 +455,8 @@ func StopTask() { // AddTask add task with name func AddTask(taskname string, t Tasker) { + taskLock.Lock() + defer taskLock.Unlock() t.SetNext(time.Now().Local()) AdminTaskList[taskname] = t if isstart { @@ -458,6 +466,8 @@ func AddTask(taskname string, t Tasker) { // DeleteTask delete task with name func DeleteTask(taskname string) { + taskLock.Lock() + defer taskLock.Unlock() delete(AdminTaskList, taskname) if isstart { changed <- true diff --git a/utils/mail.go b/utils/mail.go index e3fa1c90..42b1e4d4 100644 --- a/utils/mail.go +++ b/utils/mail.go @@ -162,7 +162,7 @@ func (e *Email) Bytes() ([]byte, error) { // AttachFile Add attach file to the send mail func (e *Email) AttachFile(args ...string) (a *Attachment, err error) { - if len(args) < 1 && len(args) > 2 { + if len(args) < 1 || len(args) > 2 { // change && to || err = errors.New("Must specify a file name and number of parameters can not exceed at least two") return } @@ -183,7 +183,7 @@ func (e *Email) AttachFile(args ...string) (a *Attachment, err error) { // Attach is used to attach content from an io.Reader to the email. // Parameters include an io.Reader, the desired filename for the attachment, and the Content-Type. func (e *Email) Attach(r io.Reader, filename string, args ...string) (a *Attachment, err error) { - if len(args) < 1 && len(args) > 2 { + if len(args) < 1 || len(args) > 2 { // change && to || err = errors.New("Must specify the file type and number of parameters can not exceed at least two") return } diff --git a/utils/utils.go b/utils/utils.go index ed885787..3874b803 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -3,19 +3,78 @@ package utils import ( "os" "path/filepath" + "regexp" "runtime" + "strconv" "strings" ) // GetGOPATHs returns all paths in GOPATH variable. func GetGOPATHs() []string { gopath := os.Getenv("GOPATH") - if gopath == "" && strings.Compare(runtime.Version(), "go1.8") >= 0 { + if gopath == "" && compareGoVersion(runtime.Version(), "go1.8") >= 0 { gopath = defaultGOPATH() } return filepath.SplitList(gopath) } +func compareGoVersion(a, b string) int { + reg := regexp.MustCompile("^\\d*") + + a = strings.TrimPrefix(a, "go") + b = strings.TrimPrefix(b, "go") + + versionsA := strings.Split(a, ".") + versionsB := strings.Split(b, ".") + + for i := 0; i < len(versionsA) && i < len(versionsB); i++ { + versionA := versionsA[i] + versionB := versionsB[i] + + vA, err := strconv.Atoi(versionA) + if err != nil { + str := reg.FindString(versionA) + if str != "" { + vA, _ = strconv.Atoi(str) + } else { + vA = -1 + } + } + + vB, err := strconv.Atoi(versionB) + if err != nil { + str := reg.FindString(versionB) + if str != "" { + vB, _ = strconv.Atoi(str) + } else { + vB = -1 + } + } + + if vA > vB { + // vA = 12, vB = 8 + return 1 + } else if vA < vB { + // vA = 6, vB = 8 + return -1 + } else if vA == -1 { + // vA = rc1, vB = rc3 + return strings.Compare(versionA, versionB) + } + + // vA = vB = 8 + continue + } + + if len(versionsA) > len(versionsB) { + return 1 + } else if len(versionsA) == len(versionsB) { + return 0 + } + + return -1 +} + func defaultGOPATH() string { env := "HOME" if runtime.GOOS == "windows" { diff --git a/utils/utils_test.go b/utils/utils_test.go new file mode 100644 index 00000000..ced6f63f --- /dev/null +++ b/utils/utils_test.go @@ -0,0 +1,36 @@ +package utils + +import ( + "testing" +) + +func TestCompareGoVersion(t *testing.T) { + targetVersion := "go1.8" + if compareGoVersion("go1.12.4", targetVersion) != 1 { + t.Error("should be 1") + } + + if compareGoVersion("go1.8.7", targetVersion) != 1 { + t.Error("should be 1") + } + + if compareGoVersion("go1.8", targetVersion) != 0 { + t.Error("should be 0") + } + + if compareGoVersion("go1.7.6", targetVersion) != -1 { + t.Error("should be -1") + } + + if compareGoVersion("go1.12.1rc1", targetVersion) != 1 { + t.Error("should be 1") + } + + if compareGoVersion("go1.8rc1", targetVersion) != 0 { + t.Error("should be 0") + } + + if compareGoVersion("go1.7rc1", targetVersion) != -1 { + t.Error("should be -1") + } +} diff --git a/validation/validation_test.go b/validation/validation_test.go index f97105fd..3146766b 100644 --- a/validation/validation_test.go +++ b/validation/validation_test.go @@ -268,6 +268,18 @@ func TestMobile(t *testing.T) { if !valid.Mobile("+8614700008888", "mobile").Ok { t.Error("\"+8614700008888\" is a valid mobile phone number should be true") } + if !valid.Mobile("17300008888", "mobile").Ok { + t.Error("\"17300008888\" is a valid mobile phone number should be true") + } + if !valid.Mobile("+8617100008888", "mobile").Ok { + t.Error("\"+8617100008888\" is a valid mobile phone number should be true") + } + if !valid.Mobile("8617500008888", "mobile").Ok { + t.Error("\"8617500008888\" is a valid mobile phone number should be true") + } + if valid.Mobile("8617400008888", "mobile").Ok { + t.Error("\"8617400008888\" is a valid mobile phone number should be false") + } } func TestTel(t *testing.T) { @@ -453,7 +465,7 @@ func TestPointer(t *testing.T) { u := User{ ReqEmail: nil, - Email: nil, + Email: nil, } valid := Validation{} @@ -468,7 +480,7 @@ func TestPointer(t *testing.T) { validEmail := "a@a.com" u = User{ ReqEmail: &validEmail, - Email: nil, + Email: nil, } valid = Validation{RequiredFirst: true} @@ -482,7 +494,7 @@ func TestPointer(t *testing.T) { u = User{ ReqEmail: &validEmail, - Email: nil, + Email: nil, } valid = Validation{} @@ -497,7 +509,7 @@ func TestPointer(t *testing.T) { invalidEmail := "a@a" u = User{ ReqEmail: &validEmail, - Email: &invalidEmail, + Email: &invalidEmail, } valid = Validation{RequiredFirst: true} @@ -511,7 +523,7 @@ func TestPointer(t *testing.T) { u = User{ ReqEmail: &validEmail, - Email: &invalidEmail, + Email: &invalidEmail, } valid = Validation{} @@ -524,19 +536,18 @@ func TestPointer(t *testing.T) { } } - func TestCanSkipAlso(t *testing.T) { type User struct { ID int - Email string `valid:"Email"` - ReqEmail string `valid:"Required;Email"` - MatchRange int `valid:"Range(10, 20)"` + Email string `valid:"Email"` + ReqEmail string `valid:"Required;Email"` + MatchRange int `valid:"Range(10, 20)"` } u := User{ - ReqEmail: "a@a.com", - Email: "", + ReqEmail: "a@a.com", + Email: "", MatchRange: 0, } @@ -560,4 +571,3 @@ func TestCanSkipAlso(t *testing.T) { } } - diff --git a/validation/validators.go b/validation/validators.go index 4dff9c0b..dc18b11e 100644 --- a/validation/validators.go +++ b/validation/validators.go @@ -632,7 +632,7 @@ func (b Base64) GetLimitValue() interface{} { } // just for chinese mobile phone number -var mobilePattern = regexp.MustCompile(`^((\+86)|(86))?(1(([35][0-9])|[8][0-9]|[7][06789]|[4][579]))\d{8}$`) +var mobilePattern = regexp.MustCompile(`^((\+86)|(86))?(1(([35][0-9])|[8][0-9]|[7][01356789]|[4][579]))\d{8}$`) // Mobile check struct type Mobile struct {