diff --git a/.gitignore b/.gitignore index 9806457b..e1b65291 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .idea +.vscode .DS_Store *.swp *.swo diff --git a/.gosimpleignore b/.gosimpleignore new file mode 100644 index 00000000..84df9b95 --- /dev/null +++ b/.gosimpleignore @@ -0,0 +1,4 @@ +github.com/astaxie/beego/*/*:S1012 +github.com/astaxie/beego/*:S1012 +github.com/astaxie/beego/*/*:S1007 +github.com/astaxie/beego/*:S1007 \ No newline at end of file diff --git a/.travis.yml b/.travis.yml index 93536488..d4b19015 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,8 +2,8 @@ language: go go: - 1.6 - - 1.5.3 - - 1.4.3 + - 1.7 + - 1.8 services: - redis-server - mysql @@ -31,6 +31,10 @@ install: - go get github.com/siddontang/ledisdb/config - go get github.com/siddontang/ledisdb/ledis - go get github.com/ssdb/gossdb/ssdb + - go get github.com/cloudflare/golz4 + - go get github.com/gogo/protobuf/proto + - go get -u honnef.co/go/tools/cmd/gosimple + - go get -u github.com/mdempsky/unconvert before_script: - psql --version - sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi" @@ -45,5 +49,7 @@ after_script: - rm -rf ./res/var/* script: - go test -v ./... + - gosimple -ignore "$(cat .gosimpleignore)" $(go list ./... | grep -v /vendor/) + - unconvert $(go list ./... | grep -v /vendor/) addons: postgresql: "9.4" diff --git a/README.md b/README.md index d3c92d84..c08927fb 100644 --- a/README.md +++ b/README.md @@ -1,20 +1,17 @@ -## Beego - -[![Build Status](https://travis-ci.org/astaxie/beego.svg?branch=master)](https://travis-ci.org/astaxie/beego) -[![GoDoc](http://godoc.org/github.com/astaxie/beego?status.svg)](http://godoc.org/github.com/astaxie/beego) -[![Foundation](https://img.shields.io/badge/Golang-Foundation-green.svg)](http://golangfoundation.org) +# Beego [![Build Status](https://travis-ci.org/astaxie/beego.svg?branch=master)](https://travis-ci.org/astaxie/beego) [![GoDoc](http://godoc.org/github.com/astaxie/beego?status.svg)](http://godoc.org/github.com/astaxie/beego) [![Foundation](https://img.shields.io/badge/Golang-Foundation-green.svg)](http://golangfoundation.org) beego is used for rapid development of RESTful APIs, web apps and backend services in Go. It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific features such as interfaces and struct embedding. -More info [beego.me](http://beego.me) +###### More info at [beego.me](http://beego.me). -##Quick Start -######Download and install +## Quick Start + +#### Download and install go get github.com/astaxie/beego -######Create file `hello.go` +#### Create file `hello.go` ```go package main @@ -24,15 +21,16 @@ func main(){ beego.Run() } ``` -######Build and run -```bash +#### Build and run + go build hello.go ./hello -``` -######Congratulations! -You just built your first beego app. -Open your browser and visit `http://localhost:8080`. -Please see [Documentation](http://beego.me/docs) for more. + +#### Go to [http://localhost:8080](http://localhost:8080) + +Congratulations! You've just built your first **beego** app. + +###### Please see [Documentation](http://beego.me/docs) for more. ## Features @@ -56,7 +54,7 @@ Please see [Documentation](http://beego.me/docs) for more. * [http://beego.me/community](http://beego.me/community) * Welcome to join us in Slack: [https://beego.slack.com](https://beego.slack.com), you can get invited from [here](https://github.com/beego/beedoc/issues/232) -## LICENSE +## License beego source code is licensed under the Apache Licence, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0.html). diff --git a/admin.go b/admin.go index a2b2f53a..b4e3068f 100644 --- a/admin.go +++ b/admin.go @@ -157,8 +157,8 @@ func listConf(rw http.ResponseWriter, r *http.Request) { resultList := new([][]string) for _, f := range bf { var result = []string{ - fmt.Sprintf("%s", f.pattern), - fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)), + f.pattern, + utils.GetFuncName(f.filterFunc), } *resultList = append(*resultList, result) } @@ -213,7 +213,7 @@ func printTree(resultList *[][]string, t *Tree) { var result = []string{ v.pattern, fmt.Sprintf("%s", v.methods), - fmt.Sprintf("%s", v.controllerType), + v.controllerType.String(), } *resultList = append(*resultList, result) } else if v.routerType == routerTypeRESTFul { @@ -287,16 +287,16 @@ func healthcheck(rw http.ResponseWriter, req *http.Request) { for name, h := range toolbox.AdminCheckList { if err := h.Check(); err != nil { result = []string{ - fmt.Sprintf("error"), - fmt.Sprintf("%s", name), - fmt.Sprintf("%s", err.Error()), + "error", + name, + err.Error(), } } else { result = []string{ - fmt.Sprintf("success"), - fmt.Sprintf("%s", name), - fmt.Sprintf("OK"), + "success", + name, + "OK", } } @@ -341,8 +341,8 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) { for tname, tk := range toolbox.AdminTaskList { result = []string{ tname, - fmt.Sprintf("%s", tk.GetSpec()), - fmt.Sprintf("%s", tk.GetStatus()), + tk.GetSpec(), + tk.GetStatus(), tk.GetPrev().String(), } *resultList = append(*resultList, result) diff --git a/admin_test.go b/admin_test.go index 0bf985f2..2348792e 100644 --- a/admin_test.go +++ b/admin_test.go @@ -65,6 +65,7 @@ func oldMap() map[string]interface{} { m["BConfig.WebConfig.Session.SessionCookieLifeTime"] = BConfig.WebConfig.Session.SessionCookieLifeTime m["BConfig.WebConfig.Session.SessionAutoSetCookie"] = BConfig.WebConfig.Session.SessionAutoSetCookie m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain + m["BConfig.WebConfig.Session.SessionDisableHTTPOnly"] = BConfig.WebConfig.Session.SessionDisableHTTPOnly m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum m["BConfig.Log.Outputs"] = BConfig.Log.Outputs diff --git a/app.go b/app.go index 32776298..25ea2a04 100644 --- a/app.go +++ b/app.go @@ -348,9 +348,9 @@ func Any(rootpath string, f FilterFunc) *App { // Handler used to register a Handler router // usage: -// beego.Handler("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) +// beego.Handler("/api", http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { +// fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path)) +// })) func Handler(rootpath string, h http.Handler, options ...interface{}) *App { BeeApp.Handlers.Handler(rootpath, h, options...) return BeeApp diff --git a/beego.go b/beego.go index 94ec92b2..7acf4a13 100644 --- a/beego.go +++ b/beego.go @@ -23,7 +23,7 @@ import ( const ( // VERSION represent beego web framework version. - VERSION = "1.7.1" + VERSION = "1.8.1" // DEV is for develop DEV = "dev" diff --git a/cache/conv_test.go b/cache/conv_test.go index cf792fa6..b90e224a 100644 --- a/cache/conv_test.go +++ b/cache/conv_test.go @@ -118,14 +118,14 @@ func TestGetFloat64(t *testing.T) { func TestGetBool(t *testing.T) { var t1 = true - if true != GetBool(t1) { + if !GetBool(t1) { t.Error("get bool from bool error") } var t2 = "true" - if true != GetBool(t2) { + if !GetBool(t2) { t.Error("get bool from string error") } - if false != GetBool(nil) { + if GetBool(nil) { t.Error("get bool from nil error") } } diff --git a/cache/file.go b/cache/file.go index 4b030980..691ce7cd 100644 --- a/cache/file.go +++ b/cache/file.go @@ -22,6 +22,7 @@ import ( "encoding/json" "fmt" "io" + "io/ioutil" "os" "path/filepath" "reflect" @@ -222,33 +223,13 @@ func exists(path string) (bool, error) { // FileGetContents Get bytes to file. // if non-exist, create this file. func FileGetContents(filename string) (data []byte, e error) { - f, e := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, os.ModePerm) - if e != nil { - return - } - defer f.Close() - stat, e := f.Stat() - if e != nil { - return - } - data = make([]byte, stat.Size()) - result, e := f.Read(data) - if e != nil || int64(result) != stat.Size() { - return nil, e - } - return + return ioutil.ReadFile(filename) } // FilePutContents Put bytes to file. // if non-exist, create this file. func FilePutContents(filename string, content []byte) error { - fp, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, os.ModePerm) - if err != nil { - return err - } - defer fp.Close() - _, err = fp.Write(content) - return err + return ioutil.WriteFile(filename, content, os.ModePerm) } // GobEncode Gob encodes file cache item. diff --git a/cache/memcache/memcache.go b/cache/memcache/memcache.go index 972361f7..0624f5fa 100644 --- a/cache/memcache/memcache.go +++ b/cache/memcache/memcache.go @@ -146,10 +146,7 @@ func (rc *Cache) IsExist(key string) bool { } } _, err := rc.conn.Get(key) - if err != nil { - return false - } - return true + return !(err != nil) } // ClearAll clear all cached in memcache. diff --git a/cache/redis/redis.go b/cache/redis/redis.go index 781e3836..3e71fb53 100644 --- a/cache/redis/redis.go +++ b/cache/redis/redis.go @@ -137,7 +137,7 @@ func (rc *Cache) IsExist(key string) bool { if err != nil { return false } - if v == false { + if !v { if _, err = rc.do("HDEL", rc.key, key); err != nil { return false } diff --git a/cache/ssdb/ssdb.go b/cache/ssdb/ssdb.go index bfee69ce..fa2ce04b 100644 --- a/cache/ssdb/ssdb.go +++ b/cache/ssdb/ssdb.go @@ -53,7 +53,7 @@ func (rc *Cache) GetMulti(keys []string) []interface{} { resSize := len(res) if err == nil { for i := 1; i < resSize; i += 2 { - values = append(values, string(res[i+1])) + values = append(values, res[i+1]) } return values } @@ -71,10 +71,7 @@ func (rc *Cache) DelMulti(keys []string) error { } } _, err := rc.conn.Do("multi_del", keys) - if err != nil { - return err - } - return nil + return err } // Put put value to memcache. only support string. @@ -113,10 +110,7 @@ func (rc *Cache) Delete(key string) error { } } _, err := rc.conn.Del(key) - if err != nil { - return err - } - return nil + return err } // Incr increase counter. @@ -152,7 +146,7 @@ func (rc *Cache) IsExist(key string) bool { if err != nil { return false } - if resp[1] == "1" { + if len(resp) == 2 && resp[1] == "1" { return true } return false @@ -175,7 +169,7 @@ func (rc *Cache) ClearAll() error { } keys := []string{} for i := 1; i < size; i += 2 { - keys = append(keys, string(resp[i])) + keys = append(keys, resp[i]) } _, e := rc.conn.Do("multi_del", keys) if e != nil { @@ -229,10 +223,7 @@ func (rc *Cache) connectInit() error { } var err error rc.conn, err = ssdb.Connect(host, port) - if err != nil { - return err - } - return nil + return err } func init() { diff --git a/config.go b/config.go index a4f40611..e6e99570 100644 --- a/config.go +++ b/config.go @@ -41,6 +41,7 @@ type Config struct { EnableGzip bool MaxMemory int64 EnableErrorsShow bool + EnableErrorsRender bool Listen Listen WebConfig WebConfig Log LogConfig @@ -86,17 +87,18 @@ type WebConfig struct { // SessionConfig holds session related config type SessionConfig struct { - SessionOn bool - SessionProvider string - SessionName string - SessionGCMaxLifetime int64 - SessionProviderConfig string - SessionCookieLifeTime int - SessionAutoSetCookie bool - SessionDomain string - EnableSidInHttpHeader bool // enable store/get the sessionId into/from http headers - SessionNameInHttpHeader string - EnableSidInUrlQuery bool // enable get the sessionId from Url Query params + SessionOn bool + SessionProvider string + SessionName string + SessionGCMaxLifetime int64 + SessionProviderConfig string + SessionCookieLifeTime int + SessionAutoSetCookie bool + SessionDomain string + SessionDisableHTTPOnly bool // used to allow for cross domain cookies/javascript cookies. + SessionEnableSidInHTTPHeader bool // enable store/get the sessionId into/from http headers + SessionNameInHTTPHeader string + SessionEnableSidInURLQuery bool // enable get the sessionId from Url Query params } // LogConfig holds Log related config @@ -170,7 +172,7 @@ func recoverPanic(ctx *context.Context) { logs.Critical(fmt.Sprintf("%s:%d", file, line)) stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line)) } - if BConfig.RunMode == DEV { + if BConfig.RunMode == DEV && BConfig.EnableErrorsRender { showErr(err, ctx, stack) } } @@ -188,6 +190,7 @@ func newBConfig() *Config { EnableGzip: false, MaxMemory: 1 << 26, //64MB EnableErrorsShow: true, + EnableErrorsRender: true, Listen: Listen{ Graceful: false, ServerTimeOut: 0, @@ -221,17 +224,18 @@ func newBConfig() *Config { XSRFKey: "beegoxsrf", XSRFExpire: 0, Session: SessionConfig{ - SessionOn: false, - SessionProvider: "memory", - SessionName: "beegosessionID", - SessionGCMaxLifetime: 3600, - SessionProviderConfig: "", - SessionCookieLifeTime: 0, //set cookie default is the browser life - SessionAutoSetCookie: true, - SessionDomain: "", - EnableSidInHttpHeader: false, // enable store/get the sessionId into/from http headers - SessionNameInHttpHeader: "Beegosessionid", - EnableSidInUrlQuery: false, // enable get the sessionId from Url Query params + SessionOn: false, + SessionProvider: "memory", + SessionName: "beegosessionID", + SessionGCMaxLifetime: 3600, + SessionProviderConfig: "", + SessionDisableHTTPOnly: false, + SessionCookieLifeTime: 0, //set cookie default is the browser life + SessionAutoSetCookie: true, + SessionDomain: "", + SessionEnableSidInHTTPHeader: false, // enable store/get the sessionId into/from http headers + SessionNameInHTTPHeader: "Beegosessionid", + SessionEnableSidInURLQuery: false, // enable get the sessionId from Url Query params }, }, Log: LogConfig{ @@ -252,6 +256,9 @@ func parseConfig(appConfigPath string) (err error) { } func assignConfig(ac config.Configer) error { + for _, i := range []interface{}{BConfig, &BConfig.Listen, &BConfig.WebConfig, &BConfig.Log, &BConfig.WebConfig.Session} { + assignSingleConfig(i, ac) + } // set the run mode first if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" { BConfig.RunMode = envRunMode @@ -259,10 +266,6 @@ func assignConfig(ac config.Configer) error { BConfig.RunMode = runMode } - for _, i := range []interface{}{BConfig, &BConfig.Listen, &BConfig.WebConfig, &BConfig.Log, &BConfig.WebConfig.Session} { - assignSingleConfig(i, ac) - } - if sd := ac.String("StaticDir"); sd != "" { BConfig.WebConfig.StaticDir = map[string]string{} sds := strings.Fields(sd) @@ -294,6 +297,10 @@ func assignConfig(ac config.Configer) error { } if lo := ac.String("LogOutputs"); lo != "" { + // if lo is not nil or empty + // means user has set his own LogOutputs + // clear the default setting to BConfig.Log.Outputs + BConfig.Log.Outputs = make(map[string]string) los := strings.Split(lo, ";") for _, v := range los { if logType2Config := strings.SplitN(v, ",", 2); len(logType2Config) == 2 { @@ -338,7 +345,7 @@ func assignSingleConfig(p interface{}, ac config.Configer) { case reflect.String: pf.SetString(ac.DefaultString(name, pf.String())) case reflect.Int, reflect.Int64: - pf.SetInt(int64(ac.DefaultInt64(name, pf.Int()))) + pf.SetInt(ac.DefaultInt64(name, pf.Int())) case reflect.Bool: pf.SetBool(ac.DefaultBool(name, pf.Bool())) case reflect.Struct: diff --git a/config/config.go b/config/config.go index 9f41fb79..e8201a24 100644 --- a/config/config.go +++ b/config/config.go @@ -43,6 +43,8 @@ package config import ( "fmt" "os" + "reflect" + "time" ) // Configer defines how to get and set value from configuration raw data. @@ -204,3 +206,37 @@ func ParseBool(val interface{}) (value bool, err error) { } return false, fmt.Errorf("parsing : invalid syntax") } + +// ToString converts values of any type to string. +func ToString(x interface{}) string { + switch y := x.(type) { + + // Handle dates with special logic + // This needs to come above the fmt.Stringer + // test since time.Time's have a .String() + // method + case time.Time: + return y.Format("A Monday") + + // Handle type string + case string: + return y + + // Handle type with .String() method + case fmt.Stringer: + return y.String() + + // Handle type with .Error() method + case error: + return y.Error() + + } + + // Handle named string type + if v := reflect.ValueOf(x); v.Kind() == reflect.String { + return v.String() + } + + // Fallback to fmt package for anything else like numeric types + return fmt.Sprint(x) +} diff --git a/config/env/env.go b/config/env/env.go new file mode 100644 index 00000000..a819e51a --- /dev/null +++ b/config/env/env.go @@ -0,0 +1,85 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// Copyright 2017 Faissal Elamraoui. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package env + +import ( + "fmt" + "os" + "strings" + + "github.com/astaxie/beego/utils" +) + +var env *utils.BeeMap + +func init() { + env = utils.NewBeeMap() + for _, e := range os.Environ() { + splits := strings.Split(e, "=") + env.Set(splits[0], os.Getenv(splits[0])) + } +} + +// Get returns a value by key. +// If the key does not exist, the default value will be returned. +func Get(key string, defVal string) string { + if val := env.Get(key); val != nil { + return val.(string) + } + return defVal +} + +// MustGet returns a value by key. +// If the key does not exist, it will return an error. +func MustGet(key string) (string, error) { + if val := env.Get(key); val != nil { + return val.(string), nil + } + return "", fmt.Errorf("no env variable with %s", key) +} + +// Set sets a value in the ENV copy. +// This does not affect the child process environment. +func Set(key string, value string) { + env.Set(key, value) +} + +// MustSet sets a value in the ENV copy and the child process environment. +// It returns an error in case the set operation failed. +func MustSet(key string, value string) error { + err := os.Setenv(key, value) + if err != nil { + return err + } + env.Set(key, value) + return nil +} + +// GetAll returns all keys/values in the current child process environment. +func GetAll() map[string]string { + items := env.Items() + envs := make(map[string]string, env.Count()) + + for key, val := range items { + switch key := key.(type) { + case string: + switch val := val.(type) { + case string: + envs[key] = val + } + } + } + return envs +} diff --git a/config/env/env_test.go b/config/env/env_test.go new file mode 100644 index 00000000..3f1d4dba --- /dev/null +++ b/config/env/env_test.go @@ -0,0 +1,75 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// Copyright 2017 Faissal Elamraoui. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package env + +import ( + "os" + "testing" +) + +func TestEnvGet(t *testing.T) { + gopath := Get("GOPATH", "") + if gopath != os.Getenv("GOPATH") { + t.Error("expected GOPATH not empty.") + } + + noExistVar := Get("NOEXISTVAR", "foo") + if noExistVar != "foo" { + t.Errorf("expected NOEXISTVAR to equal foo, got %s.", noExistVar) + } +} + +func TestEnvMustGet(t *testing.T) { + gopath, err := MustGet("GOPATH") + if err != nil { + t.Error(err) + } + + if gopath != os.Getenv("GOPATH") { + t.Errorf("expected GOPATH to be the same, got %s.", gopath) + } + + _, err = MustGet("NOEXISTVAR") + if err == nil { + t.Error("expected error to be non-nil") + } +} + +func TestEnvSet(t *testing.T) { + Set("MYVAR", "foo") + myVar := Get("MYVAR", "bar") + if myVar != "foo" { + t.Errorf("expected MYVAR to equal foo, got %s.", myVar) + } +} + +func TestEnvMustSet(t *testing.T) { + err := MustSet("FOO", "bar") + if err != nil { + t.Error(err) + } + + fooVar := os.Getenv("FOO") + if fooVar != "bar" { + t.Errorf("expected FOO variable to equal bar, got %s.", fooVar) + } +} + +func TestEnvGetAll(t *testing.T) { + envMap := GetAll() + if len(envMap) == 0 { + t.Error("expected environment not empty.") + } +} diff --git a/config/ini.go b/config/ini.go index 9371fe61..7e776a9a 100644 --- a/config/ini.go +++ b/config/ini.go @@ -18,15 +18,14 @@ import ( "bufio" "bytes" "errors" - "fmt" "io" "io/ioutil" "os" - "path" + "os/user" + "path/filepath" "strconv" "strings" "sync" - "time" ) var ( @@ -51,24 +50,26 @@ func (ini *IniConfig) Parse(name string) (Configer, error) { } func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) { - file, err := os.Open(name) + data, err := ioutil.ReadFile(name) if err != nil { return nil, err } + return ini.parseData(filepath.Dir(name), data) +} + +func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, error) { cfg := &IniConfigContainer{ - file.Name(), - make(map[string]map[string]string), - make(map[string]string), - make(map[string]string), - sync.RWMutex{}, + data: make(map[string]map[string]string), + sectionComment: make(map[string]string), + keyComment: make(map[string]string), + RWMutex: sync.RWMutex{}, } cfg.Lock() defer cfg.Unlock() - defer file.Close() var comment bytes.Buffer - buf := bufio.NewReader(file) + buf := bufio.NewReader(bytes.NewBuffer(data)) // check the BOM head, err := buf.Peek(3) if err == nil && head[0] == 239 && head[1] == 187 && head[2] == 191 { @@ -129,16 +130,20 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) { // handle include "other.conf" if len(keyValue) == 1 && strings.HasPrefix(key, "include") { + includefiles := strings.Fields(key) if includefiles[0] == "include" && len(includefiles) == 2 { + otherfile := strings.Trim(includefiles[1], "\"") - if !path.IsAbs(otherfile) { - otherfile = path.Join(path.Dir(name), otherfile) + if !filepath.IsAbs(otherfile) { + otherfile = filepath.Join(dir, otherfile) } + i, err := ini.parseFile(otherfile) if err != nil { return nil, err } + for sec, dt := range i.data { if _, ok := cfg.data[sec]; !ok { cfg.data[sec] = make(map[string]string) @@ -147,12 +152,15 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) { cfg.data[sec][k] = v } } + for sec, comm := range i.sectionComment { cfg.sectionComment[sec] = comm } + for k, comm := range i.keyComment { cfg.keyComment[k] = comm } + continue } } @@ -176,20 +184,25 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) { } // ParseData parse ini the data +// When include other.conf,other.conf is either absolute directory +// or under beego in default temporary directory(/tmp/beego[-username]). func (ini *IniConfig) ParseData(data []byte) (Configer, error) { - // Save memory data to temporary file - tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond())) - os.MkdirAll(path.Dir(tmpName), os.ModePerm) - if err := ioutil.WriteFile(tmpName, data, 0655); err != nil { + dir := "beego" + currentUser, err := user.Current() + if err == nil { + dir = "beego-" + currentUser.Username + } + dir = filepath.Join(os.TempDir(), dir) + if err = os.MkdirAll(dir, os.ModePerm); err != nil { return nil, err } - return ini.Parse(tmpName) + + return ini.parseData(dir, data) } // IniConfigContainer A Config represents the ini configuration. // When set and get value, support key as section:name type. type IniConfigContainer struct { - filename string data map[string]map[string]string // section=> key:val sectionComment map[string]string // section : comment keyComment map[string]string // id: []{comment, key...}; id 1 is for main comment. @@ -296,7 +309,7 @@ func (c *IniConfigContainer) GetSection(section string) (map[string]string, erro if v, ok := c.data[section]; ok { return v, nil } - return nil, errors.New("not exist setction") + return nil, errors.New("not exist section") } // SaveConfigFile save the config into file. @@ -392,11 +405,8 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { } } } - - if _, err = buf.WriteTo(f); err != nil { - return err - } - return nil + _, err = buf.WriteTo(f) + return err } // Set writes a new value for key. diff --git a/config/ini_test.go b/config/ini_test.go index 83ff3668..ffcdb294 100644 --- a/config/ini_test.go +++ b/config/ini_test.go @@ -181,7 +181,7 @@ name=mysql cfgData := string(data) datas := strings.Split(saveResult, "\n") for _, line := range datas { - if strings.Contains(cfgData, line+"\n") == false { + if !strings.Contains(cfgData, line+"\n") { t.Fatalf("different after save ini config file. need contains %q", line) } } diff --git a/config/xml/xml.go b/config/xml/xml.go index 0c4e4d27..b82bf403 100644 --- a/config/xml/xml.go +++ b/config/xml/xml.go @@ -35,11 +35,9 @@ import ( "fmt" "io/ioutil" "os" - "path" "strconv" "strings" "sync" - "time" "github.com/astaxie/beego/config" "github.com/beego/x2j" @@ -52,36 +50,26 @@ type Config struct{} // Parse returns a ConfigContainer with parsed xml config map. func (xc *Config) Parse(filename string) (config.Configer, error) { - file, err := os.Open(filename) + context, err := ioutil.ReadFile(filename) if err != nil { return nil, err } - defer file.Close() + return xc.ParseData(context) +} + +// ParseData xml data +func (xc *Config) ParseData(data []byte) (config.Configer, error) { x := &ConfigContainer{data: make(map[string]interface{})} - content, err := ioutil.ReadAll(file) - if err != nil { - return nil, err - } - d, err := x2j.DocToMap(string(content)) + d, err := x2j.DocToMap(string(data)) if err != nil { return nil, err } x.data = config.ExpandValueEnvForMap(d["config"].(map[string]interface{})) - return x, nil -} -// ParseData xml data -func (xc *Config) ParseData(data []byte) (config.Configer, error) { - // Save memory data to temporary file - tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond())) - os.MkdirAll(path.Dir(tmpName), os.ModePerm) - if err := ioutil.WriteFile(tmpName, data, 0655); err != nil { - return nil, err - } - return xc.Parse(tmpName) + return x, nil } // ConfigContainer A Config represents the xml configuration. @@ -193,10 +181,14 @@ func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []stri // GetSection returns map for the given section func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { - if v, ok := c.data[section]; ok { - return v.(map[string]string), nil + if v, ok := c.data[section].(map[string]interface{}); ok { + mapstr := make(map[string]string) + for k, val := range v { + mapstr[k] = config.ToString(val) + } + return mapstr, nil } - return nil, errors.New("not exist setction") + return nil, fmt.Errorf("section '%s' not found", section) } // SaveConfigFile save the config into file diff --git a/config/xml/xml_test.go b/config/xml/xml_test.go index d8a09a59..346c866e 100644 --- a/config/xml/xml_test.go +++ b/config/xml/xml_test.go @@ -37,6 +37,10 @@ func TestXML(t *testing.T) { true ${GOPATH} ${GOPATH||/home/go} + +1 +MySection + ` keyValue = map[string]interface{}{ @@ -65,11 +69,22 @@ func TestXML(t *testing.T) { } f.Close() defer os.Remove("testxml.conf") + xmlconf, err := config.NewConfig("xml", "testxml.conf") if err != nil { t.Fatal(err) } + var xmlsection map[string]string + xmlsection, err = xmlconf.GetSection("mysection") + if err != nil { + t.Fatal(err) + } + + if len(xmlsection) == 0 { + t.Error("section should not be empty") + } + for k, v := range keyValue { var ( diff --git a/config/yaml/yaml.go b/config/yaml/yaml.go index e3260215..51fe44d3 100644 --- a/config/yaml/yaml.go +++ b/config/yaml/yaml.go @@ -37,10 +37,8 @@ import ( "io/ioutil" "log" "os" - "path" "strings" "sync" - "time" "github.com/astaxie/beego/config" "github.com/beego/goyaml2" @@ -63,26 +61,30 @@ func (yaml *Config) Parse(filename string) (y config.Configer, err error) { // ParseData parse yaml data func (yaml *Config) ParseData(data []byte) (config.Configer, error) { - // Save memory data to temporary file - tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond())) - os.MkdirAll(path.Dir(tmpName), os.ModePerm) - if err := ioutil.WriteFile(tmpName, data, 0655); err != nil { + cnf, err := parseYML(data) + if err != nil { return nil, err } - return yaml.Parse(tmpName) + + return &ConfigContainer{ + data: cnf, + }, nil } // ReadYmlReader Read yaml file to map. // if json like, use json package, unless goyaml2 package. func ReadYmlReader(path string) (cnf map[string]interface{}, err error) { - f, err := os.Open(path) + buf, err := ioutil.ReadFile(path) if err != nil { return } - defer f.Close() - buf, err := ioutil.ReadAll(f) - if err != nil || len(buf) < 3 { + return parseYML(buf) +} + +// parseYML parse yaml formatted []byte to map. +func parseYML(buf []byte) (cnf map[string]interface{}, err error) { + if len(buf) < 3 { return } @@ -250,7 +252,7 @@ func (c *ConfigContainer) GetSection(section string) (map[string]string, error) if v, ok := c.data[section]; ok { return v.(map[string]string), nil } - return nil, errors.New("not exist setction") + return nil, errors.New("not exist section") } // SaveConfigFile save the config into file diff --git a/context/input.go b/context/input.go index 1e6eaf71..d9015ce3 100644 --- a/context/input.go +++ b/context/input.go @@ -413,7 +413,13 @@ func (input *BeegoInput) Bind(dest interface{}, key string) error { if !value.CanSet() { return errors.New("beego: non-settable variable passed to Bind: " + key) } - rv := input.bind(key, value.Type()) + typ := value.Type() + // Get real type if dest define with interface{}. + // e.g var dest interface{} dest=1.0 + if value.Kind() == reflect.Interface { + typ = value.Elem().Type() + } + rv := input.bind(key, typ) if !rv.IsValid() { return errors.New("beego: reflect value is empty") } @@ -422,6 +428,9 @@ func (input *BeegoInput) Bind(dest interface{}, key string) error { } func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value { + if input.Context.Request.Form == nil { + input.Context.Request.ParseForm() + } rv := reflect.Zero(typ) switch typ.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: diff --git a/context/input_test.go b/context/input_test.go index e64addba..db812a0f 100644 --- a/context/input_test.go +++ b/context/input_test.go @@ -15,81 +15,97 @@ package context import ( - "fmt" "net/http" "net/http/httptest" "reflect" "testing" ) -func TestParse(t *testing.T) { - r, _ := http.NewRequest("GET", "/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil) - beegoInput := NewInput() - beegoInput.Context = NewContext() - beegoInput.Context.Reset(httptest.NewRecorder(), r) - beegoInput.ParseFormOrMulitForm(1 << 20) +func TestBind(t *testing.T) { + type testItem struct { + field string + empty interface{} + want interface{} + } + type Human struct { + ID int + Nick string + Pwd string + Ms bool + } - var id int - err := beegoInput.Bind(&id, "id") - if id != 123 || err != nil { - t.Fatal("id should has int value") - } - fmt.Println(id) + cases := []struct { + request string + valueGp []testItem + }{ + {"/?p=str", []testItem{{"p", interface{}(""), interface{}("str")}}}, - var isok bool - err = beegoInput.Bind(&isok, "isok") - if !isok || err != nil { - t.Fatal("isok should be true") - } - fmt.Println(isok) + {"/?p=", []testItem{{"p", "", ""}}}, + {"/?p=str", []testItem{{"p", "", "str"}}}, - var float float64 - err = beegoInput.Bind(&float, "ft") - if float != 1.2 || err != nil { - t.Fatal("float should be equal to 1.2") - } - fmt.Println(float) + {"/?p=123", []testItem{{"p", 0, 123}}}, + {"/?p=123", []testItem{{"p", uint(0), uint(123)}}}, - ol := make([]int, 0, 2) - err = beegoInput.Bind(&ol, "ol") - if len(ol) != 2 || err != nil || ol[0] != 1 || ol[1] != 2 { - t.Fatal("ol should has two elements") - } - fmt.Println(ol) + {"/?p=1.0", []testItem{{"p", 0.0, 1.0}}}, + {"/?p=1", []testItem{{"p", false, true}}}, - ul := make([]string, 0, 2) - err = beegoInput.Bind(&ul, "ul") - if len(ul) != 2 || err != nil || ul[0] != "str" || ul[1] != "array" { - t.Fatal("ul should has two elements") - } - fmt.Println(ul) + {"/?p=true", []testItem{{"p", false, true}}}, + {"/?p=ON", []testItem{{"p", false, true}}}, + {"/?p=on", []testItem{{"p", false, true}}}, + {"/?p=1", []testItem{{"p", false, true}}}, + {"/?p=2", []testItem{{"p", false, false}}}, + {"/?p=false", []testItem{{"p", false, false}}}, - type User struct { - Name string - } - user := User{} - err = beegoInput.Bind(&user, "user") - if err != nil || user.Name != "astaxie" { - t.Fatal("user should has name") - } - fmt.Println(user) -} + {"/?p[a]=1&p[b]=2&p[c]=3", []testItem{{"p", map[string]int{}, map[string]int{"a": 1, "b": 2, "c": 3}}}}, + {"/?p[a]=v1&p[b]=v2&p[c]=v3", []testItem{{"p", map[string]string{}, map[string]string{"a": "v1", "b": "v2", "c": "v3"}}}}, -func TestParse2(t *testing.T) { - r, _ := http.NewRequest("GET", "/?user[0][Username]=Raph&user[1].Username=Leo&user[0].Password=123456&user[1][Password]=654321", nil) - beegoInput := NewInput() - beegoInput.Context = NewContext() - beegoInput.Context.Reset(httptest.NewRecorder(), r) - beegoInput.ParseFormOrMulitForm(1 << 20) - type User struct { - Username string - Password string + {"/?p[]=8&p[]=9&p[]=10", []testItem{{"p", []int{}, []int{8, 9, 10}}}}, + {"/?p[0]=8&p[1]=9&p[2]=10", []testItem{{"p", []int{}, []int{8, 9, 10}}}}, + {"/?p[0]=8&p[1]=9&p[2]=10&p[5]=14", []testItem{{"p", []int{}, []int{8, 9, 10, 0, 0, 14}}}}, + {"/?p[0]=8.0&p[1]=9.0&p[2]=10.0", []testItem{{"p", []float64{}, []float64{8.0, 9.0, 10.0}}}}, + + {"/?p[]=10&p[]=9&p[]=8", []testItem{{"p", []string{}, []string{"10", "9", "8"}}}}, + {"/?p[0]=8&p[1]=9&p[2]=10", []testItem{{"p", []string{}, []string{"8", "9", "10"}}}}, + + {"/?p[0]=true&p[1]=false&p[2]=true&p[5]=1&p[6]=ON&p[7]=other", []testItem{{"p", []bool{}, []bool{true, false, true, false, false, true, true, false}}}}, + + {"/?human.Nick=astaxie", []testItem{{"human", Human{}, Human{Nick: "astaxie"}}}}, + {"/?human.ID=888&human.Nick=astaxie&human.Ms=true&human[Pwd]=pass", []testItem{{"human", Human{}, Human{ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass"}}}}, + {"/?human[0].ID=888&human[0].Nick=astaxie&human[0].Ms=true&human[0][Pwd]=pass01&human[1].ID=999&human[1].Nick=ysqi&human[1].Ms=On&human[1].Pwd=pass02", + []testItem{{"human", []Human{}, []Human{ + {ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass01"}, + {ID: 999, Nick: "ysqi", Ms: true, Pwd: "pass02"}, + }}}}, + + { + "/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&human.Nick=astaxie", + []testItem{ + {"id", 0, 123}, + {"isok", false, true}, + {"ft", 0.0, 1.2}, + {"ol", []int{}, []int{1, 2}}, + {"ul", []string{}, []string{"str", "array"}}, + {"human", Human{}, Human{Nick: "astaxie"}}, + }, + }, } - var users []User - err := beegoInput.Bind(&users, "user") - fmt.Println(users) - if err != nil || users[0].Username != "Raph" || users[0].Password != "123456" || users[1].Username != "Leo" || users[1].Password != "654321" { - t.Fatal("users info wrong") + for _, c := range cases { + r, _ := http.NewRequest("GET", c.request, nil) + beegoInput := NewInput() + beegoInput.Context = NewContext() + beegoInput.Context.Reset(httptest.NewRecorder(), r) + + for _, item := range c.valueGp { + got := item.empty + err := beegoInput.Bind(&got, item.field) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, item.want) { + t.Fatalf("Bind %q error,should be:\n%#v \ngot:\n%#v", item.field, item.want, got) + } + } + } } diff --git a/context/output.go b/context/output.go index c09b9d19..c15f9fe7 100644 --- a/context/output.go +++ b/context/output.go @@ -67,6 +67,7 @@ func (output *BeegoOutput) Body(content []byte) error { } if b, n, _ := WriteBody(encoding, buf, content); b { output.Header("Content-Encoding", n) + output.Header("Content-Length", strconv.Itoa(buf.Len())) } else { output.Header("Content-Length", strconv.Itoa(len(content))) } @@ -104,7 +105,7 @@ func (output *BeegoOutput) Cookie(name string, value string, others ...interface switch { case maxAge > 0: fmt.Fprintf(&b, "; Expires=%s; Max-Age=%d", time.Now().Add(time.Duration(maxAge)*time.Second).UTC().Format(time.RFC1123), maxAge) - case maxAge < 0: + case maxAge <= 0: fmt.Fprintf(&b, "; Max-Age=0") } } @@ -330,16 +331,17 @@ func (output *BeegoOutput) IsServerError() bool { func stringsToJSON(str string) string { rs := []rune(str) - jsons := "" + var jsons bytes.Buffer for _, r := range rs { rint := int(r) if rint < 128 { - jsons += string(r) + jsons.WriteRune(r) } else { - jsons += "\\u" + strconv.FormatInt(int64(rint), 16) // json + jsons.WriteString("\\u") + jsons.WriteString(strconv.FormatInt(int64(rint), 16)) } } - return jsons + return jsons.String() } // Session sets session item value with given key. diff --git a/controller.go b/controller.go index c7eb118d..c2a327b3 100644 --- a/controller.go +++ b/controller.go @@ -69,6 +69,7 @@ type Controller struct { // template data TplName string + ViewPath string Layout string LayoutSections map[string]string // the key is the section name and the value is the template name TplPrefix string @@ -185,7 +186,11 @@ func (c *Controller) Render() error { if err != nil { return err } - c.Ctx.Output.Header("Content-Type", "text/html; charset=utf-8") + + if c.Ctx.ResponseWriter.Header().Get("Content-Type") == "" { + c.Ctx.Output.Header("Content-Type", "text/html; charset=utf-8") + } + return c.Ctx.Output.Body(rb) } @@ -209,7 +214,7 @@ func (c *Controller) RenderBytes() ([]byte, error) { continue } buf.Reset() - err = ExecuteTemplate(&buf, sectionTpl, c.Data) + err = ExecuteViewPathTemplate(&buf, sectionTpl, c.viewPath(), c.Data) if err != nil { return nil, err } @@ -218,7 +223,7 @@ func (c *Controller) RenderBytes() ([]byte, error) { } buf.Reset() - ExecuteTemplate(&buf, c.Layout, c.Data) + ExecuteViewPathTemplate(&buf, c.Layout, c.viewPath(), c.Data) } return buf.Bytes(), err } @@ -244,9 +249,16 @@ func (c *Controller) renderTemplate() (bytes.Buffer, error) { } } } - BuildTemplate(BConfig.WebConfig.ViewsPath, buildFiles...) + BuildTemplate(c.viewPath(), buildFiles...) } - return buf, ExecuteTemplate(&buf, c.TplName, c.Data) + return buf, ExecuteViewPathTemplate(&buf, c.TplName, c.viewPath(), c.Data) +} + +func (c *Controller) viewPath() string { + if c.ViewPath == "" { + return BConfig.WebConfig.ViewsPath + } + return c.ViewPath } // Redirect sends the redirection response to url with status code. @@ -302,7 +314,7 @@ func (c *Controller) ServeJSON(encoding ...bool) { if BConfig.RunMode == PROD { hasIndent = false } - if len(encoding) > 0 && encoding[0] == true { + if len(encoding) > 0 && encoding[0] { hasEncoding = true } c.Ctx.Output.JSON(c.Data["json"], hasIndent, hasEncoding) @@ -399,6 +411,16 @@ func (c *Controller) GetInt8(key string, def ...int8) (int8, error) { return int8(i64), err } +// GetUint8 return input as an uint8 or the default value while it's present and input is blank +func (c *Controller) GetUint8(key string, def ...uint8) (uint8, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + u64, err := strconv.ParseUint(strv, 10, 8) + return uint8(u64), err +} + // GetInt16 returns input as an int16 or the default value while it's present and input is blank func (c *Controller) GetInt16(key string, def ...int16) (int16, error) { strv := c.Ctx.Input.Query(key) @@ -409,6 +431,16 @@ func (c *Controller) GetInt16(key string, def ...int16) (int16, error) { return int16(i64), err } +// GetUint16 returns input as an uint16 or the default value while it's present and input is blank +func (c *Controller) GetUint16(key string, def ...uint16) (uint16, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + u64, err := strconv.ParseUint(strv, 10, 16) + return uint16(u64), err +} + // GetInt32 returns input as an int32 or the default value while it's present and input is blank func (c *Controller) GetInt32(key string, def ...int32) (int32, error) { strv := c.Ctx.Input.Query(key) @@ -419,6 +451,16 @@ func (c *Controller) GetInt32(key string, def ...int32) (int32, error) { return int32(i64), err } +// GetUint32 returns input as an uint32 or the default value while it's present and input is blank +func (c *Controller) GetUint32(key string, def ...uint32) (uint32, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + u64, err := strconv.ParseUint(strv, 10, 32) + return uint32(u64), err +} + // GetInt64 returns input value as int64 or the default value while it's present and input is blank. func (c *Controller) GetInt64(key string, def ...int64) (int64, error) { strv := c.Ctx.Input.Query(key) @@ -428,6 +470,15 @@ func (c *Controller) GetInt64(key string, def ...int64) (int64, error) { return strconv.ParseInt(strv, 10, 64) } +// GetUint64 returns input value as uint64 or the default value while it's present and input is blank. +func (c *Controller) GetUint64(key string, def ...uint64) (uint64, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + return strconv.ParseUint(strv, 10, 64) +} + // GetBool returns input value as bool or the default value while it's present and input is blank. func (c *Controller) GetBool(key string, def ...bool) (bool, error) { strv := c.Ctx.Input.Query(key) @@ -453,7 +504,7 @@ func (c *Controller) GetFile(key string) (multipart.File, *multipart.FileHeader, } // GetFiles return multi-upload files -// files, err:=c.Getfiles("myfiles") +// files, err:=c.GetFiles("myfiles") // if err != nil { // http.Error(w, err.Error(), http.StatusNoContent) // return diff --git a/controller_test.go b/controller_test.go index 51d3a5b7..1e53416d 100644 --- a/controller_test.go +++ b/controller_test.go @@ -15,9 +15,13 @@ package beego import ( + "math" + "strconv" "testing" "github.com/astaxie/beego/context" + "os" + "path/filepath" ) func TestGetInt(t *testing.T) { @@ -75,3 +79,103 @@ func TestGetInt64(t *testing.T) { t.Errorf("TestGeetInt64 expect 40,get %T,%v", val, val) } } + +func TestGetUint8(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint8, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint8("age") + if val != math.MaxUint8 { + t.Errorf("TestGetUint8 expect %v,get %T,%v", math.MaxUint8, val, val) + } +} + +func TestGetUint16(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint16, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint16("age") + if val != math.MaxUint16 { + t.Errorf("TestGetUint16 expect %v,get %T,%v", math.MaxUint16, val, val) + } +} + +func TestGetUint32(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint32, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint32("age") + if val != math.MaxUint32 { + t.Errorf("TestGetUint32 expect %v,get %T,%v", math.MaxUint32, val, val) + } +} + +func TestGetUint64(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint64, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint64("age") + if val != math.MaxUint64 { + t.Errorf("TestGetUint64 expect %v,get %T,%v", uint64(math.MaxUint64), val, val) + } +} + +func TestAdditionalViewPaths(t *testing.T) { + dir1 := "_beeTmp" + dir2 := "_beeTmp2" + defer os.RemoveAll(dir1) + defer os.RemoveAll(dir2) + + dir1file := "file1.tpl" + dir2file := "file2.tpl" + + genFile := func(dir string, name string, content string) { + os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + if f, err := os.Create(filepath.Join(dir, name)); err != nil { + t.Fatal(err) + } else { + defer f.Close() + f.WriteString(content) + f.Close() + } + + } + genFile(dir1, dir1file, `
{{.Content}}
`) + genFile(dir2, dir2file, `{{.Content}}`) + + AddViewPath(dir1) + AddViewPath(dir2) + + ctrl := Controller{ + TplName: "file1.tpl", + ViewPath: dir1, + } + ctrl.Data = map[interface{}]interface{}{ + "Content": "value2", + } + if result, err := ctrl.RenderString(); err != nil { + t.Fatal(err) + } else { + if result != "
value2
" { + t.Fatalf("TestAdditionalViewPaths expect %s got %s", "
value2
", result) + } + } + + func() { + ctrl.TplName = "file2.tpl" + defer func() { + if r := recover(); r == nil { + t.Fatal("TestAdditionalViewPaths expected error") + } + }() + ctrl.RenderString() + }() + + ctrl.TplName = "file2.tpl" + ctrl.ViewPath = dir2 + ctrl.RenderString() +} diff --git a/flash_test.go b/flash_test.go index 640d54de..d5e9608d 100644 --- a/flash_test.go +++ b/flash_test.go @@ -48,7 +48,7 @@ func TestFlashHeader(t *testing.T) { // match for the expected header res := strings.Contains(sc, "BEEGO_FLASH=%00notice%23BEEGOFLASH%23TestFlashString%00") // validate the assertion - if res != true { + if !res { t.Errorf("TestFlashHeader() unable to validate flash message") } } diff --git a/grace/grace.go b/grace/grace.go index af4e9068..6ebf8455 100644 --- a/grace/grace.go +++ b/grace/grace.go @@ -85,23 +85,31 @@ var ( isChild bool socketOrder string - once sync.Once + + hookableSignals []os.Signal ) -func onceInit() { - regLock = &sync.Mutex{} +func init() { flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)") flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started") + + regLock = &sync.Mutex{} runningServers = make(map[string]*Server) runningServersOrder = []string{} socketPtrOffsetMap = make(map[string]uint) + + hookableSignals = []os.Signal{ + syscall.SIGHUP, + syscall.SIGINT, + syscall.SIGTERM, + } } // NewServer returns a new graceServer. func NewServer(addr string, handler http.Handler) (srv *Server) { - once.Do(onceInit) regLock.Lock() defer regLock.Unlock() + if !flag.Parsed() { flag.Parse() } diff --git a/grace/listener.go b/grace/listener.go index 5439d0b2..823d3cce 100644 --- a/grace/listener.go +++ b/grace/listener.go @@ -21,7 +21,7 @@ func newGraceListener(l net.Listener, srv *Server) (el *graceListener) { server: srv, } go func() { - _ = <-el.stop + <-el.stop el.stopped = true el.stop <- el.Listener.Close() }() diff --git a/grace/server.go b/grace/server.go index 101bda56..cc985552 100644 --- a/grace/server.go +++ b/grace/server.go @@ -162,9 +162,7 @@ func (srv *Server) handleSignals() { signal.Notify( srv.sigChan, - syscall.SIGHUP, - syscall.SIGINT, - syscall.SIGTERM, + hookableSignals..., ) pid := syscall.Getpid() @@ -290,3 +288,19 @@ func (srv *Server) fork() (err error) { return } + +// RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal. +func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) (err error) { + if ppFlag != PreSignal && ppFlag != PostSignal { + err = fmt.Errorf("Invalid ppFlag argument. Must be either grace.PreSignal or grace.PostSignal.") + return + } + for _, s := range hookableSignals { + if s == sig { + srv.SignalHooks[ppFlag][sig] = append(srv.SignalHooks[ppFlag][sig], f) + return + } + } + err = fmt.Errorf("Signal '%v' is not supported.", sig) + return +} diff --git a/hooks.go b/hooks.go index 0c7d05fe..091ecbc7 100644 --- a/hooks.go +++ b/hooks.go @@ -53,10 +53,11 @@ func registerSession() error { conf.Secure = BConfig.Listen.EnableHTTPS conf.CookieLifeTime = BConfig.WebConfig.Session.SessionCookieLifeTime conf.ProviderConfig = filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig) + conf.DisableHTTPOnly = BConfig.WebConfig.Session.SessionDisableHTTPOnly conf.Domain = BConfig.WebConfig.Session.SessionDomain - conf.EnableSidInHttpHeader = BConfig.WebConfig.Session.EnableSidInHttpHeader - conf.SessionNameInHttpHeader = BConfig.WebConfig.Session.SessionNameInHttpHeader - conf.EnableSidInUrlQuery = BConfig.WebConfig.Session.EnableSidInUrlQuery + conf.EnableSidInHttpHeader = BConfig.WebConfig.Session.SessionEnableSidInHTTPHeader + conf.SessionNameInHttpHeader = BConfig.WebConfig.Session.SessionNameInHTTPHeader + conf.EnableSidInUrlQuery = BConfig.WebConfig.Session.SessionEnableSidInURLQuery } else { if err = json.Unmarshal([]byte(sessionConfig), conf); err != nil { return err @@ -71,7 +72,8 @@ func registerSession() error { } func registerTemplate() error { - if err := BuildTemplate(BConfig.WebConfig.ViewsPath); err != nil { + defer lockViewPaths() + if err := AddViewPath(BConfig.WebConfig.ViewsPath); err != nil { if BConfig.RunMode == DEV { logs.Warn(err) } diff --git a/httplib/httplib.go b/httplib/httplib.go index 7e6f2700..32edc5c4 100644 --- a/httplib/httplib.go +++ b/httplib/httplib.go @@ -136,9 +136,11 @@ type BeegoHTTPSettings struct { TLSClientConfig *tls.Config Proxy func(*http.Request) (*url.URL, error) Transport http.RoundTripper + CheckRedirect func(req *http.Request, via []*http.Request) error EnableCookie bool Gzip bool DumpBody bool + Retries int // if set to -1 means will retry forever } // BeegoHTTPRequest provides more useful methods for requesting one url than http.Request. @@ -188,6 +190,15 @@ func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest { return b } +// Retries sets Retries times. +// default is 0 means no retried. +// -1 means retried forever. +// others means retried times. +func (b *BeegoHTTPRequest) Retries(times int) *BeegoHTTPRequest { + b.setting.Retries = times + return b +} + // DumpBody setting whether need to Dump the Body. func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest { b.setting.DumpBody = isdump @@ -265,6 +276,15 @@ func (b *BeegoHTTPRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) return b } +// SetCheckRedirect specifies the policy for handling redirects. +// +// If CheckRedirect is nil, the Client uses its default policy, +// which is to stop after 10 consecutive requests. +func (b *BeegoHTTPRequest) SetCheckRedirect(redirect func(req *http.Request, via []*http.Request) error) *BeegoHTTPRequest { + b.setting.CheckRedirect = redirect + return b +} + // Param adds query param in to request. // params build query string as ?key1=value1&key2=value2... func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest { @@ -315,7 +335,7 @@ func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) func (b *BeegoHTTPRequest) buildURL(paramBody string) { // build GET url with query string if b.req.Method == "GET" && len(paramBody) > 0 { - if strings.Index(b.url, "?") != -1 { + if strings.Contains(b.url, "?") { b.url += "&" + paramBody } else { b.url = b.url + "?" + paramBody @@ -380,7 +400,7 @@ func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) { } // DoRequest will do the client.Do -func (b *BeegoHTTPRequest) DoRequest() (*http.Response, error) { +func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) { var paramBody string if len(b.params) > 0 { var buf bytes.Buffer @@ -446,6 +466,10 @@ func (b *BeegoHTTPRequest) DoRequest() (*http.Response, error) { b.req.Header.Set("User-Agent", b.setting.UserAgent) } + if b.setting.CheckRedirect != nil { + client.CheckRedirect = b.setting.CheckRedirect + } + if b.setting.ShowDebug { dump, err := httputil.DumpRequest(b.req, b.setting.DumpBody) if err != nil { @@ -453,7 +477,16 @@ func (b *BeegoHTTPRequest) DoRequest() (*http.Response, error) { } b.dump = dump } - return client.Do(b.req) + // retries default value is 0, it will run once. + // retries equal to -1, it will run forever until success + // retries is setted, it will retries fixed times. + for i := 0; b.setting.Retries == -1 || i <= b.setting.Retries; i++ { + resp, err = client.Do(b.req) + if err == nil { + break + } + } + return resp, err } // String returns the body string in response. diff --git a/logs/alils/alils.go b/logs/alils/alils.go new file mode 100644 index 00000000..30a09243 --- /dev/null +++ b/logs/alils/alils.go @@ -0,0 +1,192 @@ +package alils + +import ( + "encoding/json" + "github.com/astaxie/beego/logs" + "github.com/gogo/protobuf/proto" + "strings" + "sync" + "time" +) + +const ( + CacheSize int = 64 + Delimiter string = "##" +) + +type AliLSConfig struct { + Project string `json:"project"` + Endpoint string `json:"endpoint"` + KeyID string `json:"key_id"` + KeySecret string `json:"key_secret"` + LogStore string `json:"log_store"` + Topics []string `json:"topics"` + Source string `json:"source"` + Level int `json:"level"` + FlushWhen int `json:"flush_when"` +} + +// aliLSWriter implements LoggerInterface. +// it writes messages in keep-live tcp connection. +type aliLSWriter struct { + store *LogStore + group []*LogGroup + withMap bool + groupMap map[string]*LogGroup + lock *sync.Mutex + AliLSConfig +} + +// 创建提供Logger接口的日志服务 +func NewAliLS() logs.Logger { + alils := new(aliLSWriter) + alils.Level = logs.LevelTrace + return alils +} + +// 读取配置 +// 初始化必要的数据结构 +func (c *aliLSWriter) Init(jsonConfig string) (err error) { + + json.Unmarshal([]byte(jsonConfig), c) + + if c.FlushWhen > CacheSize { + c.FlushWhen = CacheSize + } + + // 初始化Project + prj := &LogProject{ + Name: c.Project, + Endpoint: c.Endpoint, + AccessKeyId: c.KeyID, + AccessKeySecret: c.KeySecret, + } + + // 获取logstore + c.store, err = prj.GetLogStore(c.LogStore) + if err != nil { + return err + } + + // 创建默认Log Group + c.group = append(c.group, &LogGroup{ + Topic: proto.String(""), + Source: proto.String(c.Source), + Logs: make([]*Log, 0, c.FlushWhen), + }) + + // 创建其它Log Group + c.groupMap = make(map[string]*LogGroup) + for _, topic := range c.Topics { + + lg := &LogGroup{ + Topic: proto.String(topic), + Source: proto.String(c.Source), + Logs: make([]*Log, 0, c.FlushWhen), + } + + c.group = append(c.group, lg) + c.groupMap[topic] = lg + } + + if len(c.group) == 1 { + c.withMap = false + } else { + c.withMap = true + } + + c.lock = &sync.Mutex{} + + return nil +} + +// WriteMsg write message in connection. +// if connection is down, try to re-connect. +func (c *aliLSWriter) WriteMsg(when time.Time, msg string, level int) (err error) { + + if level > c.Level { + return nil + } + + var topic string + var content string + var lg *LogGroup + if c.withMap { + + // 解析出Topic,并匹配LogGroup + strs := strings.SplitN(msg, Delimiter, 2) + if len(strs) == 2 { + pos := strings.LastIndex(strs[0], " ") + topic = strs[0][pos+1 : len(strs[0])] + content = strs[0][0:pos] + strs[1] + lg = c.groupMap[topic] + } + + // 默认发到空Topic + if lg == nil { + topic = "" + content = msg + lg = c.group[0] + } + } else { + topic = "" + content = msg + lg = c.group[0] + } + + // 生成日志 + c1 := &Log_Content{ + Key: proto.String("msg"), + Value: proto.String(content), + } + + l := &Log{ + Time: proto.Uint32(uint32(when.Unix())), // 填写日志时间 + Contents: []*Log_Content{ + c1, + }, + } + + c.lock.Lock() + lg.Logs = append(lg.Logs, l) + c.lock.Unlock() + + // 满足条件则Flush + if len(lg.Logs) >= c.FlushWhen { + c.flush(lg) + } + + return nil +} + +// Flush implementing method. empty. +func (c *aliLSWriter) Flush() { + + // flush所有group + for _, lg := range c.group { + c.flush(lg) + } +} + +// Destroy destroy connection writer and close tcp listener. +func (c *aliLSWriter) Destroy() { +} + +func (c *aliLSWriter) flush(lg *LogGroup) { + + c.lock.Lock() + defer c.lock.Unlock() + + // 把以上的LogGroup推送到SLS服务器, + // SLS服务器会根据该logstore的shard个数自动进行负载均衡。 + err := c.store.PutLogs(lg) + if err != nil { + return + } + + lg.Logs = make([]*Log, 0, c.FlushWhen) +} + +func init() { + logs.Register(logs.AdapterAliLS, NewAliLS) +} diff --git a/logs/alils/config.go b/logs/alils/config.go new file mode 100755 index 00000000..e8c24448 --- /dev/null +++ b/logs/alils/config.go @@ -0,0 +1,13 @@ +package alils + +const ( + version = "0.5.0" // SDK version + signatureMethod = "hmac-sha1" // Signature method + + // OffsetNewest stands for the log head offset, i.e. the offset that will be + // assigned to the next message that will be produced to the shard. + OffsetNewest = "end" + // OffsetOldest stands for the oldest offset available on the logstore for a + // shard. + OffsetOldest = "begin" +) diff --git a/logs/alils/log.pb.go b/logs/alils/log.pb.go new file mode 100755 index 00000000..b60ca200 --- /dev/null +++ b/logs/alils/log.pb.go @@ -0,0 +1,984 @@ +package alils + +import "github.com/gogo/protobuf/proto" +import "fmt" +import "math" + +// discarding unused import gogoproto "." + +import github_com_gogo_protobuf_proto "github.com/gogo/protobuf/proto" + +import "io" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +type Log struct { + Time *uint32 `protobuf:"varint,1,req,name=Time" json:"Time,omitempty"` + Contents []*Log_Content `protobuf:"bytes,2,rep,name=Contents" json:"Contents,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *Log) Reset() { *m = Log{} } +func (m *Log) String() string { return proto.CompactTextString(m) } +func (*Log) ProtoMessage() {} + +func (m *Log) GetTime() uint32 { + if m != nil && m.Time != nil { + return *m.Time + } + return 0 +} + +func (m *Log) GetContents() []*Log_Content { + if m != nil { + return m.Contents + } + return nil +} + +type Log_Content struct { + Key *string `protobuf:"bytes,1,req,name=Key" json:"Key,omitempty"` + Value *string `protobuf:"bytes,2,req,name=Value" json:"Value,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *Log_Content) Reset() { *m = Log_Content{} } +func (m *Log_Content) String() string { return proto.CompactTextString(m) } +func (*Log_Content) ProtoMessage() {} + +func (m *Log_Content) GetKey() string { + if m != nil && m.Key != nil { + return *m.Key + } + return "" +} + +func (m *Log_Content) GetValue() string { + if m != nil && m.Value != nil { + return *m.Value + } + return "" +} + +type LogGroup struct { + Logs []*Log `protobuf:"bytes,1,rep,name=Logs" json:"Logs,omitempty"` + Reserved *string `protobuf:"bytes,2,opt,name=Reserved" json:"Reserved,omitempty"` + Topic *string `protobuf:"bytes,3,opt,name=Topic" json:"Topic,omitempty"` + Source *string `protobuf:"bytes,4,opt,name=Source" json:"Source,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *LogGroup) Reset() { *m = LogGroup{} } +func (m *LogGroup) String() string { return proto.CompactTextString(m) } +func (*LogGroup) ProtoMessage() {} + +func (m *LogGroup) GetLogs() []*Log { + if m != nil { + return m.Logs + } + return nil +} + +func (m *LogGroup) GetReserved() string { + if m != nil && m.Reserved != nil { + return *m.Reserved + } + return "" +} + +func (m *LogGroup) GetTopic() string { + if m != nil && m.Topic != nil { + return *m.Topic + } + return "" +} + +func (m *LogGroup) GetSource() string { + if m != nil && m.Source != nil { + return *m.Source + } + return "" +} + +type LogGroupList struct { + LogGroups []*LogGroup `protobuf:"bytes,1,rep,name=logGroups" json:"logGroups,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *LogGroupList) Reset() { *m = LogGroupList{} } +func (m *LogGroupList) String() string { return proto.CompactTextString(m) } +func (*LogGroupList) ProtoMessage() {} + +func (m *LogGroupList) GetLogGroups() []*LogGroup { + if m != nil { + return m.LogGroups + } + return nil +} + +func (m *Log) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +func (m *Log) MarshalTo(data []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if m.Time == nil { + return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Time") + } else { + data[i] = 0x8 + i++ + i = encodeVarintLog(data, i, uint64(*m.Time)) + } + if len(m.Contents) > 0 { + for _, msg := range m.Contents { + data[i] = 0x12 + i++ + i = encodeVarintLog(data, i, uint64(msg.Size())) + n, err := msg.MarshalTo(data[i:]) + if err != nil { + return 0, err + } + i += n + } + } + if m.XXX_unrecognized != nil { + i += copy(data[i:], m.XXX_unrecognized) + } + return i, nil +} + +func (m *Log_Content) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +func (m *Log_Content) MarshalTo(data []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if m.Key == nil { + return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Key") + } else { + data[i] = 0xa + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Key))) + i += copy(data[i:], *m.Key) + } + if m.Value == nil { + return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Value") + } else { + data[i] = 0x12 + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Value))) + i += copy(data[i:], *m.Value) + } + if m.XXX_unrecognized != nil { + i += copy(data[i:], m.XXX_unrecognized) + } + return i, nil +} + +func (m *LogGroup) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +func (m *LogGroup) MarshalTo(data []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Logs) > 0 { + for _, msg := range m.Logs { + data[i] = 0xa + i++ + i = encodeVarintLog(data, i, uint64(msg.Size())) + n, err := msg.MarshalTo(data[i:]) + if err != nil { + return 0, err + } + i += n + } + } + if m.Reserved != nil { + data[i] = 0x12 + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Reserved))) + i += copy(data[i:], *m.Reserved) + } + if m.Topic != nil { + data[i] = 0x1a + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Topic))) + i += copy(data[i:], *m.Topic) + } + if m.Source != nil { + data[i] = 0x22 + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Source))) + i += copy(data[i:], *m.Source) + } + if m.XXX_unrecognized != nil { + i += copy(data[i:], m.XXX_unrecognized) + } + return i, nil +} + +func (m *LogGroupList) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +func (m *LogGroupList) MarshalTo(data []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.LogGroups) > 0 { + for _, msg := range m.LogGroups { + data[i] = 0xa + i++ + i = encodeVarintLog(data, i, uint64(msg.Size())) + n, err := msg.MarshalTo(data[i:]) + if err != nil { + return 0, err + } + i += n + } + } + if m.XXX_unrecognized != nil { + i += copy(data[i:], m.XXX_unrecognized) + } + return i, nil +} + +func encodeFixed64Log(data []byte, offset int, v uint64) int { + data[offset] = uint8(v) + data[offset+1] = uint8(v >> 8) + data[offset+2] = uint8(v >> 16) + data[offset+3] = uint8(v >> 24) + data[offset+4] = uint8(v >> 32) + data[offset+5] = uint8(v >> 40) + data[offset+6] = uint8(v >> 48) + data[offset+7] = uint8(v >> 56) + return offset + 8 +} +func encodeFixed32Log(data []byte, offset int, v uint32) int { + data[offset] = uint8(v) + data[offset+1] = uint8(v >> 8) + data[offset+2] = uint8(v >> 16) + data[offset+3] = uint8(v >> 24) + return offset + 4 +} +func encodeVarintLog(data []byte, offset int, v uint64) int { + for v >= 1<<7 { + data[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + data[offset] = uint8(v) + return offset + 1 +} +func (m *Log) Size() (n int) { + var l int + _ = l + if m.Time != nil { + n += 1 + sovLog(uint64(*m.Time)) + } + if len(m.Contents) > 0 { + for _, e := range m.Contents { + l = e.Size() + n += 1 + l + sovLog(uint64(l)) + } + } + if m.XXX_unrecognized != nil { + n += len(m.XXX_unrecognized) + } + return n +} + +func (m *Log_Content) Size() (n int) { + var l int + _ = l + if m.Key != nil { + l = len(*m.Key) + n += 1 + l + sovLog(uint64(l)) + } + if m.Value != nil { + l = len(*m.Value) + n += 1 + l + sovLog(uint64(l)) + } + if m.XXX_unrecognized != nil { + n += len(m.XXX_unrecognized) + } + return n +} + +func (m *LogGroup) Size() (n int) { + var l int + _ = l + if len(m.Logs) > 0 { + for _, e := range m.Logs { + l = e.Size() + n += 1 + l + sovLog(uint64(l)) + } + } + if m.Reserved != nil { + l = len(*m.Reserved) + n += 1 + l + sovLog(uint64(l)) + } + if m.Topic != nil { + l = len(*m.Topic) + n += 1 + l + sovLog(uint64(l)) + } + if m.Source != nil { + l = len(*m.Source) + n += 1 + l + sovLog(uint64(l)) + } + if m.XXX_unrecognized != nil { + n += len(m.XXX_unrecognized) + } + return n +} + +func (m *LogGroupList) Size() (n int) { + var l int + _ = l + if len(m.LogGroups) > 0 { + for _, e := range m.LogGroups { + l = e.Size() + n += 1 + l + sovLog(uint64(l)) + } + } + if m.XXX_unrecognized != nil { + n += len(m.XXX_unrecognized) + } + return n +} + +func sovLog(x uint64) (n int) { + for { + n++ + x >>= 7 + if x == 0 { + break + } + } + return n +} +func sozLog(x uint64) (n int) { + return sovLog((x << 1) ^ (x >> 63)) +} +func (m *Log) Unmarshal(data []byte) error { + var hasFields [1]uint64 + l := len(data) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Log: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Log: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Time", wireType) + } + var v uint32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + v |= (uint32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Time = &v + hasFields[0] |= uint64(0x00000001) + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Contents", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Contents = append(m.Contents, &Log_Content{}) + if err := m.Contents[len(m.Contents)-1].Unmarshal(data[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipLog(data[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthLog + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXX_unrecognized = append(m.XXX_unrecognized, data[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + if hasFields[0]&uint64(0x00000001) == 0 { + return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Time") + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *Log_Content) Unmarshal(data []byte) error { + var hasFields [1]uint64 + l := len(data) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Content: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Content: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Key", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(data[iNdEx:postIndex]) + m.Key = &s + iNdEx = postIndex + hasFields[0] |= uint64(0x00000001) + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Value", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(data[iNdEx:postIndex]) + m.Value = &s + iNdEx = postIndex + hasFields[0] |= uint64(0x00000002) + default: + iNdEx = preIndex + skippy, err := skipLog(data[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthLog + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXX_unrecognized = append(m.XXX_unrecognized, data[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + if hasFields[0]&uint64(0x00000001) == 0 { + return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Key") + } + if hasFields[0]&uint64(0x00000002) == 0 { + return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Value") + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *LogGroup) Unmarshal(data []byte) error { + l := len(data) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: LogGroup: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: LogGroup: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Logs", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Logs = append(m.Logs, &Log{}) + if err := m.Logs[len(m.Logs)-1].Unmarshal(data[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Reserved", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(data[iNdEx:postIndex]) + m.Reserved = &s + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Topic", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(data[iNdEx:postIndex]) + m.Topic = &s + iNdEx = postIndex + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Source", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(data[iNdEx:postIndex]) + m.Source = &s + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipLog(data[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthLog + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXX_unrecognized = append(m.XXX_unrecognized, data[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *LogGroupList) Unmarshal(data []byte) error { + l := len(data) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: LogGroupList: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: LogGroupList: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field LogGroups", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.LogGroups = append(m.LogGroups, &LogGroup{}) + if err := m.LogGroups[len(m.LogGroups)-1].Unmarshal(data[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipLog(data[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthLog + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXX_unrecognized = append(m.XXX_unrecognized, data[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipLog(data []byte) (n int, err error) { + l := len(data) + iNdEx := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowLog + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowLog + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if data[iNdEx-1] < 0x80 { + break + } + } + return iNdEx, nil + case 1: + iNdEx += 8 + return iNdEx, nil + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowLog + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + iNdEx += length + if length < 0 { + return 0, ErrInvalidLengthLog + } + return iNdEx, nil + case 3: + for { + var innerWire uint64 + var start int = iNdEx + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowLog + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + innerWire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + innerWireType := int(innerWire & 0x7) + if innerWireType == 4 { + break + } + next, err := skipLog(data[start:]) + if err != nil { + return 0, err + } + iNdEx = start + next + } + return iNdEx, nil + case 4: + return iNdEx, nil + case 5: + iNdEx += 4 + return iNdEx, nil + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + } + panic("unreachable") +} + +var ( + ErrInvalidLengthLog = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowLog = fmt.Errorf("proto: integer overflow") +) diff --git a/logs/alils/log_config.go b/logs/alils/log_config.go new file mode 100755 index 00000000..41fa0959 --- /dev/null +++ b/logs/alils/log_config.go @@ -0,0 +1,39 @@ +package alils + +type InputDetail struct { + LogType string `json:"logType"` + LogPath string `json:"logPath"` + FilePattern string `json:"filePattern"` + LocalStorage bool `json:"localStorage"` + TimeFormat string `json:"timeFormat"` + LogBeginRegex string `json:"logBeginRegex"` + Regex string `json:"regex"` + Keys []string `json:"key"` + FilterKeys []string `json:"filterKey"` + FilterRegex []string `json:"filterRegex"` + TopicFormat string `json:"topicFormat"` +} + +type OutputDetail struct { + Endpoint string `json:"endpoint"` + LogStoreName string `json:"logstoreName"` +} + +type LogConfig struct { + Name string `json:"configName"` + InputType string `json:"inputType"` + InputDetail InputDetail `json:"inputDetail"` + OutputType string `json:"outputType"` + OutputDetail OutputDetail `json:"outputDetail"` + + CreateTime uint32 + LastModifyTime uint32 + + project *LogProject +} + +// GetAppliedMachineGroup returns applied machine group of this config. +func (c *LogConfig) GetAppliedMachineGroup(confName string) (groupNames []string, err error) { + groupNames, err = c.project.GetAppliedMachineGroups(c.Name) + return +} diff --git a/logs/alils/log_project.go b/logs/alils/log_project.go new file mode 100755 index 00000000..63ab07f8 --- /dev/null +++ b/logs/alils/log_project.go @@ -0,0 +1,818 @@ +/* +Package sls implements the SDK(v0.5.0) of Simple Log Service(abbr. SLS). + +For more description about SLS, please read this article: +http://gitlab.alibaba-inc.com/sls/doc. +*/ +package alils + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httputil" +) + +// Error message in SLS HTTP response. +type errorMessage struct { + Code string `json:"errorCode"` + Message string `json:"errorMessage"` +} + +type LogProject struct { + Name string // Project name + Endpoint string // IP or hostname of SLS endpoint + AccessKeyId string + AccessKeySecret string +} + +// NewLogProject creates a new SLS project. +func NewLogProject(name, endpoint, accessKeyId, accessKeySecret string) (p *LogProject, err error) { + p = &LogProject{ + Name: name, + Endpoint: endpoint, + AccessKeyId: accessKeyId, + AccessKeySecret: accessKeySecret, + } + return p, nil +} + +// ListLogStore returns all logstore names of project p. +func (p *LogProject) ListLogStore() (storeNames []string, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/logstores") + r, err := request(p, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to list logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Body struct { + Count int + LogStores []string + } + body := &Body{} + + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + storeNames = body.LogStores + + return +} + +// GetLogStore returns logstore according by logstore name. +func (p *LogProject) GetLogStore(name string) (s *LogStore, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "GET", "/logstores/"+name, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to get logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + s = &LogStore{} + err = json.Unmarshal(buf, s) + if err != nil { + return + } + s.project = p + return +} + +// CreateLogStore creates a new logstore in SLS, +// where name is logstore name, +// and ttl is time-to-live(in day) of logs, +// and shardCnt is the number of shards. +func (p *LogProject) CreateLogStore(name string, ttl, shardCnt int) (err error) { + + type Body struct { + Name string `json:"logstoreName"` + TTL int `json:"ttl"` + ShardCount int `json:"shardCount"` + } + + store := &Body{ + Name: name, + TTL: ttl, + ShardCount: shardCnt, + } + + body, err := json.Marshal(store) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "POST", "/logstores", h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to create logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// DeleteLogStore deletes a logstore according by logstore name. +func (p *LogProject) DeleteLogStore(name string) (err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "DELETE", "/logstores/"+name, h, nil) + if err != nil { + return + } + + body, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to delete logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} + +// UpdateLogStore updates a logstore according by logstore name, +// obviously we can't modify the logstore name itself. +func (p *LogProject) UpdateLogStore(name string, ttl, shardCnt int) (err error) { + + type Body struct { + Name string `json:"logstoreName"` + TTL int `json:"ttl"` + ShardCount int `json:"shardCount"` + } + + store := &Body{ + Name: name, + TTL: ttl, + ShardCount: shardCnt, + } + + body, err := json.Marshal(store) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "PUT", "/logstores", h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to update logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// ListMachineGroup returns machine group name list and the total number of machine groups. +// The offset starts from 0 and the size is the max number of machine groups could be returned. +func (p *LogProject) ListMachineGroup(offset, size int) (m []string, total int, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + if size <= 0 { + size = 500 + } + + uri := fmt.Sprintf("/machinegroups?offset=%v&size=%v", offset, size) + r, err := request(p, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to list machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Body struct { + MachineGroups []string + Count int + Total int + } + body := &Body{} + + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + m = body.MachineGroups + total = body.Total + + return +} + +// GetMachineGroup retruns machine group according by machine group name. +func (p *LogProject) GetMachineGroup(name string) (m *MachineGroup, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "GET", "/machinegroups/"+name, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to get machine group:%v", name) + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + m = &MachineGroup{} + err = json.Unmarshal(buf, m) + if err != nil { + return + } + m.project = p + return +} + +// CreateMachineGroup creates a new machine group in SLS. +func (p *LogProject) CreateMachineGroup(m *MachineGroup) (err error) { + + body, err := json.Marshal(m) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "POST", "/machinegroups", h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to create machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// UpdateMachineGroup updates a machine group. +func (p *LogProject) UpdateMachineGroup(m *MachineGroup) (err error) { + + body, err := json.Marshal(m) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "PUT", "/machinegroups/"+m.Name, h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to update machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// DeleteMachineGroup deletes machine group according machine group name. +func (p *LogProject) DeleteMachineGroup(name string) (err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "DELETE", "/machinegroups/"+name, h, nil) + if err != nil { + return + } + + body, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to delete machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} + +// ListConfig returns config names list and the total number of configs. +// The offset starts from 0 and the size is the max number of configs could be returned. +func (p *LogProject) ListConfig(offset, size int) (cfgNames []string, total int, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + if size <= 0 { + size = 100 + } + + uri := fmt.Sprintf("/configs?offset=%v&size=%v", offset, size) + r, err := request(p, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to delete machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Body struct { + Total int + Configs []string + } + body := &Body{} + + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + cfgNames = body.Configs + total = body.Total + return +} + +// GetConfig returns config according by config name. +func (p *LogProject) GetConfig(name string) (c *LogConfig, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "GET", "/configs/"+name, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to delete config") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + c = &LogConfig{} + err = json.Unmarshal(buf, c) + if err != nil { + return + } + c.project = p + return +} + +// UpdateConfig updates a config. +func (p *LogProject) UpdateConfig(c *LogConfig) (err error) { + + body, err := json.Marshal(c) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "PUT", "/configs/"+c.Name, h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to update config") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// CreateConfig creates a new config in SLS. +func (p *LogProject) CreateConfig(c *LogConfig) (err error) { + + body, err := json.Marshal(c) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "POST", "/configs", h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to update config") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// DeleteConfig deletes a config according by config name. +func (p *LogProject) DeleteConfig(name string) (err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "DELETE", "/configs/"+name, h, nil) + if err != nil { + return + } + + body, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to delete config") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} + +// GetAppliedMachineGroups returns applied machine group names list according config name. +func (p *LogProject) GetAppliedMachineGroups(confName string) (groupNames []string, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/configs/%v/machinegroups", confName) + r, err := request(p, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to get applied machine groups") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Body struct { + Count int + Machinegroups []string + } + + body := &Body{} + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + groupNames = body.Machinegroups + return +} + +// GetAppliedConfigs returns applied config names list according machine group name groupName. +func (p *LogProject) GetAppliedConfigs(groupName string) (confNames []string, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/machinegroups/%v/configs", groupName) + r, err := request(p, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to applied configs") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Cfg struct { + Count int `json:"count"` + Configs []string `json:"configs"` + } + + body := &Cfg{} + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + confNames = body.Configs + return +} + +// ApplyConfigToMachineGroup applies config to machine group. +func (p *LogProject) ApplyConfigToMachineGroup(confName, groupName string) (err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/machinegroups/%v/configs/%v", groupName, confName) + r, err := request(p, "PUT", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to apply config to machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} + +// RemoveConfigFromMachineGroup removes config from machine group. +func (p *LogProject) RemoveConfigFromMachineGroup(confName, groupName string) (err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/machinegroups/%v/configs/%v", groupName, confName) + r, err := request(p, "DELETE", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to remove config from machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} diff --git a/logs/alils/log_store.go b/logs/alils/log_store.go new file mode 100755 index 00000000..009e39c4 --- /dev/null +++ b/logs/alils/log_store.go @@ -0,0 +1,269 @@ +package alils + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httputil" + "strconv" + + lz4 "github.com/cloudflare/golz4" + "github.com/gogo/protobuf/proto" +) + +type LogStore struct { + Name string `json:"logstoreName"` + TTL int + ShardCount int + + CreateTime uint32 + LastModifyTime uint32 + + project *LogProject +} + +type Shard struct { + ShardID int `json:"shardID"` +} + +// ListShards returns shard id list of this logstore. +func (s *LogStore) ListShards() (shardIDs []int, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/logstores/%v/shards", s.Name) + r, err := request(s.project, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to list logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Println(dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + var shards []*Shard + err = json.Unmarshal(buf, &shards) + if err != nil { + return + } + + for _, v := range shards { + shardIDs = append(shardIDs, v.ShardID) + } + return +} + +// PutLogs put logs into logstore. +// The callers should transform user logs into LogGroup. +func (s *LogStore) PutLogs(lg *LogGroup) (err error) { + body, err := proto.Marshal(lg) + if err != nil { + return + } + + // Compresse body with lz4 + out := make([]byte, lz4.CompressBound(body)) + n, err := lz4.Compress(body, out) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-compresstype": "lz4", + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/x-protobuf", + } + + uri := fmt.Sprintf("/logstores/%v", s.Name) + r, err := request(s.project, "POST", uri, h, out[:n]) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to put logs") + dump, _ := httputil.DumpResponse(r, true) + fmt.Println(dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} + +// GetCursor gets log cursor of one shard specified by shardId. +// The from can be in three form: a) unix timestamp in seccond, b) "begin", c) "end". +// For more detail please read: http://gitlab.alibaba-inc.com/sls/doc/blob/master/api/shard.md#logstore +func (s *LogStore) GetCursor(shardId int, from string) (cursor string, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/logstores/%v/shards/%v?type=cursor&from=%v", + s.Name, shardId, from) + + r, err := request(s.project, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to get cursor") + dump, _ := httputil.DumpResponse(r, true) + fmt.Println(dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Body struct { + Cursor string + } + body := &Body{} + + err = json.Unmarshal(buf, body) + if err != nil { + return + } + cursor = body.Cursor + return +} + +// GetLogsBytes gets logs binary data from shard specified by shardId according cursor. +// The logGroupMaxCount is the max number of logGroup could be returned. +// The nextCursor is the next curosr can be used to read logs at next time. +func (s *LogStore) GetLogsBytes(shardId int, cursor string, + logGroupMaxCount int) (out []byte, nextCursor string, err error) { + + h := map[string]string{ + "x-sls-bodyrawsize": "0", + "Accept": "application/x-protobuf", + "Accept-Encoding": "lz4", + } + + uri := fmt.Sprintf("/logstores/%v/shards/%v?type=logs&cursor=%v&count=%v", + s.Name, shardId, cursor, logGroupMaxCount) + + r, err := request(s.project, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to get cursor") + dump, _ := httputil.DumpResponse(r, true) + fmt.Println(dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + v, ok := r.Header["X-Sls-Compresstype"] + if !ok || len(v) == 0 { + err = fmt.Errorf("can't find 'x-sls-compresstype' header") + return + } + if v[0] != "lz4" { + err = fmt.Errorf("unexpected compress type:%v", v[0]) + return + } + + v, ok = r.Header["X-Sls-Cursor"] + if !ok || len(v) == 0 { + err = fmt.Errorf("can't find 'x-sls-cursor' header") + return + } + nextCursor = v[0] + + v, ok = r.Header["X-Sls-Bodyrawsize"] + if !ok || len(v) == 0 { + err = fmt.Errorf("can't find 'x-sls-bodyrawsize' header") + return + } + bodyRawSize, err := strconv.Atoi(v[0]) + if err != nil { + return + } + + out = make([]byte, bodyRawSize) + err = lz4.Uncompress(buf, out) + if err != nil { + return + } + + return +} + +// LogsBytesDecode decodes logs binary data retruned by GetLogsBytes API +func LogsBytesDecode(data []byte) (gl *LogGroupList, err error) { + + gl = &LogGroupList{} + err = proto.Unmarshal(data, gl) + if err != nil { + return + } + + return +} + +// GetLogs gets logs from shard specified by shardId according cursor. +// The logGroupMaxCount is the max number of logGroup could be returned. +// The nextCursor is the next curosr can be used to read logs at next time. +func (s *LogStore) GetLogs(shardId int, cursor string, + logGroupMaxCount int) (gl *LogGroupList, nextCursor string, err error) { + + out, nextCursor, err := s.GetLogsBytes(shardId, cursor, logGroupMaxCount) + if err != nil { + return + } + + gl, err = LogsBytesDecode(out) + if err != nil { + return + } + + return +} diff --git a/logs/alils/machine_group.go b/logs/alils/machine_group.go new file mode 100755 index 00000000..7a0aace1 --- /dev/null +++ b/logs/alils/machine_group.go @@ -0,0 +1,87 @@ +package alils + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httputil" +) + +type MachinGroupAttribute struct { + ExternalName string `json:"externalName"` + TopicName string `json:"groupTopic"` +} + +type MachineGroup struct { + Name string `json:"groupName"` + Type string `json:"groupType"` + MachineIdType string `json:"machineIdentifyType"` + MachineIdList []string `json:"machineList"` + + Attribute MachinGroupAttribute `json:"groupAttribute"` + + CreateTime uint32 + LastModifyTime uint32 + + project *LogProject +} + +type Machine struct { + IP string + UniqueId string `json:"machine-uniqueid"` + UserdefinedId string `json:"userdefined-id"` +} + +type MachineList struct { + Total int + Machines []*Machine +} + +// ListMachines returns machine list of this machine group. +func (m *MachineGroup) ListMachines() (ms []*Machine, total int, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/machinegroups/%v/machines", m.Name) + r, err := request(m.project, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to remove config from machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Println(dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + body := &MachineList{} + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + ms = body.Machines + total = body.Total + + return +} + +// GetAppliedConfigs returns applied configs of this machine group. +func (m *MachineGroup) GetAppliedConfigs() (confNames []string, err error) { + confNames, err = m.project.GetAppliedConfigs(m.Name) + return +} diff --git a/logs/alils/request.go b/logs/alils/request.go new file mode 100755 index 00000000..20df45b4 --- /dev/null +++ b/logs/alils/request.go @@ -0,0 +1,62 @@ +package alils + +import ( + "bytes" + "crypto/md5" + "fmt" + "net/http" +) + +// request sends a request to SLS. +func request(project *LogProject, method, uri string, headers map[string]string, + body []byte) (resp *http.Response, err error) { + + // The caller should provide 'x-sls-bodyrawsize' header + if _, ok := headers["x-sls-bodyrawsize"]; !ok { + err = fmt.Errorf("Can't find 'x-sls-bodyrawsize' header") + return + } + + // SLS public request headers + headers["Host"] = project.Name + "." + project.Endpoint + headers["Date"] = nowRFC1123() + headers["x-sls-apiversion"] = version + headers["x-sls-signaturemethod"] = signatureMethod + if body != nil { + bodyMD5 := fmt.Sprintf("%X", md5.Sum(body)) + headers["Content-MD5"] = bodyMD5 + + if _, ok := headers["Content-Type"]; !ok { + err = fmt.Errorf("Can't find 'Content-Type' header") + return + } + } + + // Calc Authorization + // Authorization = "SLS :" + digest, err := signature(project, method, uri, headers) + if err != nil { + return + } + auth := fmt.Sprintf("SLS %v:%v", project.AccessKeyId, digest) + headers["Authorization"] = auth + + // Initialize http request + reader := bytes.NewReader(body) + urlStr := fmt.Sprintf("http://%v.%v%v", project.Name, project.Endpoint, uri) + req, err := http.NewRequest(method, urlStr, reader) + if err != nil { + return + } + for k, v := range headers { + req.Header.Add(k, v) + } + + // Get ready to do request + resp, err = http.DefaultClient.Do(req) + if err != nil { + return + } + + return +} diff --git a/logs/alils/signature.go b/logs/alils/signature.go new file mode 100755 index 00000000..2d611307 --- /dev/null +++ b/logs/alils/signature.go @@ -0,0 +1,111 @@ +package alils + +import ( + "crypto/hmac" + "crypto/sha1" + "encoding/base64" + "fmt" + "net/url" + "sort" + "strings" + "time" +) + +// GMT location +var gmtLoc = time.FixedZone("GMT", 0) + +// NowRFC1123 returns now time in RFC1123 format with GMT timezone, +// eg. "Mon, 02 Jan 2006 15:04:05 GMT". +func nowRFC1123() string { + return time.Now().In(gmtLoc).Format(time.RFC1123) +} + +// signature calculates a request's signature digest. +func signature(project *LogProject, method, uri string, + headers map[string]string) (digest string, err error) { + var contentMD5, contentType, date, canoHeaders, canoResource string + var slsHeaderKeys sort.StringSlice + + // SignString = VERB + "\n" + // + CONTENT-MD5 + "\n" + // + CONTENT-TYPE + "\n" + // + DATE + "\n" + // + CanonicalizedSLSHeaders + "\n" + // + CanonicalizedResource + + if val, ok := headers["Content-MD5"]; ok { + contentMD5 = val + } + + if val, ok := headers["Content-Type"]; ok { + contentType = val + } + + date, ok := headers["Date"] + if !ok { + err = fmt.Errorf("Can't find 'Date' header") + return + } + + // Calc CanonicalizedSLSHeaders + slsHeaders := make(map[string]string, len(headers)) + for k, v := range headers { + l := strings.TrimSpace(strings.ToLower(k)) + if strings.HasPrefix(l, "x-sls-") { + slsHeaders[l] = strings.TrimSpace(v) + slsHeaderKeys = append(slsHeaderKeys, l) + } + } + + sort.Sort(slsHeaderKeys) + for i, k := range slsHeaderKeys { + canoHeaders += k + ":" + slsHeaders[k] + if i+1 < len(slsHeaderKeys) { + canoHeaders += "\n" + } + } + + // Calc CanonicalizedResource + u, err := url.Parse(uri) + if err != nil { + return + } + + canoResource += url.QueryEscape(u.Path) + if u.RawQuery != "" { + var keys sort.StringSlice + + vals := u.Query() + for k := range vals { + keys = append(keys, k) + } + + sort.Sort(keys) + canoResource += "?" + for i, k := range keys { + if i > 0 { + canoResource += "&" + } + + for _, v := range vals[k] { + canoResource += k + "=" + v + } + } + } + + signStr := method + "\n" + + contentMD5 + "\n" + + contentType + "\n" + + date + "\n" + + canoHeaders + "\n" + + canoResource + + // Signature = base64(hmac-sha1(UTF8-Encoding-Of(SignString),AccessKeySecret)) + mac := hmac.New(sha1.New, []byte(project.AccessKeySecret)) + _, err = mac.Write([]byte(signStr)) + if err != nil { + return + } + digest = base64.StdEncoding.EncodeToString(mac.Sum(nil)) + return +} diff --git a/logs/console.go b/logs/console.go index e6bf6c29..e75f2a1b 100644 --- a/logs/console.go +++ b/logs/console.go @@ -41,7 +41,7 @@ var colors = []brush{ newBrush("1;33"), // Warning yellow newBrush("1;32"), // Notice green newBrush("1;34"), // Informational blue - newBrush("1;34"), // Debug blue + newBrush("1;44"), // Debug Background blue } // consoleWriter implements LoggerInterface and writes messages to terminal. diff --git a/logs/file.go b/logs/file.go index 42146dae..32595ec4 100644 --- a/logs/file.go +++ b/logs/file.go @@ -193,16 +193,14 @@ func (w *fileLogWriter) dailyRotate(openTime time.Time) { y, m, d := openTime.Add(24 * time.Hour).Date() nextDay := time.Date(y, m, d, 0, 0, 0, 0, openTime.Location()) tm := time.NewTimer(time.Duration(nextDay.UnixNano() - openTime.UnixNano() + 100)) - select { - case <-tm.C: - w.Lock() - if w.needRotate(0, time.Now().Day()) { - if err := w.doRotate(time.Now()); err != nil { - fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) - } + <-tm.C + w.Lock() + if w.needRotate(0, time.Now().Day()) { + if err := w.doRotate(time.Now()); err != nil { + fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) } - w.Unlock() } + w.Unlock() } func (w *fileLogWriter) lines() (int, error) { @@ -270,6 +268,7 @@ func (w *fileLogWriter) doRotate(logTime time.Time) error { // Rename the file to its new found name // even if occurs error,we MUST guarantee to restart new logger err = os.Rename(w.Filename, fName) + err = os.Chmod(fName, os.FileMode(0440)) // re-start logger RESTART_LOGGER: diff --git a/logs/file_test.go b/logs/file_test.go index 69a66d84..f345ff20 100644 --- a/logs/file_test.go +++ b/logs/file_test.go @@ -162,7 +162,27 @@ func TestFileRotate_05(t *testing.T) { testFileDailyRotate(t, fn1, fn2) os.Remove(fn) } - +func TestFileRotate_06(t *testing.T) { //test file mode + log := NewLogger(10000) + log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`) + log.Debug("debug") + log.Info("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) + ".log" + s, _ := os.Lstat(rotateName) + if s.Mode() != 0440 { + os.Remove(rotateName) + os.Remove("test3.log") + t.Fatal("rotate file mode error") + } + os.Remove(rotateName) + os.Remove("test3.log") +} func testFileRotate(t *testing.T, fn1, fn2 string) { fw := &fileLogWriter{ Daily: true, diff --git a/logs/jianliao.go b/logs/jianliao.go index 3755118d..16bbdc2e 100644 --- a/logs/jianliao.go +++ b/logs/jianliao.go @@ -25,11 +25,7 @@ func newJLWriter() Logger { // Init JLWriter with json config string func (s *JLWriter) Init(jsonconfig string) error { - err := json.Unmarshal([]byte(jsonconfig), s) - if err != nil { - return err - } - return nil + return json.Unmarshal([]byte(jsonconfig), s) } // WriteMsg write message in smtp writer. @@ -56,10 +52,10 @@ func (s *JLWriter) WriteMsg(when time.Time, msg string, level int) error { if err != nil { return err } + defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return fmt.Errorf("Post webhook failed %s %d", resp.Status, resp.StatusCode) } - resp.Body.Close() return nil } diff --git a/logs/log.go b/logs/log.go index 3d512d2e..a5e0a9ea 100644 --- a/logs/log.go +++ b/logs/log.go @@ -71,6 +71,7 @@ const ( AdapterEs = "es" AdapterJianLiao = "jianliao" AdapterSlack = "slack" + AdapterAliLS = "alils" ) // Legacy log level constants to ensure backwards compatibility. @@ -274,7 +275,7 @@ func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error line = 0 } _, filename := path.Split(file) - msg = "[" + filename + ":" + strconv.FormatInt(int64(line), 10) + "] " + msg + msg = "[" + filename + ":" + strconv.Itoa(line) + "] " + msg } //set level info in front of filename info @@ -380,7 +381,10 @@ func (bl *BeeLogger) Error(format string, v ...interface{}) { // Warning Log WARNING level message. func (bl *BeeLogger) Warning(format string, v ...interface{}) { - bl.Warn(format, v...) + if LevelWarn > bl.level { + return + } + bl.writeMsg(LevelWarn, format, v...) } // Notice Log NOTICE level message. @@ -393,7 +397,10 @@ func (bl *BeeLogger) Notice(format string, v ...interface{}) { // Informational Log INFORMATIONAL level message. func (bl *BeeLogger) Informational(format string, v ...interface{}) { - bl.Info(format, v...) + if LevelInfo > bl.level { + return + } + bl.writeMsg(LevelInfo, format, v...) } // Debug Log DEBUG level message. @@ -425,7 +432,10 @@ func (bl *BeeLogger) Info(format string, v ...interface{}) { // Trace Log TRACE level message. // compatibility alias for Debug() func (bl *BeeLogger) Trace(format string, v ...interface{}) { - bl.Debug(format, v...) + if LevelDebug > bl.level { + return + } + bl.writeMsg(LevelDebug, format, v...) } // Flush flush all chan data. @@ -551,11 +561,7 @@ func SetLogFuncCallDepth(d int) { // SetLogger sets a new logger. func SetLogger(adapter string, config ...string) error { - err := beeLogger.SetLogger(adapter, config...) - if err != nil { - return err - } - return nil + return beeLogger.SetLogger(adapter, config...) } // Emergency logs a message at emergency level. diff --git a/logs/slack.go b/logs/slack.go index eddedd5d..d8482b3e 100644 --- a/logs/slack.go +++ b/logs/slack.go @@ -21,11 +21,7 @@ func newSLACKWriter() Logger { // Init SLACKWriter with json config string func (s *SLACKWriter) Init(jsonconfig string) error { - err := json.Unmarshal([]byte(jsonconfig), s) - if err != nil { - return err - } - return nil + return json.Unmarshal([]byte(jsonconfig), s) } // WriteMsg write message in smtp writer. @@ -44,10 +40,10 @@ func (s *SLACKWriter) WriteMsg(when time.Time, msg string, level int) error { if err != nil { return err } + defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return fmt.Errorf("Post webhook failed %s %d", resp.Status, resp.StatusCode) } - resp.Body.Close() return nil } diff --git a/logs/smtp.go b/logs/smtp.go index 834130ef..2ec0514d 100644 --- a/logs/smtp.go +++ b/logs/smtp.go @@ -52,11 +52,7 @@ func newSMTPWriter() Logger { // "level":LevelError // } func (s *SMTPWriter) Init(jsonconfig string) error { - err := json.Unmarshal([]byte(jsonconfig), s) - if err != nil { - return err - } - return nil + return json.Unmarshal([]byte(jsonconfig), s) } func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth { @@ -106,7 +102,7 @@ func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAd if err != nil { return err } - _, err = w.Write([]byte(msgContent)) + _, err = w.Write(msgContent) if err != nil { return err } @@ -116,12 +112,7 @@ func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAd return err } - err = client.Quit() - if err != nil { - return err - } - - return nil + return client.Quit() } // WriteMsg write message in smtp writer. diff --git a/namespace_test.go b/namespace_test.go index fc02b5fb..b3f20dff 100644 --- a/namespace_test.go +++ b/namespace_test.go @@ -139,10 +139,7 @@ func TestNamespaceCond(t *testing.T) { ns := NewNamespace("/v2") ns.Cond(func(ctx *context.Context) bool { - if ctx.Input.Domain() == "beego.me" { - return true - } - return false + return ctx.Input.Domain() == "beego.me" }). AutoRouter(&TestController{}) AddNamespace(ns) diff --git a/orm/cmd.go b/orm/cmd.go index 3638a75c..0ff4dc40 100644 --- a/orm/cmd.go +++ b/orm/cmd.go @@ -150,7 +150,7 @@ func (d *commandSyncDb) Run() error { } for _, fi := range mi.fields.fieldsDB { - if _, ok := columns[fi.column]; ok == false { + if _, ok := columns[fi.column]; !ok { fields = append(fields, fi) } } @@ -175,7 +175,7 @@ func (d *commandSyncDb) Run() error { } for _, idx := range indexes[mi.table] { - if d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) == false { + if !d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) { if !d.noInfo { fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table) } diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go index 8119b70b..de47cb02 100644 --- a/orm/cmd_utils.go +++ b/orm/cmd_utils.go @@ -89,7 +89,7 @@ checkColumn: col = T["float64"] case TypeDecimalField: s := T["float64-decimal"] - if strings.Index(s, "%d") == -1 { + if !strings.Contains(s, "%d") { col = s } else { col = fmt.Sprintf(s, fi.digits, fi.decimals) @@ -120,7 +120,7 @@ func getColumnAddQuery(al *alias, fi *fieldInfo) string { Q := al.DbBaser.TableQuote() typ := getColumnTyp(al, fi) - if fi.null == false { + if !fi.null { typ += " " + "NOT NULL" } @@ -172,7 +172,7 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex } else { column += col - if fi.null == false { + if !fi.null { column += " " + "NOT NULL" } @@ -192,7 +192,7 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex } } - if strings.Index(column, "%COL%") != -1 { + if strings.Contains(column, "%COL%") { column = strings.Replace(column, "%COL%", fi.column, -1) } diff --git a/orm/db.go b/orm/db.go index 30d8ae4e..1544c479 100644 --- a/orm/db.go +++ b/orm/db.go @@ -48,6 +48,7 @@ var ( "lte": true, "eq": true, "nq": true, + "ne": true, "startswith": true, "endswith": true, "istartswith": true, @@ -86,7 +87,7 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, } else { panic(fmt.Errorf("wrong db field/column name `%s` for model `%s`", column, mi.fullName)) } - if fi.dbcol == false || fi.auto && skipAuto { + if !fi.dbcol || fi.auto && skipAuto { continue } value, err := d.collectFieldValue(mi, fi, ind, insert, tz) @@ -223,7 +224,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val value = nil } } - if fi.null == false && value == nil { + if !fi.null && value == nil { return nil, fmt.Errorf("field `%s` cannot be NULL", fi.fullName) } } @@ -270,7 +271,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, dbcols := make([]string, 0, len(mi.fields.dbcols)) marks := make([]string, 0, len(mi.fields.dbcols)) for _, fi := range mi.fields.fieldsDB { - if fi.auto == false { + if !fi.auto { dbcols = append(dbcols, fi.column) marks = append(marks, "?") } @@ -325,7 +326,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo } else { // default use pk value as where condtion. pkColumn, pkValue, ok := getExistPk(mi, ind) - if ok == false { + if !ok { return ErrMissPK } whereCols = []string{pkColumn} @@ -600,7 +601,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a // execute update sql dbQuerier with given struct reflect.Value. func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { pkName, pkValue, ok := getExistPk(mi, ind) - if ok == false { + if !ok { return 0, ErrMissPK } @@ -653,7 +654,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. } else { // default use pk value as where condtion. pkColumn, pkValue, ok := getExistPk(mi, ind) - if ok == false { + if !ok { return 0, ErrMissPK } whereCols = []string{pkColumn} @@ -698,7 +699,7 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con columns := make([]string, 0, len(params)) values := make([]interface{}, 0, len(params)) for col, val := range params { - if fi, ok := mi.fields.GetByAny(col); ok == false || fi.dbcol == false { + if fi, ok := mi.fields.GetByAny(col); !ok || !fi.dbcol { panic(fmt.Errorf("wrong field/column name `%s`", col)) } else { columns = append(columns, fi.column) @@ -928,7 +929,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi if hasRel { for _, fi := range mi.fields.fieldsDB { if fi.fieldType&IsRelField > 0 { - if maps[fi.column] == false { + if !maps[fi.column] { tCols = append(tCols, fi.column) } } @@ -986,7 +987,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi var cnt int64 for rs.Next() { - if one && cnt == 0 || one == false { + if one && cnt == 0 || !one { if err := rs.Scan(refs...); err != nil { return 0, err } @@ -1066,7 +1067,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi cnt++ } - if one == false { + if !one { if cnt > 0 { ind.Set(slice) } else { @@ -1356,7 +1357,7 @@ end: func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) { fieldType := fi.fieldType - isNative := fi.isFielder == false + isNative := !fi.isFielder setValue: switch { @@ -1532,7 +1533,7 @@ setValue: } } - if isNative == false { + if !isNative { fd := field.Addr().Interface().(Fielder) err := fd.SetRaw(value) if err != nil { @@ -1593,7 +1594,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond infos = make([]*fieldInfo, 0, len(exprs)) for _, ex := range exprs { index, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep)) - if suc == false { + if !suc { panic(fmt.Errorf("unknown field/column name `%s`", ex)) } cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, fi.column, Q, Q, name, Q)) diff --git a/orm/db_alias.go b/orm/db_alias.go index c95d49c9..8daa4b23 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -186,7 +186,7 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error()) } - if dataBaseCache.add(aliasName, al) == false { + if !dataBaseCache.add(aliasName, al) { return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) } @@ -244,7 +244,7 @@ end: // RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. func RegisterDriver(driverName string, typ DriverType) error { - if t, ok := drivers[driverName]; ok == false { + if t, ok := drivers[driverName]; !ok { drivers[driverName] = typ } else { if t != typ { diff --git a/orm/db_mysql.go b/orm/db_mysql.go index 10fe2657..1016de2b 100644 --- a/orm/db_mysql.go +++ b/orm/db_mysql.go @@ -16,6 +16,8 @@ package orm import ( "fmt" + "reflect" + "strings" ) // mysql operators. @@ -96,6 +98,83 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool return cnt > 0 } +// InsertOrUpdate a row +// If your primary key or unique column conflict will update +// If no will insert +// Add "`" for mysql sql building +func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { + + iouStr := "" + argsMap := map[string]string{} + + iouStr = "ON DUPLICATE KEY UPDATE" + + //Get on the key-value pairs + for _, v := range args { + kv := strings.Split(v, "=") + if len(kv) == 2 { + argsMap[strings.ToLower(kv[0])] = kv[1] + } + } + + isMulti := false + names := make([]string, 0, len(mi.fields.dbcols)-1) + Q := d.ins.TableQuote() + values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ) + + if err != nil { + return 0, err + } + + marks := make([]string, len(names)) + updateValues := make([]interface{}, 0) + updates := make([]string, len(names)) + + for i, v := range names { + marks[i] = "?" + valueStr := argsMap[strings.ToLower(v)] + if valueStr != "" { + updates[i] = "`" + v + "`" + "=" + valueStr + } else { + updates[i] = "`" + v + "`" + "=?" + updateValues = append(updateValues, values[i]) + } + } + + values = append(values, updateValues...) + + sep := fmt.Sprintf("%s, %s", Q, Q) + qmarks := strings.Join(marks, ", ") + qupdates := strings.Join(updates, ", ") + columns := strings.Join(names, sep) + + multi := len(values) / len(names) + + if isMulti { + qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks + } + //conflitValue maybe is a int,can`t use fmt.Sprintf + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr) + + d.ins.ReplaceMarks(&query) + + if isMulti || !d.ins.HasReturningID(mi, &query) { + res, err := q.Exec(query, values...) + if err == nil { + if isMulti { + return res.RowsAffected() + } + return res.LastInsertId() + } + return 0, err + } + + row := q.QueryRow(query, values...) + var id int64 + err = row.Scan(&id) + return id, err +} + // create new mysql dbBaser. func newdbBaseMysql() dbBaser { b := new(dbBaseMysql) diff --git a/orm/db_sqlite.go b/orm/db_sqlite.go index a3cb69a7..a43a5594 100644 --- a/orm/db_sqlite.go +++ b/orm/db_sqlite.go @@ -134,7 +134,7 @@ func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool defer rows.Close() for rows.Next() { var tmp, index sql.NullString - rows.Scan(&tmp, &index, &tmp) + rows.Scan(&tmp, &index, &tmp, &tmp, &tmp) if name == index.String { return true } diff --git a/orm/db_tables.go b/orm/db_tables.go index e4c74ace..42be5550 100644 --- a/orm/db_tables.go +++ b/orm/db_tables.go @@ -63,7 +63,7 @@ func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) // add table info to collection. func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) { name := strings.Join(names, ExprSep) - if _, ok := t.tablesM[name]; ok == false { + if _, ok := t.tablesM[name]; !ok { i := len(t.tables) + 1 jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} t.tablesM[name] = jt @@ -261,7 +261,7 @@ loopFor: fiN, okN = mmi.fields.GetByAny(exprs[i+1]) } - if isRel && (fi.mi.isThrough == false || num != i) { + if isRel && (!fi.mi.isThrough || num != i) { if fi.null || t.skipEnd { inner = false } @@ -364,7 +364,7 @@ func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (whe } index, _, fi, suc := t.parseExprs(mi, exprs) - if suc == false { + if !suc { panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep))) } @@ -383,7 +383,7 @@ func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (whe } } - if sub == false && where != "" { + if !sub && where != "" { where = "WHERE " + where } @@ -403,7 +403,7 @@ func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) { exprs := strings.Split(group, ExprSep) index, _, fi, suc := t.parseExprs(t.mi, exprs) - if suc == false { + if !suc { panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) } @@ -432,7 +432,7 @@ func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) { exprs := strings.Split(order, ExprSep) index, _, fi, suc := t.parseExprs(t.mi, exprs) - if suc == false { + if !suc { panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) } diff --git a/orm/db_utils.go b/orm/db_utils.go index 0279a14a..7ae10ca5 100644 --- a/orm/db_utils.go +++ b/orm/db_utils.go @@ -41,6 +41,8 @@ func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interfac vu := v.Int() exist = true value = vu + } else if fi.fieldType&IsRelField > 0 { + _, value, exist = getExistPk(fi.relModelInfo, reflect.Indirect(v)) } else { vu := v.String() exist = vu != "" @@ -147,8 +149,10 @@ outFor: arg = v.In(tz).Format(formatDate) } else if fi != nil && fi.fieldType == TypeDateTimeField { arg = v.In(tz).Format(formatDateTime) - } else { + } else if fi != nil && fi.fieldType == TypeTimeField { arg = v.In(tz).Format(formatTime) + } else { + arg = v.In(tz).Format(formatDateTime) } } else { typ := val.Type() diff --git a/orm/models_boot.go b/orm/models_boot.go index 319785ce..85d0917f 100644 --- a/orm/models_boot.go +++ b/orm/models_boot.go @@ -117,7 +117,7 @@ func bootStrap() { name := getFullName(elm) mii, ok := modelCache.getByFullName(name) if !ok || mii.pkg != elm.PkgPath() { - err = fmt.Errorf("can not found rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String()) + err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String()) goto end } fi.relModelInfo = mii @@ -128,7 +128,7 @@ func bootStrap() { if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) { pn := fi.relThrough[:i] rmi, ok := modelCache.getByFullName(fi.relThrough) - if ok == false || pn != rmi.pkg { + if !ok || pn != rmi.pkg { err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough) goto end } @@ -171,7 +171,7 @@ func bootStrap() { break } } - if inModel == false { + if !inModel { rmi := fi.relModelInfo ffi := new(fieldInfo) ffi.name = mi.name @@ -185,7 +185,7 @@ func bootStrap() { } else { ffi.fieldType = RelReverseMany } - if rmi.fields.Add(ffi) == false { + if !rmi.fields.Add(ffi) { added := false for cnt := 0; cnt < 5; cnt++ { ffi.name = fmt.Sprintf("%s%d", mi.name, cnt) @@ -195,7 +195,7 @@ func bootStrap() { break } } - if added == false { + if !added { panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName)) } } @@ -248,7 +248,7 @@ func bootStrap() { break mForA } } - if found == false { + if !found { err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) goto end } @@ -267,7 +267,7 @@ func bootStrap() { break mForB } } - if found == false { + if !found { mForC: for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] { conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough || @@ -287,7 +287,7 @@ func bootStrap() { } } } - if found == false { + if !found { err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) goto end } diff --git a/orm/models_info_f.go b/orm/models_info_f.go index 4b3d3e27..bbb7d71f 100644 --- a/orm/models_info_f.go +++ b/orm/models_info_f.go @@ -47,7 +47,7 @@ func (f *fields) Add(fi *fieldInfo) (added bool) { } else { return } - if _, ok := f.fieldsByType[fi.fieldType]; ok == false { + if _, ok := f.fieldsByType[fi.fieldType]; !ok { f.fieldsByType[fi.fieldType] = make([]*fieldInfo, 0) } f.fieldsByType[fi.fieldType] = append(f.fieldsByType[fi.fieldType], fi) @@ -334,12 +334,12 @@ checkType: switch onDelete { case odCascade, odDoNothing: case odSetDefault: - if initial.Exist() == false { + if !initial.Exist() { err = errors.New("on_delete: set_default need set field a default value") goto end } case odSetNULL: - if fi.null == false { + if !fi.null { err = errors.New("on_delete: set_null need set field null") goto end } diff --git a/orm/models_info_m.go b/orm/models_info_m.go index d6ba1dca..4a3a37f9 100644 --- a/orm/models_info_m.go +++ b/orm/models_info_m.go @@ -78,7 +78,7 @@ func addModelFields(mi *modelInfo, ind reflect.Value, mName string, index []int) fi.fieldIndex = append(index, i) fi.mi = mi fi.inModel = true - if mi.fields.Add(fi) == false { + if !mi.fields.Add(fi) { err = fmt.Errorf("duplicate column name: %s", fi.column) break } diff --git a/orm/models_test.go b/orm/models_test.go index 462370b2..9843a87d 100644 --- a/orm/models_test.go +++ b/orm/models_test.go @@ -406,6 +406,11 @@ type UintPk struct { Name string } +type PtrPk struct { + ID *IntegerPk `orm:"pk;rel(one)"` + Positive bool +} + var DBARGS = struct { Driver string Source string diff --git a/orm/orm.go b/orm/orm.go index 538916e4..d364e621 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -122,21 +122,13 @@ func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { // read data to model func (o *orm) Read(md interface{}, cols ...string) error { mi, ind := o.getMiInd(md, true) - err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) - if err != nil { - return err - } - return nil + return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) } // read data to model, like Read(), but use "SELECT FOR UPDATE" form func (o *orm) ReadForUpdate(md interface{}, cols ...string) error { mi, ind := o.getMiInd(md, true) - err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true) - if err != nil { - return err - } - return nil + return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true) } // Try to read a row from the database, or insert one if it doesn't exist @@ -153,6 +145,8 @@ func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, i id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex) if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { id = int64(vid.Uint()) + } else if mi.fields.pk.rel { + return o.ReadOrCreate(vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name) } else { id = vid.Int() } @@ -236,15 +230,11 @@ func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64 // cols set the columns those want to update. func (o *orm) Update(md interface{}, cols ...string) (int64, error) { mi, ind := o.getMiInd(md, true) - num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) - if err != nil { - return num, err - } - return num, nil + return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) } // delete model in database -// cols shows the delete conditions values read from. deafult is pk +// cols shows the delete conditions values read from. default is pk func (o *orm) Delete(md interface{}, cols ...string) (int64, error) { mi, ind := o.getMiInd(md, true) num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols) @@ -359,7 +349,7 @@ func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, fi := o.getFieldInfo(mi, name) _, _, exist := getExistPk(mi, ind) - if exist == false { + if !exist { panic(ErrMissPK) } @@ -487,7 +477,7 @@ func (o *orm) Begin() error { // commit transaction func (o *orm) Commit() error { - if o.isTx == false { + if !o.isTx { return ErrTxDone } err := o.db.(txEnder).Commit() @@ -502,7 +492,7 @@ func (o *orm) Commit() error { // rollback transaction func (o *orm) Rollback() error { - if o.isTx == false { + if !o.isTx { return ErrTxDone } err := o.db.(txEnder).Rollback() diff --git a/orm/orm_conds.go b/orm/orm_conds.go index e56d6fbb..f6e389ec 100644 --- a/orm/orm_conds.go +++ b/orm/orm_conds.go @@ -75,6 +75,19 @@ func (c *Condition) AndCond(cond *Condition) *Condition { return c } +// AndNotCond combine a AND NOT condition to current condition +func (c *Condition) AndNotCond(cond *Condition) *Condition { + c = c.clone() + if c == cond { + panic(fmt.Errorf(" cannot use self as sub cond")) + } + + if cond != nil { + c.params = append(c.params, condValue{cond: cond, isCond: true, isNot: true}) + } + return c +} + // Or add OR expression to condition func (c Condition) Or(expr string, args ...interface{}) *Condition { if expr == "" || len(args) == 0 { @@ -105,6 +118,19 @@ func (c *Condition) OrCond(cond *Condition) *Condition { return c } +// OrNotCond combine a OR NOT condition to current condition +func (c *Condition) OrNotCond(cond *Condition) *Condition { + c = c.clone() + if c == cond { + panic(fmt.Errorf(" cannot use self as sub cond")) + } + + if cond != nil { + c.params = append(c.params, condValue{cond: cond, isCond: true, isNot: true, isOr: true}) + } + return c +} + // IsEmpty check the condition arguments are empty or not. func (c *Condition) IsEmpty() bool { return len(c.params) == 0 diff --git a/orm/orm_querym2m.go b/orm/orm_querym2m.go index b220bda6..6a270a0d 100644 --- a/orm/orm_querym2m.go +++ b/orm/orm_querym2m.go @@ -72,7 +72,7 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { } _, v1, exist := getExistPk(o.mi, o.ind) - if exist == false { + if !exist { panic(ErrMissPK) } @@ -87,7 +87,7 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { v2 = ind.Interface() } else { _, v2, exist = getExistPk(fi.relModelInfo, ind) - if exist == false { + if !exist { panic(ErrMissPK) } } @@ -104,11 +104,7 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { fi := o.fi qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md) - nums, err := qs.Filter(fi.reverseFieldInfoTwo.name+ExprSep+"in", mds).Delete() - if err != nil { - return nums, err - } - return nums, nil + return qs.Filter(fi.reverseFieldInfoTwo.name+ExprSep+"in", mds).Delete() } // check model is existed in relationship of origin model diff --git a/orm/orm_queryset.go b/orm/orm_queryset.go index 575f62ae..4e33646d 100644 --- a/orm/orm_queryset.go +++ b/orm/orm_queryset.go @@ -153,6 +153,11 @@ func (o querySet) SetCond(cond *Condition) QuerySeter { return &o } +// get condition from QuerySeter +func (o querySet) GetCond() *Condition { + return o.cond +} + // return QuerySeter execution result number func (o *querySet) Count() (int64, error) { return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) diff --git a/orm/orm_test.go b/orm/orm_test.go index fbf4768d..a8b4fe69 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -93,14 +93,14 @@ wrongArg: } func AssertIs(a interface{}, args ...interface{}) error { - if ok, err := ValuesCompare(true, a, args...); ok == false { + if ok, err := ValuesCompare(true, a, args...); !ok { return err } return nil } func AssertNot(a interface{}, args ...interface{}) error { - if ok, err := ValuesCompare(false, a, args...); ok == false { + if ok, err := ValuesCompare(false, a, args...); !ok { return err } return nil @@ -193,6 +193,7 @@ func TestSyncDb(t *testing.T) { RegisterModel(new(InLineOneToOne)) RegisterModel(new(IntegerPk)) RegisterModel(new(UintPk)) + RegisterModel(new(PtrPk)) err := RunSyncdb("default", true, Debug) throwFail(t, err) @@ -216,6 +217,7 @@ func TestRegisterModels(t *testing.T) { RegisterModel(new(InLineOneToOne)) RegisterModel(new(IntegerPk)) RegisterModel(new(UintPk)) + RegisterModel(new(PtrPk)) BootStrap() @@ -909,6 +911,16 @@ func TestSetCond(t *testing.T) { num, err = qs.SetCond(cond2).Count() throwFail(t, err) throwFail(t, AssertIs(num, 2)) + + cond3 := cond.AndNotCond(cond.And("status__in", 1)) + num, err = qs.SetCond(cond3).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + cond4 := cond.And("user_name", "slene").OrNotCond(cond.And("user_name", "slene")) + num, err = qs.SetCond(cond4).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) } func TestLimit(t *testing.T) { @@ -2134,6 +2146,48 @@ func TestUintPk(t *testing.T) { dORM.Delete(u) } +func TestPtrPk(t *testing.T) { + parent := &IntegerPk{ID: 10, Value: "10"} + + id, _ := dORM.Insert(parent) + if !IsMysql { + // MySql does not support last_insert_id in this case: see #2382 + throwFail(t, AssertIs(id, 10)) + } + + ptr := PtrPk{ID: parent, Positive: true} + num, err := dORM.InsertMulti(2, []PtrPk{ptr}) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertIs(ptr.ID, parent)) + + nptr := &PtrPk{ID: parent} + created, pk, err := dORM.ReadOrCreate(nptr, "ID") + throwFail(t, err) + throwFail(t, AssertIs(created, false)) + throwFail(t, AssertIs(pk, 10)) + throwFail(t, AssertIs(nptr.ID, parent)) + throwFail(t, AssertIs(nptr.Positive, true)) + + nptr = &PtrPk{Positive: true} + created, pk, err = dORM.ReadOrCreate(nptr, "Positive") + throwFail(t, err) + throwFail(t, AssertIs(created, false)) + throwFail(t, AssertIs(pk, 10)) + throwFail(t, AssertIs(nptr.ID, parent)) + + nptr.Positive = false + num, err = dORM.Update(nptr) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertIs(nptr.ID, parent)) + throwFail(t, AssertIs(nptr.Positive, false)) + + num, err = dORM.Delete(nptr) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + func TestSnake(t *testing.T) { cases := map[string]string{ "i": "i", diff --git a/orm/types.go b/orm/types.go index fd3062ab..3e6a9e87 100644 --- a/orm/types.go +++ b/orm/types.go @@ -145,6 +145,16 @@ type QuerySeter interface { // //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000 // num, err := qs.SetCond(cond1).Count() SetCond(*Condition) QuerySeter + // get condition from QuerySeter. + // sql's where condition + // cond := orm.NewCondition() + // cond = cond.And("profile__isnull", false).AndNot("status__in", 1) + // qs = qs.SetCond(cond) + // cond = qs.GetCond() + // cond := cond.Or("profile__age__gt", 2000) + // //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000 + // num, err := qs.SetCond(cond).Count() + GetCond() *Condition // add LIMIT value. // args[0] means offset, e.g. LIMIT num,offset. // if Limit <= 0 then Limit will be set to default limit ,eg 1000 diff --git a/orm/utils.go b/orm/utils.go index 6e23447e..669d4734 100644 --- a/orm/utils.go +++ b/orm/utils.go @@ -16,6 +16,7 @@ package orm import ( "fmt" + "math/big" "reflect" "strconv" "strings" @@ -87,7 +88,15 @@ func (f StrTo) Int32() (int32, error) { // Int64 string to int64 func (f StrTo) Int64() (int64, error) { v, err := strconv.ParseInt(f.String(), 10, 64) - return int64(v), err + if err != nil { + i := new(big.Int) + ni, ok := i.SetString(f.String(), 10) // octal + if !ok { + return v, err + } + return ni.Int64(), nil + } + return v, err } // Uint string to uint @@ -117,7 +126,15 @@ func (f StrTo) Uint32() (uint32, error) { // Uint64 string to uint64 func (f StrTo) Uint64() (uint64, error) { v, err := strconv.ParseUint(f.String(), 10, 64) - return uint64(v), err + if err != nil { + i := new(big.Int) + ni, ok := i.SetString(f.String(), 10) + if !ok { + return v, err + } + return ni.Uint64(), nil + } + return v, err } // String string to string @@ -202,22 +219,17 @@ func snakeString(s string) string { // camel string, xx_yy to XxYy func camelString(s string) string { data := make([]byte, 0, len(s)) - j := false - k := false - num := len(s) - 1 + flag, num := true, len(s)-1 for i := 0; i <= num; i++ { d := s[i] - if k == false && d >= 'A' && d <= 'Z' { - k = true - } - if d >= 'a' && d <= 'z' && (j || k == false) { - d = d - 32 - j = false - k = true - } - if k && d == '_' && num > i && s[i+1] >= 'a' && s[i+1] <= 'z' { - j = true + if d == '_' { + flag = true continue + } else if flag { + if d >= 'a' && d <= 'z' { + d = d - 32 + } + flag = false } data = append(data, d) } diff --git a/orm/utils_test.go b/orm/utils_test.go new file mode 100644 index 00000000..8c7c5008 --- /dev/null +++ b/orm/utils_test.go @@ -0,0 +1,36 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "testing" +) + +func TestCamelString(t *testing.T) { + snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"} + camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"} + + answer := make(map[string]string) + for i, v := range snake { + answer[v] = camel[i] + } + + for _, v := range snake { + res := camelString(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} diff --git a/plugins/apiauth/apiauth.go b/plugins/apiauth/apiauth.go index 10636d1c..f816029c 100644 --- a/plugins/apiauth/apiauth.go +++ b/plugins/apiauth/apiauth.go @@ -56,6 +56,7 @@ package apiauth import ( + "bytes" "crypto/hmac" "crypto/sha256" "encoding/base64" @@ -128,53 +129,32 @@ func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc { // Signature used to generate signature with the appsecret/method/params/RequestURI func Signature(appsecret, method string, params url.Values, RequestURL string) (result string) { - var query string + var b bytes.Buffer + keys := make([]string, len(params)) pa := make(map[string]string) for k, v := range params { pa[k] = v[0] + keys = append(keys, k) } - vs := mapSorter(pa) - vs.Sort() - for i := 0; i < vs.Len(); i++ { - if vs.Keys[i] == "signature" { + + sort.Strings(keys) + + for _, key := range keys { + if key == "signature" { continue } - if vs.Keys[i] != "" && vs.Vals[i] != "" { - query = fmt.Sprintf("%v%v%v", query, vs.Keys[i], vs.Vals[i]) + + val := pa[key] + if key != "" && val != "" { + b.WriteString(key) + b.WriteString(val) } } - stringToSign := fmt.Sprintf("%v\n%v\n%v\n", method, query, RequestURL) + + stringToSign := fmt.Sprintf("%v\n%v\n%v\n", method, b.String(), RequestURL) sha256 := sha256.New hash := hmac.New(sha256, []byte(appsecret)) hash.Write([]byte(stringToSign)) return base64.StdEncoding.EncodeToString(hash.Sum(nil)) } - -type valSorter struct { - Keys []string - Vals []string -} - -func mapSorter(m map[string]string) *valSorter { - vs := &valSorter{ - Keys: make([]string, 0, len(m)), - Vals: make([]string, 0, len(m)), - } - for k, v := range m { - vs.Keys = append(vs.Keys, k) - vs.Vals = append(vs.Vals, v) - } - return vs -} - -func (vs *valSorter) Sort() { - sort.Sort(vs) -} - -func (vs *valSorter) Len() int { return len(vs.Keys) } -func (vs *valSorter) Less(i, j int) bool { return vs.Keys[i] < vs.Keys[j] } -func (vs *valSorter) Swap(i, j int) { - vs.Vals[i], vs.Vals[j] = vs.Vals[j], vs.Vals[i] - vs.Keys[i], vs.Keys[j] = vs.Keys[j], vs.Keys[i] -} diff --git a/plugins/apiauth/apiauth_test.go b/plugins/apiauth/apiauth_test.go new file mode 100644 index 00000000..1f56cb0f --- /dev/null +++ b/plugins/apiauth/apiauth_test.go @@ -0,0 +1,20 @@ +package apiauth + +import ( + "net/url" + "testing" +) + +func TestSignature(t *testing.T) { + appsecret := "beego secret" + method := "GET" + RequestURL := "http://localhost/test/url" + params := make(url.Values) + params.Add("arg1", "hello") + params.Add("arg2", "beego") + + signature := "mFdpvLh48ca4mDVEItE9++AKKQ/IVca7O/ZyyB8hR58=" + if Signature(appsecret, method, params, RequestURL) != signature { + t.Error("Signature error") + } +} diff --git a/policy.go b/policy.go new file mode 100644 index 00000000..2b91fdcc --- /dev/null +++ b/policy.go @@ -0,0 +1,97 @@ +// Copyright 2016 beego authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "strings" + + "github.com/astaxie/beego/context" +) + +// PolicyFunc defines a policy function which is invoked before the controller handler is executed. +type PolicyFunc func(*context.Context) + +// FindRouter Find Router info for URL +func (p *ControllerRegister) FindPolicy(cont *context.Context) []PolicyFunc { + var urlPath = cont.Input.URL() + if !BConfig.RouterCaseSensitive { + urlPath = strings.ToLower(urlPath) + } + httpMethod := cont.Input.Method() + isWildcard := false + // Find policy for current method + t, ok := p.policies[httpMethod] + // If not found - find policy for whole controller + if !ok { + t, ok = p.policies["*"] + isWildcard = true + } + if ok { + runObjects := t.Match(urlPath, cont) + if r, ok := runObjects.([]PolicyFunc); ok { + return r + } else if !isWildcard { + // If no policies found and we checked not for "*" method - try to find it + t, ok = p.policies["*"] + if ok { + runObjects = t.Match(urlPath, cont) + if r, ok = runObjects.([]PolicyFunc); ok { + return r + } + } + } + } + return nil +} + +func (p *ControllerRegister) addToPolicy(method, pattern string, r ...PolicyFunc) { + method = strings.ToUpper(method) + p.enablePolicy = true + if !BConfig.RouterCaseSensitive { + pattern = strings.ToLower(pattern) + } + if t, ok := p.policies[method]; ok { + t.AddRouter(pattern, r) + } else { + t := NewTree() + t.AddRouter(pattern, r) + p.policies[method] = t + } +} + +// Register new policy in beego +func Policy(pattern, method string, policy ...PolicyFunc) { + BeeApp.Handlers.addToPolicy(method, pattern, policy...) +} + +// Find policies and execute if were found +func (p *ControllerRegister) execPolicy(cont *context.Context, urlPath string) (started bool) { + if !p.enablePolicy { + return false + } + // Find Policy for method + policyList := p.FindPolicy(cont) + if len(policyList) > 0 { + // Run policies + for _, runPolicy := range policyList { + runPolicy(cont) + if cont.ResponseWriter.Started { + return true + } + } + return false + } + return false +} diff --git a/router.go b/router.go index 97b0edba..9f573f26 100644 --- a/router.go +++ b/router.go @@ -51,15 +51,22 @@ const ( var ( // HTTPMETHOD list the supported http methods. HTTPMETHOD = map[string]string{ - "GET": "GET", - "POST": "POST", - "PUT": "PUT", - "DELETE": "DELETE", - "PATCH": "PATCH", - "OPTIONS": "OPTIONS", - "HEAD": "HEAD", - "TRACE": "TRACE", - "CONNECT": "CONNECT", + "GET": "GET", + "POST": "POST", + "PUT": "PUT", + "DELETE": "DELETE", + "PATCH": "PATCH", + "OPTIONS": "OPTIONS", + "HEAD": "HEAD", + "TRACE": "TRACE", + "CONNECT": "CONNECT", + "MKCOL": "MKCOL", + "COPY": "COPY", + "MOVE": "MOVE", + "PROPFIND": "PROPFIND", + "PROPPATCH": "PROPPATCH", + "LOCK": "LOCK", + "UNLOCK": "UNLOCK", } // these beego.Controller's methods shouldn't reflect to AutoRouter exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString", @@ -114,6 +121,8 @@ type controllerInfo struct { // ControllerRegister containers registered router rules, controller handlers and filters. type ControllerRegister struct { routers map[string]*Tree + enablePolicy bool + policies map[string]*Tree enableFilter bool filters [FinishRouter + 1][]*FilterRouter pool sync.Pool @@ -122,7 +131,8 @@ type ControllerRegister struct { // NewControllerRegister returns a new ControllerRegister. func NewControllerRegister() *ControllerRegister { cr := &ControllerRegister{ - routers: make(map[string]*Tree), + routers: make(map[string]*Tree), + policies: make(map[string]*Tree), } cr.pool.New = func() interface{} { return beecontext.NewContext() @@ -711,11 +721,14 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) goto Admin } + //check policies + if p.execPolicy(context, urlPath) { + goto Admin + } + if routerInfo != nil { - if BConfig.RunMode == DEV { - //store router pattern into context - context.Input.SetData("RouterPattern", routerInfo.pattern) - } + //store router pattern into context + context.Input.SetData("RouterPattern", routerInfo.pattern) if routerInfo.routerType == routerTypeRESTFul { if _, ok := routerInfo.methods[r.Method]; ok { isRunnable = true diff --git a/router_test.go b/router_test.go index 936fd5e8..25804d67 100644 --- a/router_test.go +++ b/router_test.go @@ -502,10 +502,10 @@ func TestFilterBeforeRouter(t *testing.T) { rw, r := testRequest("GET", url) mux.ServeHTTP(rw, r) - if strings.Contains(rw.Body.String(), "BeforeRouter1") == false { + if !strings.Contains(rw.Body.String(), "BeforeRouter1") { t.Errorf(testName + " BeforeRouter did not run") } - if strings.Contains(rw.Body.String(), "hello") == true { + if strings.Contains(rw.Body.String(), "hello") { t.Errorf(testName + " BeforeRouter did not return properly") } } @@ -525,13 +525,13 @@ func TestFilterBeforeExec(t *testing.T) { rw, r := testRequest("GET", url) mux.ServeHTTP(rw, r) - if strings.Contains(rw.Body.String(), "BeforeExec1") == false { + if !strings.Contains(rw.Body.String(), "BeforeExec1") { t.Errorf(testName + " BeforeExec did not run") } - if strings.Contains(rw.Body.String(), "hello") == true { + if strings.Contains(rw.Body.String(), "hello") { t.Errorf(testName + " BeforeExec did not return properly") } - if strings.Contains(rw.Body.String(), "BeforeRouter") == true { + if strings.Contains(rw.Body.String(), "BeforeRouter") { t.Errorf(testName + " BeforeRouter ran in error") } } @@ -552,16 +552,16 @@ func TestFilterAfterExec(t *testing.T) { rw, r := testRequest("GET", url) mux.ServeHTTP(rw, r) - if strings.Contains(rw.Body.String(), "AfterExec1") == false { + if !strings.Contains(rw.Body.String(), "AfterExec1") { t.Errorf(testName + " AfterExec did not run") } - if strings.Contains(rw.Body.String(), "hello") == false { + if !strings.Contains(rw.Body.String(), "hello") { t.Errorf(testName + " handler did not run properly") } - if strings.Contains(rw.Body.String(), "BeforeRouter") == true { + if strings.Contains(rw.Body.String(), "BeforeRouter") { t.Errorf(testName + " BeforeRouter ran in error") } - if strings.Contains(rw.Body.String(), "BeforeExec") == true { + if strings.Contains(rw.Body.String(), "BeforeExec") { t.Errorf(testName + " BeforeExec ran in error") } } @@ -583,19 +583,19 @@ func TestFilterFinishRouter(t *testing.T) { rw, r := testRequest("GET", url) mux.ServeHTTP(rw, r) - if strings.Contains(rw.Body.String(), "FinishRouter1") == true { + if strings.Contains(rw.Body.String(), "FinishRouter1") { t.Errorf(testName + " FinishRouter did not run") } - if strings.Contains(rw.Body.String(), "hello") == false { + if !strings.Contains(rw.Body.String(), "hello") { t.Errorf(testName + " handler did not run properly") } - if strings.Contains(rw.Body.String(), "AfterExec1") == true { + if strings.Contains(rw.Body.String(), "AfterExec1") { t.Errorf(testName + " AfterExec ran in error") } - if strings.Contains(rw.Body.String(), "BeforeRouter") == true { + if strings.Contains(rw.Body.String(), "BeforeRouter") { t.Errorf(testName + " BeforeRouter ran in error") } - if strings.Contains(rw.Body.String(), "BeforeExec") == true { + if strings.Contains(rw.Body.String(), "BeforeExec") { t.Errorf(testName + " BeforeExec ran in error") } } @@ -615,14 +615,14 @@ func TestFilterFinishRouterMultiFirstOnly(t *testing.T) { rw, r := testRequest("GET", url) mux.ServeHTTP(rw, r) - if strings.Contains(rw.Body.String(), "FinishRouter1") == false { + if !strings.Contains(rw.Body.String(), "FinishRouter1") { t.Errorf(testName + " FinishRouter1 did not run") } - if strings.Contains(rw.Body.String(), "hello") == false { + if !strings.Contains(rw.Body.String(), "hello") { t.Errorf(testName + " handler did not run properly") } // not expected in body - if strings.Contains(rw.Body.String(), "FinishRouter2") == true { + if strings.Contains(rw.Body.String(), "FinishRouter2") { t.Errorf(testName + " FinishRouter2 did run") } } @@ -642,13 +642,13 @@ func TestFilterFinishRouterMulti(t *testing.T) { rw, r := testRequest("GET", url) mux.ServeHTTP(rw, r) - if strings.Contains(rw.Body.String(), "FinishRouter1") == false { + if !strings.Contains(rw.Body.String(), "FinishRouter1") { t.Errorf(testName + " FinishRouter1 did not run") } - if strings.Contains(rw.Body.String(), "hello") == false { + if !strings.Contains(rw.Body.String(), "hello") { t.Errorf(testName + " handler did not run properly") } - if strings.Contains(rw.Body.String(), "FinishRouter2") == false { + if !strings.Contains(rw.Body.String(), "FinishRouter2") { t.Errorf(testName + " FinishRouter2 did not run properly") } } diff --git a/session/ledis/ledis_session.go b/session/ledis/ledis_session.go index 68f37b08..4892a8ab 100644 --- a/session/ledis/ledis_session.go +++ b/session/ledis/ledis_session.go @@ -125,10 +125,7 @@ func (lp *Provider) SessionRead(sid string) (session.Store, error) { // SessionExist check ledis session exist by sid func (lp *Provider) SessionExist(sid string) bool { count, _ := c.Exists([]byte(sid)) - if count == 0 { - return false - } - return true + return !(count == 0) } // SessionRegenerate generate new sid for ledis session @@ -150,7 +147,7 @@ func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) if len(kvs) == 0 { kv = make(map[interface{}]interface{}) } else { - kv, err = session.DecodeGob([]byte(kvs)) + kv, err = session.DecodeGob(kvs) if err != nil { return nil, err } diff --git a/session/memcache/sess_memcache.go b/session/memcache/sess_memcache.go index f1069bc9..cad02258 100644 --- a/session/memcache/sess_memcache.go +++ b/session/memcache/sess_memcache.go @@ -205,11 +205,7 @@ func (rp *MemProvider) SessionDestroy(sid string) error { } } - err := client.Delete(sid) - if err != nil { - return err - } - return nil + return client.Delete(sid) } func (rp *MemProvider) connectInit() error { diff --git a/session/mysql/sess_mysql.go b/session/mysql/sess_mysql.go index 838ec669..7bcf959c 100644 --- a/session/mysql/sess_mysql.go +++ b/session/mysql/sess_mysql.go @@ -170,10 +170,7 @@ func (mp *Provider) SessionExist(sid string) bool { row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) var sessiondata []byte err := row.Scan(&sessiondata) - if err == sql.ErrNoRows { - return false - } - return true + return !(err == sql.ErrNoRows) } // SessionRegenerate generate new sid for mysql session diff --git a/session/postgres/sess_postgresql.go b/session/postgres/sess_postgresql.go index 73f9c13a..f00e0711 100644 --- a/session/postgres/sess_postgresql.go +++ b/session/postgres/sess_postgresql.go @@ -184,11 +184,7 @@ func (mp *Provider) SessionExist(sid string) bool { row := c.QueryRow("select session_data from session where session_key=$1", sid) var sessiondata []byte err := row.Scan(&sessiondata) - - if err == sql.ErrNoRows { - return false - } - return true + return !(err == sql.ErrNoRows) } // SessionRegenerate generate new sid for postgresql session diff --git a/session/redis/sess_redis.go b/session/redis/sess_redis.go index c46fa7cd..14bbd4c1 100644 --- a/session/redis/sess_redis.go +++ b/session/redis/sess_redis.go @@ -128,7 +128,7 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { } if len(configs) > 1 { poolsize, err := strconv.Atoi(configs[1]) - if err != nil || poolsize <= 0 { + if err != nil || poolsize < 0 { rp.poolsize = MaxPoolSize } else { rp.poolsize = poolsize diff --git a/session/sess_file.go b/session/sess_file.go index 132f5a00..12cf1f3f 100644 --- a/session/sess_file.go +++ b/session/sess_file.go @@ -15,8 +15,7 @@ package session import ( - "errors" - "io" + "fmt" "io/ioutil" "net/http" "os" @@ -135,6 +134,9 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) { } else { return nil, err } + + defer f.Close() + os.Chtimes(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), time.Now(), time.Now()) var kv map[interface{}]interface{} b, err := ioutil.ReadAll(f) @@ -149,7 +151,7 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) { return nil, err } } - f.Close() + ss := &FileSessionStore{sid: sid, values: kv} return ss, nil } @@ -161,10 +163,7 @@ func (fp *FileProvider) SessionExist(sid string) bool { defer filepder.lock.Unlock() _, err := os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) - if err == nil { - return true - } - return false + return err == nil } // SessionDestroy Remove all files in this save path @@ -204,49 +203,58 @@ func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) { filepder.lock.Lock() defer filepder.lock.Unlock() - err := os.MkdirAll(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])), 0777) - if err != nil { - SLogger.Println(err.Error()) - } - err = os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777) - if err != nil { - SLogger.Println(err.Error()) - } - _, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) - var newf *os.File + oldPath := path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])) + oldSidFile := path.Join(oldPath, oldsid) + newPath := path.Join(fp.savePath, string(sid[0]), string(sid[1])) + newSidFile := path.Join(newPath, sid) + + // new sid file is exist + _, err := os.Stat(newSidFile) if err == nil { - return nil, errors.New("newsid exist") - } else if os.IsNotExist(err) { - newf, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) + return nil, fmt.Errorf("newsid %s exist", newSidFile) } - _, err = os.Stat(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1]), oldsid)) - var f *os.File - if err == nil { - f, err = os.OpenFile(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1]), oldsid), os.O_RDWR, 0777) - io.Copy(newf, f) - } else if os.IsNotExist(err) { - newf, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) - } else { - return nil, err - } - f.Close() - os.Remove(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1]))) - os.Chtimes(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), time.Now(), time.Now()) - var kv map[interface{}]interface{} - b, err := ioutil.ReadAll(newf) + err = os.MkdirAll(newPath, 0777) if err != nil { - return nil, err + SLogger.Println(err.Error()) } - if len(b) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = DecodeGob(b) + + // if old sid file exist + // 1.read and parse file content + // 2.write content to new sid file + // 3.remove old sid file, change new sid file atime and ctime + // 4.return FileSessionStore + _, err = os.Stat(oldSidFile) + if err == nil { + b, err := ioutil.ReadFile(oldSidFile) if err != nil { return nil, err } + + var kv map[interface{}]interface{} + if len(b) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = DecodeGob(b) + if err != nil { + return nil, err + } + } + + ioutil.WriteFile(newSidFile, b, 0777) + os.Remove(oldSidFile) + os.Chtimes(newSidFile, time.Now(), time.Now()) + ss := &FileSessionStore{sid: sid, values: kv} + return ss, nil } - ss := &FileSessionStore{sid: sid, values: kv} + + // if old sid file not exist, just create new sid file and return + newf, err := os.Create(newSidFile) + if err != nil { + return nil, err + } + newf.Close() + ss := &FileSessionStore{sid: sid, values: make(map[interface{}]interface{})} return ss, nil } diff --git a/session/sess_test.go b/session/sess_test.go index b40865f3..60b47dd4 100644 --- a/session/sess_test.go +++ b/session/sess_test.go @@ -115,7 +115,7 @@ func TestParseConfig(t *testing.T) { if cf2.Gclifetime != 3600 { t.Fatal("parseconfig get gclifetime error") } - if cf2.EnableSetCookie != false { + if cf2.EnableSetCookie { t.Fatal("parseconfig get enableSetCookie error") } cconfig := new(cookieConfig) diff --git a/session/session.go b/session/session.go index 3c9d07ab..fb4b2821 100644 --- a/session/session.go +++ b/session/session.go @@ -86,6 +86,7 @@ type ManagerConfig struct { EnableSetCookie bool `json:"enableSetCookie,omitempty"` Gclifetime int64 `json:"gclifetime"` Maxlifetime int64 `json:"maxLifetime"` + DisableHTTPOnly bool `json:"disableHTTPOnly"` Secure bool `json:"secure"` CookieLifeTime int `json:"cookieLifeTime"` ProviderConfig string `json:"providerConfig"` @@ -206,13 +207,13 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se session, err = manager.provider.SessionRead(sid) if err != nil { - return nil, errs + return nil, err } cookie := &http.Cookie{ Name: manager.config.CookieName, Value: url.QueryEscape(sid), Path: "/", - HttpOnly: true, + HttpOnly: !manager.config.DisableHTTPOnly, Secure: manager.isSecure(r), Domain: manager.config.Domain, } @@ -251,7 +252,7 @@ func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { expiration := time.Now() cookie = &http.Cookie{Name: manager.config.CookieName, Path: "/", - HttpOnly: true, + HttpOnly: !manager.config.DisableHTTPOnly, Expires: expiration, MaxAge: -1} @@ -285,7 +286,7 @@ func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Reque cookie = &http.Cookie{Name: manager.config.CookieName, Value: url.QueryEscape(sid), Path: "/", - HttpOnly: true, + HttpOnly: !manager.config.DisableHTTPOnly, Secure: manager.isSecure(r), Domain: manager.config.Domain, } diff --git a/session/ssdb/sess_ssdb.go b/session/ssdb/sess_ssdb.go index 4dcf160a..c846e73a 100644 --- a/session/ssdb/sess_ssdb.go +++ b/session/ssdb/sess_ssdb.go @@ -26,10 +26,7 @@ func (p *SsdbProvider) connectInit() error { return errors.New("SessionInit First") } p.client, err = ssdb.Connect(p.host, p.port) - if err != nil { - return err - } - return nil + return err } func (p *SsdbProvider) SessionInit(maxLifetime int64, savePath string) error { @@ -41,11 +38,7 @@ func (p *SsdbProvider) SessionInit(maxLifetime int64, savePath string) error { if e != nil { return e } - err := p.connectInit() - if err != nil { - return err - } - return nil + return p.connectInit() } func (p *SsdbProvider) SessionRead(sid string) (session.Store, error) { @@ -126,10 +119,7 @@ func (p *SsdbProvider) SessionDestroy(sid string) error { } } _, err := p.client.Del(sid) - if err != nil { - return err - } - return nil + return err } func (p *SsdbProvider) SessionGC() { diff --git a/staticfile.go b/staticfile.go index b7be24f3..7c270947 100644 --- a/staticfile.go +++ b/staticfile.go @@ -109,14 +109,14 @@ var ( func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, string, *serveContentHolder, error) { mapKey := acceptEncoding + ":" + filePath mapLock.RLock() - mapFile, _ := staticFileMap[mapKey] + mapFile := staticFileMap[mapKey] mapLock.RUnlock() if isOk(mapFile, fi) { return mapFile.encoding != "", mapFile.encoding, mapFile, nil } mapLock.Lock() defer mapLock.Unlock() - if mapFile, _ = staticFileMap[mapKey]; !isOk(mapFile, fi) { + if mapFile = staticFileMap[mapKey]; !isOk(mapFile, fi) { file, err := os.Open(filePath) if err != nil { return false, "", nil, err diff --git a/swagger/swagger.go b/swagger/swagger.go index 409e264e..e0ac5cf5 100644 --- a/swagger/swagger.go +++ b/swagger/swagger.go @@ -97,6 +97,7 @@ type Parameter struct { Type string `json:"type,omitempty" yaml:"type,omitempty"` Format string `json:"format,omitempty" yaml:"format,omitempty"` Items *ParameterItems `json:"items,omitempty" yaml:"items,omitempty"` + Default interface{} `json:"default,omitempty" yaml:"default,omitempty"` } // A limited subset of JSON-Schema's items object. It is used by parameter definitions that are not located in "body". @@ -126,7 +127,7 @@ type Propertie struct { Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"` Title string `json:"title,omitempty" yaml:"title,omitempty"` Description string `json:"description,omitempty" yaml:"description,omitempty"` - Default string `json:"default,omitempty" yaml:"default,omitempty"` + Default interface{} `json:"default,omitempty" yaml:"default,omitempty"` Type string `json:"type,omitempty" yaml:"type,omitempty"` Example string `json:"example,omitempty" yaml:"example,omitempty"` Required []string `json:"required,omitempty" yaml:"required,omitempty"` diff --git a/template.go b/template.go index 5415f5f0..104383c9 100644 --- a/template.go +++ b/template.go @@ -31,10 +31,11 @@ import ( ) var ( - beegoTplFuncMap = make(template.FuncMap) - // beeTemplates caching map and supported template file extensions. - beeTemplates = make(map[string]*template.Template) - templatesLock sync.RWMutex + beegoTplFuncMap = make(template.FuncMap) + beeViewPathTemplateLocked = false + // beeViewPathTemplates caching map and supported template file extensions per view + beeViewPathTemplates = make(map[string]map[string]*template.Template) + templatesLock sync.RWMutex // beeTemplateExt stores the template extension which will build beeTemplateExt = []string{"tpl", "html"} // beeTemplatePreprocessors stores associations of extension -> preprocessor handler @@ -45,23 +46,33 @@ var ( // writing the output to wr. // A template will be executed safely in parallel. func ExecuteTemplate(wr io.Writer, name string, data interface{}) error { + return ExecuteViewPathTemplate(wr, name, BConfig.WebConfig.ViewsPath, data) +} + +// ExecuteViewPathTemplate applies the template with name and from specific viewPath to the specified data object, +// writing the output to wr. +// A template will be executed safely in parallel. +func ExecuteViewPathTemplate(wr io.Writer, name string, viewPath string, data interface{}) error { if BConfig.RunMode == DEV { templatesLock.RLock() defer templatesLock.RUnlock() } - if t, ok := beeTemplates[name]; ok { - var err error - if t.Lookup(name) != nil { - err = t.ExecuteTemplate(wr, name, data) - } else { - err = t.Execute(wr, data) + if beeTemplates, ok := beeViewPathTemplates[viewPath]; ok { + if t, ok := beeTemplates[name]; ok { + var err error + if t.Lookup(name) != nil { + err = t.ExecuteTemplate(wr, name, data) + } else { + err = t.Execute(wr, data) + } + if err != nil { + logs.Trace("template Execute err:", err) + } + return err } - if err != nil { - logs.Trace("template Execute err:", err) - } - return err + panic("can't find templatefile in the path:" + viewPath + "/" + name) } - panic("can't find templatefile in the path:" + name) + panic("Unknown view path:" + viewPath) } func init() { @@ -149,6 +160,24 @@ func AddTemplateExt(ext string) { beeTemplateExt = append(beeTemplateExt, ext) } +// AddViewPath adds a new path to the supported view paths. +//Can later be used by setting a controller ViewPath to this folder +//will panic if called after beego.Run() +func AddViewPath(viewPath string) error { + if beeViewPathTemplateLocked { + if _, exist := beeViewPathTemplates[viewPath]; exist { + return nil //Ignore if viewpath already exists + } + panic("Can not add new view paths after beego.Run()") + } + beeViewPathTemplates[viewPath] = make(map[string]*template.Template) + return BuildTemplate(viewPath) +} + +func lockViewPaths() { + beeViewPathTemplateLocked = true +} + // BuildTemplate will build all template files in a directory. // it makes beego can render any template file in view directory. func BuildTemplate(dir string, files ...string) error { @@ -158,6 +187,10 @@ func BuildTemplate(dir string, files ...string) error { } return errors.New("dir open err") } + beeTemplates, ok := beeViewPathTemplates[dir] + if !ok { + panic("Unknown view path: " + dir) + } self := &templateFile{ root: dir, files: make(map[string][]string), @@ -184,7 +217,7 @@ func BuildTemplate(dir string, files ...string) error { t, err = getTemplate(self.root, file, v...) } if err != nil { - logs.Trace("parse template err:", file, err) + logs.Error("parse template err:", file, err) } else { beeTemplates[file] = t } @@ -197,9 +230,12 @@ func BuildTemplate(dir string, files ...string) error { func getTplDeep(root, file, parent string, t *template.Template) (*template.Template, [][]string, error) { var fileAbsPath string + var rParent string if filepath.HasPrefix(file, "../") { + rParent = filepath.Join(filepath.Dir(parent), file) fileAbsPath = filepath.Join(root, filepath.Dir(parent), file) } else { + rParent = file fileAbsPath = filepath.Join(root, file) } if e := utils.FileExists(fileAbsPath); !e { @@ -224,7 +260,7 @@ func getTplDeep(root, file, parent string, t *template.Template) (*template.Temp if !HasTemplateExt(m[1]) { continue } - t, _, err = getTplDeep(root, m[1], file, t) + _, _, err = getTplDeep(root, m[1], rParent, t) if err != nil { return nil, [][]string{}, err } @@ -263,7 +299,7 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others t, subMods1, err = getTplDeep(root, otherFile, "", t) if err != nil { logs.Trace("template parse file err:", err) - } else if subMods1 != nil && len(subMods1) > 0 { + } else if len(subMods1) > 0 { t, err = _getTemplate(t, root, subMods1, others...) } break @@ -284,7 +320,7 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others t, subMods1, err = getTplDeep(root, otherFile, "", t) if err != nil { logs.Trace("template parse file err:", err) - } else if subMods1 != nil && len(subMods1) > 0 { + } else if len(subMods1) > 0 { t, err = _getTemplate(t, root, subMods1, others...) } break diff --git a/template_test.go b/template_test.go index 4f13736c..17690965 100644 --- a/template_test.go +++ b/template_test.go @@ -67,9 +67,10 @@ func TestTemplate(t *testing.T) { f.Close() } } - if err := BuildTemplate(dir); err != nil { + if err := AddViewPath(dir); err != nil { t.Fatal(err) } + beeTemplates := beeViewPathTemplates[dir] if len(beeTemplates) != 3 { t.Fatalf("should be 3 but got %v", len(beeTemplates)) } @@ -103,6 +104,12 @@ var user = ` func TestRelativeTemplate(t *testing.T) { dir := "_beeTmp" + + //Just add dir to known viewPaths + if err := AddViewPath(dir); err != nil { + t.Fatal(err) + } + files := []string{ "easyui/public/menu.tpl", "easyui/rbac/user.tpl", @@ -126,6 +133,7 @@ func TestRelativeTemplate(t *testing.T) { if err := BuildTemplate(dir, files[1]); err != nil { t.Fatal(err) } + beeTemplates := beeViewPathTemplates[dir] if err := beeTemplates["easyui/rbac/user.tpl"].ExecuteTemplate(os.Stdout, "easyui/rbac/user.tpl", nil); err != nil { t.Fatal(err) } diff --git a/templatefunc.go b/templatefunc.go index 01751717..4fdc936d 100644 --- a/templatefunc.go +++ b/templatefunc.go @@ -26,6 +26,12 @@ import ( "time" ) +const ( + formatTime = "15:04:05" + formatDate = "2006-01-02" + formatDateTime = "2006-01-02 15:04:05" +) + // Substr returns the substr from start to length. func Substr(s string, start, length int) string { bt := []rune(s) @@ -46,26 +52,25 @@ func Substr(s string, start, length int) string { // HTML2str returns escaping text convert from html. func HTML2str(html string) string { - src := string(html) re, _ := regexp.Compile("\\<[\\S\\s]+?\\>") - src = re.ReplaceAllStringFunc(src, strings.ToLower) + html = re.ReplaceAllStringFunc(html, strings.ToLower) //remove STYLE re, _ = regexp.Compile("\\") - src = re.ReplaceAllString(src, "") + html = re.ReplaceAllString(html, "") //remove SCRIPT re, _ = regexp.Compile("\\") - src = re.ReplaceAllString(src, "") + html = re.ReplaceAllString(html, "") re, _ = regexp.Compile("\\<[\\S\\s]+?\\>") - src = re.ReplaceAllString(src, "\n") + html = re.ReplaceAllString(html, "\n") re, _ = regexp.Compile("\\s{2,}") - src = re.ReplaceAllString(src, "\n") + html = re.ReplaceAllString(html, "\n") - return strings.TrimSpace(src) + return strings.TrimSpace(html) } // DateFormat takes a time and a layout string and returns a string with the formatted date. Used by the template parser as "dateformat" @@ -193,7 +198,7 @@ func Str2html(raw string) template.HTML { } // Htmlquote returns quoted html string. -func Htmlquote(src string) string { +func Htmlquote(text string) string { //HTML编码为实体符号 /* Encodes `text` for raw use in HTML. @@ -201,8 +206,6 @@ func Htmlquote(src string) string { '<'&">' */ - text := string(src) - text = strings.Replace(text, "&", "&", -1) // Must be done first! text = strings.Replace(text, "<", "<", -1) text = strings.Replace(text, ">", ">", -1) @@ -216,7 +219,7 @@ func Htmlquote(src string) string { } // Htmlunquote returns unquoted html string. -func Htmlunquote(src string) string { +func Htmlunquote(text string) string { //实体符号解释为HTML /* Decodes `text` that's HTML quoted. @@ -227,7 +230,6 @@ func Htmlunquote(src string) string { // strings.Replace(s, old, new, n) // 在s字符串中,把old字符串替换为new字符串,n表示替换的次数,小于0表示全部替换 - text := string(src) text = strings.Replace(text, " ", " ", -1) text = strings.Replace(text, "”", "”", -1) text = strings.Replace(text, "“", "“", -1) @@ -262,19 +264,17 @@ func URLFor(endpoint string, values ...interface{}) string { } // AssetsJs returns script tag with src string. -func AssetsJs(src string) template.HTML { - text := string(src) +func AssetsJs(text string) template.HTML { - text = "" + text = "" return template.HTML(text) } // AssetsCSS returns stylesheet link tag with src string. -func AssetsCSS(src string) template.HTML { - text := string(src) +func AssetsCSS(text string) template.HTML { - text = "" + text = "" return template.HTML(text) } @@ -352,11 +352,28 @@ func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) e case reflect.Struct: switch fieldT.Type.String() { case "time.Time": - format := time.RFC3339 - if len(tags) > 1 { - format = tags[1] + var ( + t time.Time + err error + ) + if len(value) >= 25 { + value = value[:25] + t, err = time.ParseInLocation(time.RFC3339, value, time.Local) + } else if len(value) >= 19 { + value = value[:19] + t, err = time.ParseInLocation(formatDateTime, value, time.Local) + } else if len(value) >= 10 { + if len(value) > 10 { + value = value[:10] + } + t, err = time.ParseInLocation(formatDate, value, time.Local) + } else if len(value) >= 8 { + if len(value) > 8 { + value = value[:8] + } + t, err = time.ParseInLocation(formatTime, value, time.Local) } - t, err := time.ParseInLocation(format, value, time.Local) + if err != nil { return err } diff --git a/templatefunc_test.go b/templatefunc_test.go index a1ec1136..20d9b850 100644 --- a/templatefunc_test.go +++ b/templatefunc_test.go @@ -173,7 +173,7 @@ func TestParseForm(t *testing.T) { if u.Intro != "I am an engineer!" { t.Errorf("Intro should equal `I am an engineer!` but got `%v`", u.Intro) } - if u.StrBool != true { + if !u.StrBool { t.Errorf("strboll should equal `true`, but got `%v`", u.StrBool) } y, m, d := u.Date.Date() @@ -255,43 +255,43 @@ func TestParseFormTag(t *testing.T) { objT := reflect.TypeOf(&user{}).Elem() label, name, fType, id, class, ignored, required := parseFormTag(objT.Field(0)) - if !(name == "name" && label == "年龄:" && fType == "text" && ignored == false) { + if !(name == "name" && label == "年龄:" && fType == "text" && !ignored) { t.Errorf("Form Tag with name, label and type was not correctly parsed.") } label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(1)) - if !(name == "NoName" && label == "年龄:" && fType == "hidden" && ignored == false) { + if !(name == "NoName" && label == "年龄:" && fType == "hidden" && !ignored) { t.Errorf("Form Tag with label and type but without name was not correctly parsed.") } label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(2)) - if !(name == "OnlyLabel" && label == "年龄:" && fType == "text" && ignored == false) { + if !(name == "OnlyLabel" && label == "年龄:" && fType == "text" && !ignored) { t.Errorf("Form Tag containing only label was not correctly parsed.") } label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(3)) - if !(name == "name" && label == "OnlyName: " && fType == "text" && ignored == false && + if !(name == "name" && label == "OnlyName: " && fType == "text" && !ignored && id == "name" && class == "form-name") { t.Errorf("Form Tag containing only name was not correctly parsed.") } label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(4)) - if ignored == false { + if !ignored { t.Errorf("Form Tag that should be ignored was not correctly parsed.") } label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(5)) - if !(name == "name" && required == true) { + if !(name == "name" && required) { t.Errorf("Form Tag containing only name and required was not correctly parsed.") } label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(6)) - if !(name == "name" && required == false) { + if !(name == "name" && !required) { t.Errorf("Form Tag containing only name and ignore required was not correctly parsed.") } label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(7)) - if !(name == "name" && required == false) { + if !(name == "name" && !required) { t.Errorf("Form Tag containing only name and not required was not correctly parsed.") } diff --git a/toolbox/statistics.go b/toolbox/statistics.go index 69b88772..d014544c 100644 --- a/toolbox/statistics.go +++ b/toolbox/statistics.go @@ -117,6 +117,8 @@ func (m *URLMap) GetMap() map[string]interface{} { // GetMapData return all mapdata func (m *URLMap) GetMapData() []map[string]interface{} { + m.lock.Lock() + defer m.lock.Unlock() var resultLists []map[string]interface{} diff --git a/toolbox/task.go b/toolbox/task.go index abd411c8..672717cd 100644 --- a/toolbox/task.go +++ b/toolbox/task.go @@ -427,6 +427,7 @@ func run() { } continue case <-changed: + now = time.Now().Local() continue case <-stop: return diff --git a/tree.go b/tree.go index 25b78e50..2d6c3fc3 100644 --- a/tree.go +++ b/tree.go @@ -288,10 +288,10 @@ func (t *Tree) Match(pattern string, ctx *context.Context) (runObject interface{ return nil } w := make([]string, 0, 20) - return t.match(pattern, w, ctx) + return t.match(pattern[1:], pattern, w, ctx) } -func (t *Tree) match(pattern string, wildcardValues []string, ctx *context.Context) (runObject interface{}) { +func (t *Tree) match(treePattern string, pattern string, wildcardValues []string, ctx *context.Context) (runObject interface{}) { if len(pattern) > 0 { i := 0 for ; i < len(pattern) && pattern[i] == '/'; i++ { @@ -301,13 +301,13 @@ func (t *Tree) match(pattern string, wildcardValues []string, ctx *context.Conte // Handle leaf nodes: if len(pattern) == 0 { for _, l := range t.leaves { - if ok := l.match(wildcardValues, ctx); ok { + if ok := l.match(treePattern, wildcardValues, ctx); ok { return l.runObject } } if t.wildcard != nil { for _, l := range t.wildcard.leaves { - if ok := l.match(wildcardValues, ctx); ok { + if ok := l.match(treePattern, wildcardValues, ctx); ok { return l.runObject } } @@ -327,7 +327,12 @@ func (t *Tree) match(pattern string, wildcardValues []string, ctx *context.Conte } for _, subTree := range t.fixrouters { if subTree.prefix == seg { - runObject = subTree.match(pattern, wildcardValues, ctx) + if len(pattern) != 0 && pattern[0] == '/' { + treePattern = pattern[1:] + } else { + treePattern = pattern + } + runObject = subTree.match(treePattern, pattern, wildcardValues, ctx) if runObject != nil { break } @@ -339,7 +344,7 @@ func (t *Tree) match(pattern string, wildcardValues []string, ctx *context.Conte if strings.HasSuffix(seg, str) { for _, subTree := range t.fixrouters { if subTree.prefix == seg[:len(seg)-len(str)] { - runObject = subTree.match(pattern, wildcardValues, ctx) + runObject = subTree.match(treePattern, pattern, wildcardValues, ctx) if runObject != nil { ctx.Input.SetParam(":ext", str[1:]) } @@ -349,7 +354,7 @@ func (t *Tree) match(pattern string, wildcardValues []string, ctx *context.Conte } } if runObject == nil && t.wildcard != nil { - runObject = t.wildcard.match(pattern, append(wildcardValues, seg), ctx) + runObject = t.wildcard.match(treePattern, pattern, append(wildcardValues, seg), ctx) } if runObject == nil && len(t.leaves) > 0 { @@ -368,7 +373,7 @@ func (t *Tree) match(pattern string, wildcardValues []string, ctx *context.Conte wildcardValues = append(wildcardValues, pattern[start:i]) } for _, l := range t.leaves { - if ok := l.match(wildcardValues, ctx); ok { + if ok := l.match(treePattern, wildcardValues, ctx); ok { return l.runObject } } @@ -386,7 +391,7 @@ type leafInfo struct { runObject interface{} } -func (leaf *leafInfo) match(wildcardValues []string, ctx *context.Context) (ok bool) { +func (leaf *leafInfo) match(treePattern string, wildcardValues []string, ctx *context.Context) (ok bool) { //fmt.Println("Leaf:", wildcardValues, leaf.wildcards, leaf.regexps) if leaf.regexps == nil { if len(wildcardValues) == 0 && len(leaf.wildcards) == 0 { // static path @@ -394,7 +399,7 @@ func (leaf *leafInfo) match(wildcardValues []string, ctx *context.Context) (ok b } // match * if len(leaf.wildcards) == 1 && leaf.wildcards[0] == ":splat" { - ctx.Input.SetParam(":splat", path.Join(wildcardValues...)) + ctx.Input.SetParam(":splat", treePattern) return true } // match *.* or :id diff --git a/tree_test.go b/tree_test.go index 81ff7edd..d412a348 100644 --- a/tree_test.go +++ b/tree_test.go @@ -42,7 +42,7 @@ func init() { routers = append(routers, testinfo{"/", "/", nil}) routers = append(routers, testinfo{"/customer/login", "/customer/login", nil}) routers = append(routers, testinfo{"/customer/login", "/customer/login.json", map[string]string{":ext": "json"}}) - routers = append(routers, testinfo{"/*", "/customer/123", map[string]string{":splat": "customer/123"}}) + routers = append(routers, testinfo{"/*", "/http://customer/123/", map[string]string{":splat": "http://customer/123/"}}) routers = append(routers, testinfo{"/*", "/customer/2009/12/11", map[string]string{":splat": "customer/2009/12/11"}}) routers = append(routers, testinfo{"/aa/*/bb", "/aa/2009/bb", map[string]string{":splat": "2009"}}) routers = append(routers, testinfo{"/cc/*/dd", "/cc/2009/11/dd", map[string]string{":splat": "2009/11"}}) diff --git a/utils/captcha/image.go b/utils/captcha/image.go index 0ceb8e42..c3c9a83a 100644 --- a/utils/captcha/image.go +++ b/utils/captcha/image.go @@ -474,7 +474,7 @@ func randomBrightness(c color.RGBA, max uint8) color.RGBA { uint8(int(c.R) + n), uint8(int(c.G) + n), uint8(int(c.B) + n), - uint8(c.A), + c.A, } } diff --git a/utils/mail.go b/utils/mail.go index 10555a0a..e3fa1c90 100644 --- a/utils/mail.go +++ b/utils/mail.go @@ -232,14 +232,16 @@ func (e *Email) Send() error { return errors.New("Must specify at least one To address") } - from, err := mail.ParseAddress(e.Username) + // Use the username if no From is provided + if len(e.From) == 0 { + e.From = e.Username + } + + from, err := mail.ParseAddress(e.From) if err != nil { return err } - if len(e.From) == 0 { - e.From = e.Username - } // use mail's RFC 2047 to encode any string e.Subject = qEncode("utf-8", e.Subject) diff --git a/utils/safemap.go b/utils/safemap.go index 2e438f2c..1793030a 100644 --- a/utils/safemap.go +++ b/utils/safemap.go @@ -61,10 +61,8 @@ func (m *BeeMap) Set(k interface{}, v interface{}) bool { func (m *BeeMap) Check(k interface{}) bool { m.lock.RLock() defer m.lock.RUnlock() - if _, ok := m.bm[k]; !ok { - return false - } - return true + _, ok := m.bm[k] + return ok } // Delete the given key and value. @@ -84,3 +82,10 @@ func (m *BeeMap) Items() map[interface{}]interface{} { } return r } + +// Count returns the number of items within the map. +func (m *BeeMap) Count() int { + m.lock.RLock() + defer m.lock.RUnlock() + return len(m.bm) +} diff --git a/utils/safemap_test.go b/utils/safemap_test.go index fb271d18..1bfe8699 100644 --- a/utils/safemap_test.go +++ b/utils/safemap_test.go @@ -14,25 +14,44 @@ package utils -import ( - "testing" -) +import "testing" -func Test_beemap(t *testing.T) { - bm := NewBeeMap() - if !bm.Set("astaxie", 1) { - t.Error("set Error") - } - if !bm.Check("astaxie") { - t.Error("check err") - } +var safeMap *BeeMap - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - - bm.Delete("astaxie") - if bm.Check("astaxie") { - t.Error("delete err") +func TestNewBeeMap(t *testing.T) { + safeMap = NewBeeMap() + if safeMap == nil { + t.Fatal("expected to return non-nil BeeMap", "got", safeMap) + } +} + +func TestSet(t *testing.T) { + if ok := safeMap.Set("astaxie", 1); !ok { + t.Error("expected", true, "got", false) + } +} + +func TestCheck(t *testing.T) { + if exists := safeMap.Check("astaxie"); !exists { + t.Error("expected", true, "got", false) + } +} + +func TestGet(t *testing.T) { + if val := safeMap.Get("astaxie"); val.(int) != 1 { + t.Error("expected value", 1, "got", val) + } +} + +func TestDelete(t *testing.T) { + safeMap.Delete("astaxie") + if exists := safeMap.Check("astaxie"); exists { + t.Error("expected element to be deleted") + } +} + +func TestCount(t *testing.T) { + if count := safeMap.Count(); count != 0 { + t.Error("expected count to be", 0, "got", count) } } diff --git a/validation/validation.go b/validation/validation.go index 489dfa5e..9dc51106 100644 --- a/validation/validation.go +++ b/validation/validation.go @@ -349,7 +349,7 @@ func (v *Validation) RecursiveValid(objc interface{}) (bool, error) { //Step 1: validate obj itself firstly // fails if objc is not struct pass, err := v.Valid(objc) - if err != nil || false == pass { + if err != nil || !pass { return pass, err // Stop recursive validation } // Step 2: Validate struct's struct fields diff --git a/validation/validation_test.go b/validation/validation_test.go index ec65b6d0..83e881bf 100644 --- a/validation/validation_test.go +++ b/validation/validation_test.go @@ -214,6 +214,12 @@ func TestEmail(t *testing.T) { if !valid.Email("suchuangji@gmail.com", "email").Ok { t.Error("\"suchuangji@gmail.com\" is a valid email address should be true") } + if valid.Email("@suchuangji@gmail.com", "email").Ok { + t.Error("\"@suchuangji@gmail.com\" is a valid email address should be false") + } + if valid.Email("suchuangji@gmail.com ok", "email").Ok { + t.Error("\"suchuangji@gmail.com ok\" is a valid email address should be false") + } } func TestIP(t *testing.T) { diff --git a/validation/validators.go b/validation/validators.go index 9b04c5ce..01aed443 100644 --- a/validation/validators.go +++ b/validation/validators.go @@ -518,7 +518,7 @@ func (a AlphaDash) GetLimitValue() interface{} { return nil } -var emailPattern = regexp.MustCompile("[\\w!#$%&'*+/=?^_`{|}~-]+(?:\\.[\\w!#$%&'*+/=?^_`{|}~-]+)*@(?:[\\w](?:[\\w-]*[\\w])?\\.)+[a-zA-Z0-9](?:[\\w-]*[\\w])?") +var emailPattern = regexp.MustCompile("^[\\w!#$%&'*+/=?^_`{|}~-]+(?:\\.[\\w!#$%&'*+/=?^_`{|}~-]+)*@(?:[\\w](?:[\\w-]*[\\w])?\\.)+[a-zA-Z0-9](?:[\\w-]*[\\w])?$") // Email check struct type Email struct {