diff --git a/README.md b/README.md index fbd7ccb7..d3c92d84 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ [![Build Status](https://travis-ci.org/astaxie/beego.svg?branch=master)](https://travis-ci.org/astaxie/beego) [![GoDoc](http://godoc.org/github.com/astaxie/beego?status.svg)](http://godoc.org/github.com/astaxie/beego) +[![Foundation](https://img.shields.io/badge/Golang-Foundation-green.svg)](http://golangfoundation.org) beego is used for rapid development of RESTful APIs, web apps and backend services in Go. It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific features such as interfaces and struct embedding. diff --git a/admin_test.go b/admin_test.go index 0bf985f2..2348792e 100644 --- a/admin_test.go +++ b/admin_test.go @@ -65,6 +65,7 @@ func oldMap() map[string]interface{} { m["BConfig.WebConfig.Session.SessionCookieLifeTime"] = BConfig.WebConfig.Session.SessionCookieLifeTime m["BConfig.WebConfig.Session.SessionAutoSetCookie"] = BConfig.WebConfig.Session.SessionAutoSetCookie m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain + m["BConfig.WebConfig.Session.SessionDisableHTTPOnly"] = BConfig.WebConfig.Session.SessionDisableHTTPOnly m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum m["BConfig.Log.Outputs"] = BConfig.Log.Outputs diff --git a/beego.go b/beego.go index 94ec92b2..1bc8bb85 100644 --- a/beego.go +++ b/beego.go @@ -23,7 +23,7 @@ import ( const ( // VERSION represent beego web framework version. - VERSION = "1.7.1" + VERSION = "1.7.2" // DEV is for develop DEV = "dev" diff --git a/config.go b/config.go index a4f40611..36bf445c 100644 --- a/config.go +++ b/config.go @@ -86,17 +86,18 @@ type WebConfig struct { // SessionConfig holds session related config type SessionConfig struct { - SessionOn bool - SessionProvider string - SessionName string - SessionGCMaxLifetime int64 - SessionProviderConfig string - SessionCookieLifeTime int - SessionAutoSetCookie bool - SessionDomain string - EnableSidInHttpHeader bool // enable store/get the sessionId into/from http headers - SessionNameInHttpHeader string - EnableSidInUrlQuery bool // enable get the sessionId from Url Query params + SessionOn bool + SessionProvider string + SessionName string + SessionGCMaxLifetime int64 + SessionProviderConfig string + SessionCookieLifeTime int + SessionAutoSetCookie bool + SessionDomain string + SessionDisableHTTPOnly bool // used to allow for cross domain cookies/javascript cookies. + SessionEnableSidInHTTPHeader bool // enable store/get the sessionId into/from http headers + SessionNameInHTTPHeader string + SessionEnableSidInURLQuery bool // enable get the sessionId from Url Query params } // LogConfig holds Log related config @@ -143,6 +144,9 @@ func init() { if err = parseConfig(appConfigPath); err != nil { panic(err) } + if err = os.Chdir(AppPath); err != nil { + panic(err) + } } func recoverPanic(ctx *context.Context) { @@ -221,17 +225,18 @@ func newBConfig() *Config { XSRFKey: "beegoxsrf", XSRFExpire: 0, Session: SessionConfig{ - SessionOn: false, - SessionProvider: "memory", - SessionName: "beegosessionID", - SessionGCMaxLifetime: 3600, - SessionProviderConfig: "", - SessionCookieLifeTime: 0, //set cookie default is the browser life - SessionAutoSetCookie: true, - SessionDomain: "", - EnableSidInHttpHeader: false, // enable store/get the sessionId into/from http headers - SessionNameInHttpHeader: "Beegosessionid", - EnableSidInUrlQuery: false, // enable get the sessionId from Url Query params + SessionOn: false, + SessionProvider: "memory", + SessionName: "beegosessionID", + SessionGCMaxLifetime: 3600, + SessionProviderConfig: "", + SessionDisableHTTPOnly: false, + SessionCookieLifeTime: 0, //set cookie default is the browser life + SessionAutoSetCookie: true, + SessionDomain: "", + SessionEnableSidInHTTPHeader: false, // enable store/get the sessionId into/from http headers + SessionNameInHTTPHeader: "Beegosessionid", + SessionEnableSidInURLQuery: false, // enable get the sessionId from Url Query params }, }, Log: LogConfig{ @@ -294,6 +299,10 @@ func assignConfig(ac config.Configer) error { } if lo := ac.String("LogOutputs"); lo != "" { + // if lo is not nil or empty + // means user has set his own LogOutputs + // clear the default setting to BConfig.Log.Outputs + BConfig.Log.Outputs = make(map[string]string) los := strings.Split(lo, ";") for _, v := range los { if logType2Config := strings.SplitN(v, ",", 2); len(logType2Config) == 2 { diff --git a/config/config.go b/config/config.go index 9f41fb79..e8201a24 100644 --- a/config/config.go +++ b/config/config.go @@ -43,6 +43,8 @@ package config import ( "fmt" "os" + "reflect" + "time" ) // Configer defines how to get and set value from configuration raw data. @@ -204,3 +206,37 @@ func ParseBool(val interface{}) (value bool, err error) { } return false, fmt.Errorf("parsing : invalid syntax") } + +// ToString converts values of any type to string. +func ToString(x interface{}) string { + switch y := x.(type) { + + // Handle dates with special logic + // This needs to come above the fmt.Stringer + // test since time.Time's have a .String() + // method + case time.Time: + return y.Format("A Monday") + + // Handle type string + case string: + return y + + // Handle type with .String() method + case fmt.Stringer: + return y.String() + + // Handle type with .Error() method + case error: + return y.Error() + + } + + // Handle named string type + if v := reflect.ValueOf(x); v.Kind() == reflect.String { + return v.String() + } + + // Fallback to fmt package for anything else like numeric types + return fmt.Sprint(x) +} diff --git a/config/ini.go b/config/ini.go index 9371fe61..b3332bd8 100644 --- a/config/ini.go +++ b/config/ini.go @@ -23,6 +23,7 @@ import ( "io/ioutil" "os" "path" + "path/filepath" "strconv" "strings" "sync" @@ -132,8 +133,8 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) { includefiles := strings.Fields(key) if includefiles[0] == "include" && len(includefiles) == 2 { otherfile := strings.Trim(includefiles[1], "\"") - if !path.IsAbs(otherfile) { - otherfile = path.Join(path.Dir(name), otherfile) + if !filepath.IsAbs(otherfile) { + otherfile = filepath.Join(filepath.Dir(name), otherfile) } i, err := ini.parseFile(otherfile) if err != nil { diff --git a/config/xml/xml.go b/config/xml/xml.go index 0c4e4d27..66115714 100644 --- a/config/xml/xml.go +++ b/config/xml/xml.go @@ -193,10 +193,14 @@ func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []stri // GetSection returns map for the given section func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { - if v, ok := c.data[section]; ok { - return v.(map[string]string), nil + if v, ok := c.data[section].(map[string]interface{}); ok { + mapstr := make(map[string]string) + for k, val := range v { + mapstr[k] = config.ToString(val) + } + return mapstr, nil } - return nil, errors.New("not exist setction") + return nil, fmt.Errorf("section '%s' not found", section) } // SaveConfigFile save the config into file diff --git a/config/xml/xml_test.go b/config/xml/xml_test.go index d8a09a59..346c866e 100644 --- a/config/xml/xml_test.go +++ b/config/xml/xml_test.go @@ -37,6 +37,10 @@ func TestXML(t *testing.T) { true ${GOPATH} ${GOPATH||/home/go} + +1 +MySection + ` keyValue = map[string]interface{}{ @@ -65,11 +69,22 @@ func TestXML(t *testing.T) { } f.Close() defer os.Remove("testxml.conf") + xmlconf, err := config.NewConfig("xml", "testxml.conf") if err != nil { t.Fatal(err) } + var xmlsection map[string]string + xmlsection, err = xmlconf.GetSection("mysection") + if err != nil { + t.Fatal(err) + } + + if len(xmlsection) == 0 { + t.Error("section should not be empty") + } + for k, v := range keyValue { var ( diff --git a/controller.go b/controller.go index c7eb118d..e484ce49 100644 --- a/controller.go +++ b/controller.go @@ -399,6 +399,16 @@ func (c *Controller) GetInt8(key string, def ...int8) (int8, error) { return int8(i64), err } +// GetUint8 return input as an uint8 or the default value while it's present and input is blank +func (c *Controller) GetUint8(key string, def ...uint8) (uint8, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + u64, err := strconv.ParseUint(strv, 10, 8) + return uint8(u64), err +} + // GetInt16 returns input as an int16 or the default value while it's present and input is blank func (c *Controller) GetInt16(key string, def ...int16) (int16, error) { strv := c.Ctx.Input.Query(key) @@ -409,6 +419,16 @@ func (c *Controller) GetInt16(key string, def ...int16) (int16, error) { return int16(i64), err } +// GetUint16 returns input as an uint16 or the default value while it's present and input is blank +func (c *Controller) GetUint16(key string, def ...uint16) (uint16, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + u64, err := strconv.ParseUint(strv, 10, 16) + return uint16(u64), err +} + // GetInt32 returns input as an int32 or the default value while it's present and input is blank func (c *Controller) GetInt32(key string, def ...int32) (int32, error) { strv := c.Ctx.Input.Query(key) @@ -419,6 +439,16 @@ func (c *Controller) GetInt32(key string, def ...int32) (int32, error) { return int32(i64), err } +// GetUint32 returns input as an uint32 or the default value while it's present and input is blank +func (c *Controller) GetUint32(key string, def ...uint32) (uint32, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + u64, err := strconv.ParseUint(strv, 10, 32) + return uint32(u64), err +} + // GetInt64 returns input value as int64 or the default value while it's present and input is blank. func (c *Controller) GetInt64(key string, def ...int64) (int64, error) { strv := c.Ctx.Input.Query(key) @@ -428,6 +458,15 @@ func (c *Controller) GetInt64(key string, def ...int64) (int64, error) { return strconv.ParseInt(strv, 10, 64) } +// GetUint64 returns input value as uint64 or the default value while it's present and input is blank. +func (c *Controller) GetUint64(key string, def ...uint64) (uint64, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + return strconv.ParseUint(strv, 10, 64) +} + // GetBool returns input value as bool or the default value while it's present and input is blank. func (c *Controller) GetBool(key string, def ...bool) (bool, error) { strv := c.Ctx.Input.Query(key) @@ -453,7 +492,7 @@ func (c *Controller) GetFile(key string) (multipart.File, *multipart.FileHeader, } // GetFiles return multi-upload files -// files, err:=c.Getfiles("myfiles") +// files, err:=c.GetFiles("myfiles") // if err != nil { // http.Error(w, err.Error(), http.StatusNoContent) // return diff --git a/controller_test.go b/controller_test.go index 51d3a5b7..132971a1 100644 --- a/controller_test.go +++ b/controller_test.go @@ -15,6 +15,8 @@ package beego import ( + "math" + "strconv" "testing" "github.com/astaxie/beego/context" @@ -75,3 +77,47 @@ func TestGetInt64(t *testing.T) { t.Errorf("TestGeetInt64 expect 40,get %T,%v", val, val) } } + +func TestGetUint8(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint8, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint8("age") + if val != math.MaxUint8 { + t.Errorf("TestGetUint8 expect %v,get %T,%v", math.MaxUint8, val, val) + } +} + +func TestGetUint16(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint16, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint16("age") + if val != math.MaxUint16 { + t.Errorf("TestGetUint16 expect %v,get %T,%v", math.MaxUint16, val, val) + } +} + +func TestGetUint32(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint32, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint32("age") + if val != math.MaxUint32 { + t.Errorf("TestGetUint32 expect %v,get %T,%v", math.MaxUint32, val, val) + } +} + +func TestGetUint64(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint64, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint64("age") + if val != math.MaxUint64 { + t.Errorf("TestGetUint64 expect %v,get %T,%v", uint64(math.MaxUint64), val, val) + } +} diff --git a/hooks.go b/hooks.go index 0c7d05fe..b5a5e6c5 100644 --- a/hooks.go +++ b/hooks.go @@ -53,10 +53,11 @@ func registerSession() error { conf.Secure = BConfig.Listen.EnableHTTPS conf.CookieLifeTime = BConfig.WebConfig.Session.SessionCookieLifeTime conf.ProviderConfig = filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig) + conf.DisableHTTPOnly = BConfig.WebConfig.Session.SessionDisableHTTPOnly conf.Domain = BConfig.WebConfig.Session.SessionDomain - conf.EnableSidInHttpHeader = BConfig.WebConfig.Session.EnableSidInHttpHeader - conf.SessionNameInHttpHeader = BConfig.WebConfig.Session.SessionNameInHttpHeader - conf.EnableSidInUrlQuery = BConfig.WebConfig.Session.EnableSidInUrlQuery + conf.EnableSidInHttpHeader = BConfig.WebConfig.Session.SessionEnableSidInHTTPHeader + conf.SessionNameInHttpHeader = BConfig.WebConfig.Session.SessionNameInHTTPHeader + conf.EnableSidInUrlQuery = BConfig.WebConfig.Session.SessionEnableSidInURLQuery } else { if err = json.Unmarshal([]byte(sessionConfig), conf); err != nil { return err diff --git a/httplib/httplib.go b/httplib/httplib.go index 7e6f2700..510ad75e 100644 --- a/httplib/httplib.go +++ b/httplib/httplib.go @@ -136,6 +136,7 @@ type BeegoHTTPSettings struct { TLSClientConfig *tls.Config Proxy func(*http.Request) (*url.URL, error) Transport http.RoundTripper + CheckRedirect func(req *http.Request, via []*http.Request) error EnableCookie bool Gzip bool DumpBody bool @@ -265,6 +266,15 @@ func (b *BeegoHTTPRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) return b } +// SetCheckRedirect specifies the policy for handling redirects. +// +// If CheckRedirect is nil, the Client uses its default policy, +// which is to stop after 10 consecutive requests. +func (b *BeegoHTTPRequest) SetCheckRedirect(redirect func(req *http.Request, via []*http.Request) error) *BeegoHTTPRequest { + b.setting.CheckRedirect = redirect + return b +} + // Param adds query param in to request. // params build query string as ?key1=value1&key2=value2... func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest { @@ -446,6 +456,10 @@ func (b *BeegoHTTPRequest) DoRequest() (*http.Response, error) { b.req.Header.Set("User-Agent", b.setting.UserAgent) } + if b.setting.CheckRedirect != nil { + client.CheckRedirect = b.setting.CheckRedirect + } + if b.setting.ShowDebug { dump, err := httputil.DumpRequest(b.req, b.setting.DumpBody) if err != nil { diff --git a/logs/jianliao.go b/logs/jianliao.go index 3755118d..16773c93 100644 --- a/logs/jianliao.go +++ b/logs/jianliao.go @@ -56,10 +56,10 @@ func (s *JLWriter) WriteMsg(when time.Time, msg string, level int) error { if err != nil { return err } + defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return fmt.Errorf("Post webhook failed %s %d", resp.Status, resp.StatusCode) } - resp.Body.Close() return nil } diff --git a/logs/log.go b/logs/log.go index 3d512d2e..806ebaa0 100644 --- a/logs/log.go +++ b/logs/log.go @@ -380,7 +380,10 @@ func (bl *BeeLogger) Error(format string, v ...interface{}) { // Warning Log WARNING level message. func (bl *BeeLogger) Warning(format string, v ...interface{}) { - bl.Warn(format, v...) + if LevelWarn > bl.level { + return + } + bl.writeMsg(LevelWarn, format, v...) } // Notice Log NOTICE level message. @@ -393,7 +396,10 @@ func (bl *BeeLogger) Notice(format string, v ...interface{}) { // Informational Log INFORMATIONAL level message. func (bl *BeeLogger) Informational(format string, v ...interface{}) { - bl.Info(format, v...) + if LevelInfo > bl.level { + return + } + bl.writeMsg(LevelInfo, format, v...) } // Debug Log DEBUG level message. @@ -425,7 +431,10 @@ func (bl *BeeLogger) Info(format string, v ...interface{}) { // Trace Log TRACE level message. // compatibility alias for Debug() func (bl *BeeLogger) Trace(format string, v ...interface{}) { - bl.Debug(format, v...) + if LevelDebug > bl.level { + return + } + bl.writeMsg(LevelDebug, format, v...) } // Flush flush all chan data. diff --git a/logs/slack.go b/logs/slack.go index eddedd5d..90f009cb 100644 --- a/logs/slack.go +++ b/logs/slack.go @@ -44,10 +44,10 @@ func (s *SLACKWriter) WriteMsg(when time.Time, msg string, level int) error { if err != nil { return err } + defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return fmt.Errorf("Post webhook failed %s %d", resp.Status, resp.StatusCode) } - resp.Body.Close() return nil } diff --git a/orm/db_mysql.go b/orm/db_mysql.go index 10fe2657..1016de2b 100644 --- a/orm/db_mysql.go +++ b/orm/db_mysql.go @@ -16,6 +16,8 @@ package orm import ( "fmt" + "reflect" + "strings" ) // mysql operators. @@ -96,6 +98,83 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool return cnt > 0 } +// InsertOrUpdate a row +// If your primary key or unique column conflict will update +// If no will insert +// Add "`" for mysql sql building +func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { + + iouStr := "" + argsMap := map[string]string{} + + iouStr = "ON DUPLICATE KEY UPDATE" + + //Get on the key-value pairs + for _, v := range args { + kv := strings.Split(v, "=") + if len(kv) == 2 { + argsMap[strings.ToLower(kv[0])] = kv[1] + } + } + + isMulti := false + names := make([]string, 0, len(mi.fields.dbcols)-1) + Q := d.ins.TableQuote() + values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ) + + if err != nil { + return 0, err + } + + marks := make([]string, len(names)) + updateValues := make([]interface{}, 0) + updates := make([]string, len(names)) + + for i, v := range names { + marks[i] = "?" + valueStr := argsMap[strings.ToLower(v)] + if valueStr != "" { + updates[i] = "`" + v + "`" + "=" + valueStr + } else { + updates[i] = "`" + v + "`" + "=?" + updateValues = append(updateValues, values[i]) + } + } + + values = append(values, updateValues...) + + sep := fmt.Sprintf("%s, %s", Q, Q) + qmarks := strings.Join(marks, ", ") + qupdates := strings.Join(updates, ", ") + columns := strings.Join(names, sep) + + multi := len(values) / len(names) + + if isMulti { + qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks + } + //conflitValue maybe is a int,can`t use fmt.Sprintf + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr) + + d.ins.ReplaceMarks(&query) + + if isMulti || !d.ins.HasReturningID(mi, &query) { + res, err := q.Exec(query, values...) + if err == nil { + if isMulti { + return res.RowsAffected() + } + return res.LastInsertId() + } + return 0, err + } + + row := q.QueryRow(query, values...) + var id int64 + err = row.Scan(&id) + return id, err +} + // create new mysql dbBaser. func newdbBaseMysql() dbBaser { b := new(dbBaseMysql) diff --git a/orm/db_utils.go b/orm/db_utils.go index 0279a14a..923917ec 100644 --- a/orm/db_utils.go +++ b/orm/db_utils.go @@ -147,8 +147,10 @@ outFor: arg = v.In(tz).Format(formatDate) } else if fi != nil && fi.fieldType == TypeDateTimeField { arg = v.In(tz).Format(formatDateTime) - } else { + } else if fi != nil && fi.fieldType == TypeTimeField { arg = v.In(tz).Format(formatTime) + } else { + arg = v.In(tz).Format(formatDateTime) } } else { typ := val.Type() diff --git a/orm/orm_conds.go b/orm/orm_conds.go index e56d6fbb..f6e389ec 100644 --- a/orm/orm_conds.go +++ b/orm/orm_conds.go @@ -75,6 +75,19 @@ func (c *Condition) AndCond(cond *Condition) *Condition { return c } +// AndNotCond combine a AND NOT condition to current condition +func (c *Condition) AndNotCond(cond *Condition) *Condition { + c = c.clone() + if c == cond { + panic(fmt.Errorf(" cannot use self as sub cond")) + } + + if cond != nil { + c.params = append(c.params, condValue{cond: cond, isCond: true, isNot: true}) + } + return c +} + // Or add OR expression to condition func (c Condition) Or(expr string, args ...interface{}) *Condition { if expr == "" || len(args) == 0 { @@ -105,6 +118,19 @@ func (c *Condition) OrCond(cond *Condition) *Condition { return c } +// OrNotCond combine a OR NOT condition to current condition +func (c *Condition) OrNotCond(cond *Condition) *Condition { + c = c.clone() + if c == cond { + panic(fmt.Errorf(" cannot use self as sub cond")) + } + + if cond != nil { + c.params = append(c.params, condValue{cond: cond, isCond: true, isNot: true, isOr: true}) + } + return c +} + // IsEmpty check the condition arguments are empty or not. func (c *Condition) IsEmpty() bool { return len(c.params) == 0 diff --git a/orm/orm_test.go b/orm/orm_test.go index fbf4768d..adfe0066 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -909,6 +909,16 @@ func TestSetCond(t *testing.T) { num, err = qs.SetCond(cond2).Count() throwFail(t, err) throwFail(t, AssertIs(num, 2)) + + cond3 := cond.AndNotCond(cond.And("status__in", 1)) + num, err = qs.SetCond(cond3).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + cond4 := cond.And("user_name", "slene").OrNotCond(cond.And("user_name", "slene")) + num, err = qs.SetCond(cond4).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) } func TestLimit(t *testing.T) { diff --git a/orm/utils.go b/orm/utils.go index 6e23447e..bf43ceb0 100644 --- a/orm/utils.go +++ b/orm/utils.go @@ -16,6 +16,7 @@ package orm import ( "fmt" + "math/big" "reflect" "strconv" "strings" @@ -87,6 +88,14 @@ func (f StrTo) Int32() (int32, error) { // Int64 string to int64 func (f StrTo) Int64() (int64, error) { v, err := strconv.ParseInt(f.String(), 10, 64) + if err != nil { + i := new(big.Int) + ni, ok := i.SetString(f.String(), 10) // octal + if !ok { + return int64(v), err + } + return ni.Int64(), nil + } return int64(v), err } @@ -117,6 +126,14 @@ func (f StrTo) Uint32() (uint32, error) { // Uint64 string to uint64 func (f StrTo) Uint64() (uint64, error) { v, err := strconv.ParseUint(f.String(), 10, 64) + if err != nil { + i := new(big.Int) + ni, ok := i.SetString(f.String(), 10) + if !ok { + return uint64(v), err + } + return ni.Uint64(), nil + } return uint64(v), err } diff --git a/policy.go b/policy.go new file mode 100644 index 00000000..2b91fdcc --- /dev/null +++ b/policy.go @@ -0,0 +1,97 @@ +// Copyright 2016 beego authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "strings" + + "github.com/astaxie/beego/context" +) + +// PolicyFunc defines a policy function which is invoked before the controller handler is executed. +type PolicyFunc func(*context.Context) + +// FindRouter Find Router info for URL +func (p *ControllerRegister) FindPolicy(cont *context.Context) []PolicyFunc { + var urlPath = cont.Input.URL() + if !BConfig.RouterCaseSensitive { + urlPath = strings.ToLower(urlPath) + } + httpMethod := cont.Input.Method() + isWildcard := false + // Find policy for current method + t, ok := p.policies[httpMethod] + // If not found - find policy for whole controller + if !ok { + t, ok = p.policies["*"] + isWildcard = true + } + if ok { + runObjects := t.Match(urlPath, cont) + if r, ok := runObjects.([]PolicyFunc); ok { + return r + } else if !isWildcard { + // If no policies found and we checked not for "*" method - try to find it + t, ok = p.policies["*"] + if ok { + runObjects = t.Match(urlPath, cont) + if r, ok = runObjects.([]PolicyFunc); ok { + return r + } + } + } + } + return nil +} + +func (p *ControllerRegister) addToPolicy(method, pattern string, r ...PolicyFunc) { + method = strings.ToUpper(method) + p.enablePolicy = true + if !BConfig.RouterCaseSensitive { + pattern = strings.ToLower(pattern) + } + if t, ok := p.policies[method]; ok { + t.AddRouter(pattern, r) + } else { + t := NewTree() + t.AddRouter(pattern, r) + p.policies[method] = t + } +} + +// Register new policy in beego +func Policy(pattern, method string, policy ...PolicyFunc) { + BeeApp.Handlers.addToPolicy(method, pattern, policy...) +} + +// Find policies and execute if were found +func (p *ControllerRegister) execPolicy(cont *context.Context, urlPath string) (started bool) { + if !p.enablePolicy { + return false + } + // Find Policy for method + policyList := p.FindPolicy(cont) + if len(policyList) > 0 { + // Run policies + for _, runPolicy := range policyList { + runPolicy(cont) + if cont.ResponseWriter.Started { + return true + } + } + return false + } + return false +} diff --git a/router.go b/router.go index 456c3221..74cf02a1 100644 --- a/router.go +++ b/router.go @@ -114,6 +114,8 @@ type controllerInfo struct { // ControllerRegister containers registered router rules, controller handlers and filters. type ControllerRegister struct { routers map[string]*Tree + enablePolicy bool + policies map[string]*Tree enableFilter bool filters [FinishRouter + 1][]*FilterRouter pool sync.Pool @@ -122,7 +124,8 @@ type ControllerRegister struct { // NewControllerRegister returns a new ControllerRegister. func NewControllerRegister() *ControllerRegister { cr := &ControllerRegister{ - routers: make(map[string]*Tree), + routers: make(map[string]*Tree), + policies: make(map[string]*Tree), } cr.pool.New = func() interface{} { return beecontext.NewContext() @@ -711,6 +714,11 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) goto Admin } + //check policies + if p.execPolicy(context, urlPath) { + goto Admin + } + if routerInfo != nil { //store router pattern into context context.Input.SetData("RouterPattern", routerInfo.pattern) diff --git a/session/session.go b/session/session.go index 3c9d07ab..fb4b2821 100644 --- a/session/session.go +++ b/session/session.go @@ -86,6 +86,7 @@ type ManagerConfig struct { EnableSetCookie bool `json:"enableSetCookie,omitempty"` Gclifetime int64 `json:"gclifetime"` Maxlifetime int64 `json:"maxLifetime"` + DisableHTTPOnly bool `json:"disableHTTPOnly"` Secure bool `json:"secure"` CookieLifeTime int `json:"cookieLifeTime"` ProviderConfig string `json:"providerConfig"` @@ -206,13 +207,13 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se session, err = manager.provider.SessionRead(sid) if err != nil { - return nil, errs + return nil, err } cookie := &http.Cookie{ Name: manager.config.CookieName, Value: url.QueryEscape(sid), Path: "/", - HttpOnly: true, + HttpOnly: !manager.config.DisableHTTPOnly, Secure: manager.isSecure(r), Domain: manager.config.Domain, } @@ -251,7 +252,7 @@ func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { expiration := time.Now() cookie = &http.Cookie{Name: manager.config.CookieName, Path: "/", - HttpOnly: true, + HttpOnly: !manager.config.DisableHTTPOnly, Expires: expiration, MaxAge: -1} @@ -285,7 +286,7 @@ func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Reque cookie = &http.Cookie{Name: manager.config.CookieName, Value: url.QueryEscape(sid), Path: "/", - HttpOnly: true, + HttpOnly: !manager.config.DisableHTTPOnly, Secure: manager.isSecure(r), Domain: manager.config.Domain, } diff --git a/swagger/swagger.go b/swagger/swagger.go index 409e264e..e0ac5cf5 100644 --- a/swagger/swagger.go +++ b/swagger/swagger.go @@ -97,6 +97,7 @@ type Parameter struct { Type string `json:"type,omitempty" yaml:"type,omitempty"` Format string `json:"format,omitempty" yaml:"format,omitempty"` Items *ParameterItems `json:"items,omitempty" yaml:"items,omitempty"` + Default interface{} `json:"default,omitempty" yaml:"default,omitempty"` } // A limited subset of JSON-Schema's items object. It is used by parameter definitions that are not located in "body". @@ -126,7 +127,7 @@ type Propertie struct { Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"` Title string `json:"title,omitempty" yaml:"title,omitempty"` Description string `json:"description,omitempty" yaml:"description,omitempty"` - Default string `json:"default,omitempty" yaml:"default,omitempty"` + Default interface{} `json:"default,omitempty" yaml:"default,omitempty"` Type string `json:"type,omitempty" yaml:"type,omitempty"` Example string `json:"example,omitempty" yaml:"example,omitempty"` Required []string `json:"required,omitempty" yaml:"required,omitempty"` diff --git a/utils/mail.go b/utils/mail.go index 10555a0a..e3fa1c90 100644 --- a/utils/mail.go +++ b/utils/mail.go @@ -232,14 +232,16 @@ func (e *Email) Send() error { return errors.New("Must specify at least one To address") } - from, err := mail.ParseAddress(e.Username) + // Use the username if no From is provided + if len(e.From) == 0 { + e.From = e.Username + } + + from, err := mail.ParseAddress(e.From) if err != nil { return err } - if len(e.From) == 0 { - e.From = e.Username - } // use mail's RFC 2047 to encode any string e.Subject = qEncode("utf-8", e.Subject)