diff --git a/beego.go b/beego.go index 29cd45a0..e658c3b5 100644 --- a/beego.go +++ b/beego.go @@ -12,7 +12,7 @@ import ( ) // beego web framework version. -const VERSION = "1.1.0" +const VERSION = "1.1.1" type hookfunc func() error //hook function to run var hooks []hookfunc //hook function slice to store the hookfunc @@ -28,12 +28,12 @@ type GroupRouters []groupRouter // Get a new GroupRouters func NewGroupRouters() GroupRouters { - return make([]groupRouter, 0) + return make(GroupRouters, 0) } // Add Router in the GroupRouters // it is for plugin or module to register router -func (gr GroupRouters) AddRouter(pattern string, c ControllerInterface, mappingMethod ...string) { +func (gr *GroupRouters) AddRouter(pattern string, c ControllerInterface, mappingMethod ...string) { var newRG groupRouter if len(mappingMethod) > 0 { newRG = groupRouter{ @@ -48,16 +48,16 @@ func (gr GroupRouters) AddRouter(pattern string, c ControllerInterface, mappingM "", } } - gr = append(gr, newRG) + *gr = append(*gr, newRG) } -func (gr GroupRouters) AddAuto(c ControllerInterface) { +func (gr *GroupRouters) AddAuto(c ControllerInterface) { newRG := groupRouter{ "", c, "", } - gr = append(gr, newRG) + *gr = append(*gr, newRG) } // AddGroupRouter with the prefix diff --git a/cache/file.go b/cache/file.go index 410da3a0..d750fecc 100644 --- a/cache/file.go +++ b/cache/file.go @@ -147,6 +147,8 @@ func (this *FileCache) Get(key string) interface{} { // timeout means how long to keep this file, unit of ms. // if timeout equals FileCacheEmbedExpiry(default is 0), cache this item forever. func (this *FileCache) Put(key string, val interface{}, timeout int64) error { + gob.Register(val) + filename := this.getCacheFileName(key) var item FileCacheItem item.Data = val diff --git a/context/context.go b/context/context.go index ee308476..36d566b5 100644 --- a/context/context.go +++ b/context/context.go @@ -1,7 +1,14 @@ package context import ( + "crypto/hmac" + "crypto/sha1" + "encoding/base64" + "fmt" "net/http" + "strconv" + "strings" + "time" "github.com/astaxie/beego/middleware" ) @@ -59,3 +66,41 @@ func (ctx *Context) GetCookie(key string) string { func (ctx *Context) SetCookie(name string, value string, others ...interface{}) { ctx.Output.Cookie(name, value, others...) } + +// Get secure cookie from request by a given key. +func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) { + val := ctx.Input.Cookie(key) + if val == "" { + return "", false + } + + parts := strings.SplitN(val, "|", 3) + + if len(parts) != 3 { + return "", false + } + + vs := parts[0] + timestamp := parts[1] + sig := parts[2] + + h := hmac.New(sha1.New, []byte(Secret)) + fmt.Fprintf(h, "%s%s", vs, timestamp) + + if fmt.Sprintf("%02x", h.Sum(nil)) != sig { + return "", false + } + res, _ := base64.URLEncoding.DecodeString(vs) + return string(res), true +} + +// Set Secure cookie for response. +func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) { + vs := base64.URLEncoding.EncodeToString([]byte(value)) + timestamp := strconv.FormatInt(time.Now().UnixNano(), 10) + h := hmac.New(sha1.New, []byte(Secret)) + fmt.Fprintf(h, "%s%s", vs, timestamp) + sig := fmt.Sprintf("%02x", h.Sum(nil)) + cookie := strings.Join([]string{vs, timestamp, sig}, "|") + ctx.Output.Cookie(name, cookie, others...) +} diff --git a/context/output.go b/context/output.go index 43948e15..a8a304b6 100644 --- a/context/output.go +++ b/context/output.go @@ -77,39 +77,77 @@ func (output *BeegoOutput) Cookie(name string, value string, others ...interface var b bytes.Buffer fmt.Fprintf(&b, "%s=%s", sanitizeName(name), sanitizeValue(value)) if len(others) > 0 { - switch others[0].(type) { + switch v := others[0].(type) { case int: - if others[0].(int) > 0 { - fmt.Fprintf(&b, "; Max-Age=%d", others[0].(int)) - } else if others[0].(int) < 0 { + if v > 0 { + fmt.Fprintf(&b, "; Max-Age=%d", v) + } else if v < 0 { fmt.Fprintf(&b, "; Max-Age=0") } case int64: - if others[0].(int64) > 0 { - fmt.Fprintf(&b, "; Max-Age=%d", others[0].(int64)) - } else if others[0].(int64) < 0 { + if v > 0 { + fmt.Fprintf(&b, "; Max-Age=%d", v) + } else if v < 0 { fmt.Fprintf(&b, "; Max-Age=0") } case int32: - if others[0].(int32) > 0 { - fmt.Fprintf(&b, "; Max-Age=%d", others[0].(int32)) - } else if others[0].(int32) < 0 { + if v > 0 { + fmt.Fprintf(&b, "; Max-Age=%d", v) + } else if v < 0 { fmt.Fprintf(&b, "; Max-Age=0") } } } + + // the settings below + // Path, Domain, Secure, HttpOnly + // can use nil skip set + + // default "/" if len(others) > 1 { - fmt.Fprintf(&b, "; Path=%s", sanitizeValue(others[1].(string))) + if v, ok := others[1].(string); ok && len(v) > 0 { + fmt.Fprintf(&b, "; Path=%s", sanitizeValue(v)) + } + } else { + fmt.Fprintf(&b, "; Path=%s", "/") } + + // default empty if len(others) > 2 { - fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(others[2].(string))) + if v, ok := others[2].(string); ok && len(v) > 0 { + fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(v)) + } } + + // default empty if len(others) > 3 { - fmt.Fprintf(&b, "; Secure") + var secure bool + switch v := others[3].(type) { + case bool: + secure = v + default: + if others[3] != nil { + secure = true + } + } + if secure { + fmt.Fprintf(&b, "; Secure") + } } + + // default true + httponly := true if len(others) > 4 { + if v, ok := others[4].(bool); ok && !v || others[4] == nil { + // HttpOnly = false + httponly = false + } + } + + if httponly { fmt.Fprintf(&b, "; HttpOnly") } + output.Context.ResponseWriter.Header().Add("Set-Cookie", b.String()) } diff --git a/controller.go b/controller.go index 9d783747..c9ad10e9 100644 --- a/controller.go +++ b/controller.go @@ -2,11 +2,7 @@ package beego import ( "bytes" - "crypto/hmac" - "crypto/sha1" - "encoding/base64" "errors" - "fmt" "html/template" "io" "io/ioutil" @@ -17,7 +13,6 @@ import ( "reflect" "strconv" "strings" - "time" "github.com/astaxie/beego/context" "github.com/astaxie/beego/session" @@ -313,11 +308,11 @@ func (c *Controller) GetString(key string) string { // GetStrings returns the input string slice by key string. // it's designed for multi-value input field such as checkbox(input[type=checkbox]), multi-selection. func (c *Controller) GetStrings(key string) []string { - r := c.Ctx.Request - if r.Form == nil { + f := c.Input() + if f == nil { return []string{} } - vs := r.Form[key] + vs := f[key] if len(vs) > 0 { return vs } @@ -417,40 +412,12 @@ func (c *Controller) IsAjax() bool { // GetSecureCookie returns decoded cookie value from encoded browser cookie values. func (c *Controller) GetSecureCookie(Secret, key string) (string, bool) { - val := c.Ctx.GetCookie(key) - if val == "" { - return "", false - } - - parts := strings.SplitN(val, "|", 3) - - if len(parts) != 3 { - return "", false - } - - vs := parts[0] - timestamp := parts[1] - sig := parts[2] - - h := hmac.New(sha1.New, []byte(Secret)) - fmt.Fprintf(h, "%s%s", vs, timestamp) - - if fmt.Sprintf("%02x", h.Sum(nil)) != sig { - return "", false - } - res, _ := base64.URLEncoding.DecodeString(vs) - return string(res), true + return c.Ctx.GetSecureCookie(Secret, key) } // SetSecureCookie puts value into cookie after encoded the value. -func (c *Controller) SetSecureCookie(Secret, name, val string, age int64) { - vs := base64.URLEncoding.EncodeToString([]byte(val)) - timestamp := strconv.FormatInt(time.Now().UnixNano(), 10) - h := hmac.New(sha1.New, []byte(Secret)) - fmt.Fprintf(h, "%s%s", vs, timestamp) - sig := fmt.Sprintf("%02x", h.Sum(nil)) - cookie := strings.Join([]string{vs, timestamp, sig}, "|") - c.Ctx.SetCookie(name, cookie, age, "/") +func (c *Controller) SetSecureCookie(Secret, name, value string, others ...interface{}) { + c.Ctx.SetSecureCookie(Secret, name, value, others...) } // XsrfToken creates a xsrf token string and returns. diff --git a/httplib/httplib.go b/httplib/httplib.go index 1462ee45..d313c603 100644 --- a/httplib/httplib.go +++ b/httplib/httplib.go @@ -24,7 +24,7 @@ func Get(url string) *BeegoHttpRequest { req.Method = "GET" req.Header = http.Header{} req.Header.Set("User-Agent", defaultUserAgent) - return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil} + return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil} } // Post returns *BeegoHttpRequest with POST method. @@ -33,7 +33,7 @@ func Post(url string) *BeegoHttpRequest { req.Method = "POST" req.Header = http.Header{} req.Header.Set("User-Agent", defaultUserAgent) - return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil} + return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil} } // Put returns *BeegoHttpRequest with PUT method. @@ -42,7 +42,7 @@ func Put(url string) *BeegoHttpRequest { req.Method = "PUT" req.Header = http.Header{} req.Header.Set("User-Agent", defaultUserAgent) - return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil} + return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil} } // Delete returns *BeegoHttpRequest DELETE GET method. @@ -51,7 +51,7 @@ func Delete(url string) *BeegoHttpRequest { req.Method = "DELETE" req.Header = http.Header{} req.Header.Set("User-Agent", defaultUserAgent) - return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil} + return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil} } // Head returns *BeegoHttpRequest with HEAD method. @@ -60,7 +60,7 @@ func Head(url string) *BeegoHttpRequest { req.Method = "HEAD" req.Header = http.Header{} req.Header.Set("User-Agent", defaultUserAgent) - return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil} + return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil} } // BeegoHttpRequest provides more useful methods for requesting one url than http.Request. @@ -72,6 +72,8 @@ type BeegoHttpRequest struct { connectTimeout time.Duration readWriteTimeout time.Duration tlsClientConfig *tls.Config + proxy func(*http.Request) (*url.URL, error) + transport http.RoundTripper } // Debug sets show debug or not when executing request. @@ -105,6 +107,24 @@ func (b *BeegoHttpRequest) SetCookie(cookie *http.Cookie) *BeegoHttpRequest { return b } +// Set transport to +func (b *BeegoHttpRequest) SetTransport(transport http.RoundTripper) *BeegoHttpRequest { + b.transport = transport + return b +} + +// Set http proxy +// example: +// +// func(req *http.Request) (*url.URL, error) { +// u, _ := url.ParseRequestURI("http://127.0.0.1:8118") +// return u, nil +// } +func (b *BeegoHttpRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHttpRequest { + b.proxy = proxy + return b +} + // Param adds query param in to request. // params build query string as ?key1=value1&key2=value2... func (b *BeegoHttpRequest) Param(key, value string) *BeegoHttpRequest { @@ -171,12 +191,34 @@ func (b *BeegoHttpRequest) getResponse() (*http.Response, error) { println(string(dump)) } - client := &http.Client{ - Transport: &http.Transport{ + trans := b.transport + + if trans == nil { + // create default transport + trans = &http.Transport{ TLSClientConfig: b.tlsClientConfig, + Proxy: b.proxy, Dial: TimeoutDialer(b.connectTimeout, b.readWriteTimeout), - }, + } + } else { + // if b.transport is *http.Transport then set the settings. + if t, ok := trans.(*http.Transport); ok { + if t.TLSClientConfig == nil { + t.TLSClientConfig = b.tlsClientConfig + } + if t.Proxy == nil { + t.Proxy = b.proxy + } + if t.Dial == nil { + t.Dial = TimeoutDialer(b.connectTimeout, b.readWriteTimeout) + } + } } + + client := &http.Client{ + Transport: trans, + } + resp, err := client.Do(b.req) if err != nil { return nil, err diff --git a/orm/db.go b/orm/db.go index 60e53765..dfb53621 100644 --- a/orm/db.go +++ b/orm/db.go @@ -35,7 +35,7 @@ var ( "istartswith": true, "iendswith": true, "in": true, - // "range": true, + "between": true, // "year": true, // "month": true, // "day": true, @@ -103,15 +103,36 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val } else { switch fi.fieldType { case TypeBooleanField: - value = field.Bool() - case TypeCharField, TypeTextField: - value = field.String() - case TypeFloatField, TypeDecimalField: - vu := field.Interface() - if _, ok := vu.(float32); ok { - value, _ = StrTo(ToStr(vu)).Float64() + if nb, ok := field.Interface().(sql.NullBool); ok { + value = nil + if nb.Valid { + value = nb.Bool + } } else { - value = field.Float() + value = field.Bool() + } + case TypeCharField, TypeTextField: + if ns, ok := field.Interface().(sql.NullString); ok { + value = nil + if ns.Valid { + value = ns.String + } + } else { + value = field.String() + } + case TypeFloatField, TypeDecimalField: + if nf, ok := field.Interface().(sql.NullFloat64); ok { + value = nil + if nf.Valid { + value = nf.Float64 + } + } else { + vu := field.Interface() + if _, ok := vu.(float32); ok { + value, _ = StrTo(ToStr(vu)).Float64() + } else { + value = field.Float() + } } case TypeDateField, TypeDateTimeField: value = field.Interface() @@ -124,7 +145,14 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val case fi.fieldType&IsPostiveIntegerField > 0: value = field.Uint() case fi.fieldType&IsIntegerField > 0: - value = field.Int() + if ni, ok := field.Interface().(sql.NullInt64); ok { + value = nil + if ni.Valid { + value = ni.Int64 + } + } else { + value = field.Int() + } case fi.fieldType&IsRelField > 0: if field.IsNil() { value = nil @@ -144,6 +172,11 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val switch fi.fieldType { case TypeDateField, TypeDateTimeField: if fi.auto_now || fi.auto_now_add && insert { + if insert { + if t, ok := value.(time.Time); ok && !t.IsZero() { + break + } + } tnow := time.Now() d.ins.TimeToDB(&tnow, tz) value = tnow @@ -883,13 +916,19 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri } arg := params[0] - if operator == "in" { + switch operator { + case "in": marks := make([]string, len(params)) for i, _ := range marks { marks[i] = "?" } sql = fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) - } else { + case "between": + if len(params) != 2 { + panic(fmt.Errorf("operator `%s` need 2 args not %d", operator, len(params))) + } + sql = "BETWEEN ? AND ?" + default: if len(params) > 1 { panic(fmt.Errorf("operator `%s` need 1 args not %d", operator, len(params))) } @@ -1117,17 +1156,37 @@ setValue: switch { case fieldType == TypeBooleanField: if isNative { - if value == nil { - value = false + if nb, ok := field.Interface().(sql.NullBool); ok { + if value == nil { + nb.Valid = false + } else { + nb.Bool = value.(bool) + nb.Valid = true + } + field.Set(reflect.ValueOf(nb)) + } else { + if value == nil { + value = false + } + field.SetBool(value.(bool)) } - field.SetBool(value.(bool)) } case fieldType == TypeCharField || fieldType == TypeTextField: if isNative { - if value == nil { - value = "" + if ns, ok := field.Interface().(sql.NullString); ok { + if value == nil { + ns.Valid = false + } else { + ns.String = value.(string) + ns.Valid = true + } + field.Set(reflect.ValueOf(ns)) + } else { + if value == nil { + value = "" + } + field.SetString(value.(string)) } - field.SetString(value.(string)) } case fieldType == TypeDateField || fieldType == TypeDateTimeField: if isNative { @@ -1146,18 +1205,39 @@ setValue: } } else { if isNative { - if value == nil { - value = int64(0) + if ni, ok := field.Interface().(sql.NullInt64); ok { + if value == nil { + ni.Valid = false + } else { + ni.Int64 = value.(int64) + ni.Valid = true + } + field.Set(reflect.ValueOf(ni)) + } else { + if value == nil { + value = int64(0) + } + field.SetInt(value.(int64)) } - field.SetInt(value.(int64)) } } case fieldType == TypeFloatField || fieldType == TypeDecimalField: if isNative { - if value == nil { - value = float64(0) + if nf, ok := field.Interface().(sql.NullFloat64); ok { + if value == nil { + nf.Valid = false + } else { + nf.Float64 = value.(float64) + nf.Valid = true + } + field.Set(reflect.ValueOf(nf)) + } else { + + if value == nil { + value = float64(0) + } + field.SetFloat(value.(float64)) } - field.SetFloat(value.(float64)) } case fieldType&IsRelField > 0: if value != nil { diff --git a/orm/db_alias.go b/orm/db_alias.go index 22066514..6a6623cc 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -168,7 +168,7 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { } if dataBaseCache.add(aliasName, al) == false { - return nil, fmt.Errorf("db name `%s` already registered, cannot reuse", aliasName) + return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) } return al, nil @@ -239,7 +239,7 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error { if al, ok := dataBaseCache.get(aliasName); ok { al.TZ = tz } else { - return fmt.Errorf("DataBase name `%s` not registered\n", aliasName) + return fmt.Errorf("DataBase alias name `%s` not registered\n", aliasName) } return nil } @@ -260,3 +260,19 @@ func SetMaxOpenConns(aliasName string, maxOpenConns int) { fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)}) } } + +// Get *sql.DB from registered database by db alias name. +// Use "default" as alias name if you not set. +func GetDB(aliasNames ...string) (*sql.DB, error) { + var name string + if len(aliasNames) > 0 { + name = aliasNames[0] + } else { + name = "default" + } + if al, ok := dataBaseCache.get(name); ok { + return al.DB, nil + } else { + return nil, fmt.Errorf("DataBase of alias name `%s` not found\n", name) + } +} diff --git a/orm/models.go b/orm/models.go index 5744d865..59a8a8a1 100644 --- a/orm/models.go +++ b/orm/models.go @@ -98,3 +98,9 @@ func (mc *_modelCache) clean() { mc.cacheByFN = make(map[string]*modelInfo) mc.done = false } + +// Clean model cache. Then you can re-RegisterModel. +// Common use this api for test case. +func ResetModelCache() { + modelCache.clean() +} diff --git a/orm/models_test.go b/orm/models_test.go index 706f04dc..168c091a 100644 --- a/orm/models_test.go +++ b/orm/models_test.go @@ -1,6 +1,7 @@ package orm import ( + "database/sql" "encoding/json" "fmt" "os" @@ -116,27 +117,31 @@ type Data struct { } type DataNull struct { - Id int - Boolean bool `orm:"null"` - Char string `orm:"null;size(50)"` - Text string `orm:"null;type(text)"` - Date time.Time `orm:"null;type(date)"` - DateTime time.Time `orm:"null;column(datetime)""` - Byte byte `orm:"null"` - Rune rune `orm:"null"` - Int int `orm:"null"` - Int8 int8 `orm:"null"` - Int16 int16 `orm:"null"` - Int32 int32 `orm:"null"` - Int64 int64 `orm:"null"` - Uint uint `orm:"null"` - Uint8 uint8 `orm:"null"` - Uint16 uint16 `orm:"null"` - Uint32 uint32 `orm:"null"` - Uint64 uint64 `orm:"null"` - Float32 float32 `orm:"null"` - Float64 float64 `orm:"null"` - Decimal float64 `orm:"digits(8);decimals(4);null"` + Id int + Boolean bool `orm:"null"` + Char string `orm:"null;size(50)"` + Text string `orm:"null;type(text)"` + Date time.Time `orm:"null;type(date)"` + DateTime time.Time `orm:"null;column(datetime)""` + Byte byte `orm:"null"` + Rune rune `orm:"null"` + Int int `orm:"null"` + Int8 int8 `orm:"null"` + Int16 int16 `orm:"null"` + Int32 int32 `orm:"null"` + Int64 int64 `orm:"null"` + Uint uint `orm:"null"` + Uint8 uint8 `orm:"null"` + Uint16 uint16 `orm:"null"` + Uint32 uint32 `orm:"null"` + Uint64 uint64 `orm:"null"` + Float32 float32 `orm:"null"` + Float64 float64 `orm:"null"` + Decimal float64 `orm:"digits(8);decimals(4);null"` + NullString sql.NullString `orm:"null"` + NullBool sql.NullBool `orm:"null"` + NullFloat64 sql.NullFloat64 `orm:"null"` + NullInt64 sql.NullInt64 `orm:"null"` } // only for mysql @@ -303,9 +308,8 @@ go test -v github.com/astaxie/beego/orm #### Sqlite3 -touch /path/to/orm_test.db export ORM_DRIVER=sqlite3 -export ORM_SOURCE=/path/to/orm_test.db +export ORM_SOURCE='file:memory_test?mode=memory' go test -v github.com/astaxie/beego/orm diff --git a/orm/models_utils.go b/orm/models_utils.go index 1466a724..759093ef 100644 --- a/orm/models_utils.go +++ b/orm/models_utils.go @@ -1,6 +1,7 @@ package orm import ( + "database/sql" "fmt" "reflect" "strings" @@ -98,30 +99,29 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col // return field type as type constant from reflect.Value func getFieldType(val reflect.Value) (ft int, err error) { elm := reflect.Indirect(val) - switch elm.Kind() { - case reflect.Int8: + switch elm.Interface().(type) { + case int8: ft = TypeBitField - case reflect.Int16: + case int16: ft = TypeSmallIntegerField - case reflect.Int32, reflect.Int: + case int32, int: ft = TypeIntegerField - case reflect.Int64: + case int64, sql.NullInt64: ft = TypeBigIntegerField - case reflect.Uint8: + case uint8: ft = TypePositiveBitField - case reflect.Uint16: + case uint16: ft = TypePositiveSmallIntegerField - case reflect.Uint32, reflect.Uint: + case uint32, uint: ft = TypePositiveIntegerField - case reflect.Uint64: + case uint64: ft = TypePositiveBigIntegerField - case reflect.Float32, reflect.Float64: + case float32, float64, sql.NullFloat64: ft = TypeFloatField - case reflect.Bool: + case bool, sql.NullBool: ft = TypeBooleanField - case reflect.String: + case string, sql.NullString: ft = TypeCharField - case reflect.Invalid: default: if elm.CanInterface() { if _, ok := elm.Interface().(time.Time); ok { diff --git a/orm/orm_test.go b/orm/orm_test.go index c951d5ca..69f2fc86 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -2,6 +2,7 @@ package orm import ( "bytes" + "database/sql" "fmt" "io/ioutil" "os" @@ -138,6 +139,15 @@ func throwFailNow(t *testing.T, err error, args ...interface{}) { } } +func TestGetDB(t *testing.T) { + if db, err := GetDB(); err != nil { + throwFailNow(t, err) + } else { + err = db.Ping() + throwFailNow(t, err) + } +} + func TestSyncDb(t *testing.T) { RegisterModel(new(Data), new(DataNull)) RegisterModel(new(User)) @@ -258,12 +268,45 @@ func TestNullDataTypes(t *testing.T) { err = dORM.Read(&d) throwFail(t, err) + throwFail(t, AssertIs(d.NullBool.Valid, false)) + throwFail(t, AssertIs(d.NullString.Valid, false)) + throwFail(t, AssertIs(d.NullInt64.Valid, false)) + throwFail(t, AssertIs(d.NullFloat64.Valid, false)) + _, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec() throwFail(t, err) d = DataNull{Id: 2} err = dORM.Read(&d) throwFail(t, err) + + d = DataNull{ + DateTime: time.Now(), + NullString: sql.NullString{"test", true}, + NullBool: sql.NullBool{true, true}, + NullInt64: sql.NullInt64{42, true}, + NullFloat64: sql.NullFloat64{42.42, true}, + } + + id, err = dORM.Insert(&d) + throwFail(t, err) + throwFail(t, AssertIs(id, 3)) + + d = DataNull{Id: 3} + err = dORM.Read(&d) + throwFail(t, err) + + throwFail(t, AssertIs(d.NullBool.Valid, true)) + throwFail(t, AssertIs(d.NullBool.Bool, true)) + + throwFail(t, AssertIs(d.NullString.Valid, true)) + throwFail(t, AssertIs(d.NullString.String, "test")) + + throwFail(t, AssertIs(d.NullInt64.Valid, true)) + throwFail(t, AssertIs(d.NullInt64.Int64, 42)) + + throwFail(t, AssertIs(d.NullFloat64.Valid, true)) + throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42)) } func TestCRUD(t *testing.T) { @@ -619,6 +662,14 @@ func TestOperators(t *testing.T) { num, err = qs.Filter("status__in", []*int{&n1}, &n2).Count() throwFail(t, err) throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("id__between", 2, 3).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("id__between", []int{2, 3}).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) } func TestSetCond(t *testing.T) { @@ -1577,7 +1628,6 @@ func TestDelete(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(num, 4)) - fmt.Println("...") qs = dORM.QueryTable("comment") num, err = qs.Filter("Post__User", 3).Delete() throwFail(t, err) @@ -1646,10 +1696,10 @@ func TestTransaction(t *testing.T) { func TestReadOrCreate(t *testing.T) { u := &User{ UserName: "Kyle", - Email: "kylemcc@gmail.com", + Email: "kylemcc@gmail.com", Password: "other_pass", - Status: 7, - IsStaff: false, + Status: 7, + IsStaff: false, IsActive: true, } diff --git a/plugins/auth/basic.go b/plugins/auth/basic.go index 33577d72..5838acc2 100644 --- a/plugins/auth/basic.go +++ b/plugins/auth/basic.go @@ -8,7 +8,7 @@ package auth // } // return false // } -// authPlugin := auth.NewBasicAuthenticator(SecretAuth) +// authPlugin := auth.NewBasicAuthenticator(SecretAuth, "My Realm") // beego.AddFilter("*","AfterStatic",authPlugin) import ( diff --git a/router.go b/router.go index bb083768..62f3a1c5 100644 --- a/router.go +++ b/router.go @@ -8,6 +8,7 @@ import ( "net/http" "net/url" "os" + "path" "reflect" "regexp" "runtime" @@ -545,14 +546,26 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) //static file server for prefix, staticDir := range StaticDir { + if len(prefix) == 0 { + continue + } if r.URL.Path == "/favicon.ico" { - file := staticDir + r.URL.Path - http.ServeFile(w, r, file) - w.started = true - goto Admin + file := path.Join(staticDir, r.URL.Path) + if utils.FileExists(file) { + http.ServeFile(w, r, file) + w.started = true + goto Admin + } } if strings.HasPrefix(r.URL.Path, prefix) { - file := staticDir + r.URL.Path[len(prefix):] + if len(r.URL.Path) > len(prefix) && r.URL.Path[len(prefix)] != '/' { + continue + } + if r.URL.Path == prefix && prefix[len(prefix)-1] != '/' { + http.Redirect(rw, r, r.URL.Path+"/", 302) + goto Admin + } + file := path.Join(staticDir, r.URL.Path[len(prefix):]) finfo, err := os.Stat(file) if err != nil { if RunMode == "dev" { diff --git a/session/sess_couchbase.go b/session/sess_couchbase.go new file mode 100644 index 00000000..74b2242c --- /dev/null +++ b/session/sess_couchbase.go @@ -0,0 +1,203 @@ +package session + +import ( + "github.com/couchbaselabs/go-couchbase" + "net/http" + "strings" + "sync" +) + +var couchbpder = &CouchbaseProvider{} + +type CouchbaseSessionStore struct { + b *couchbase.Bucket + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxlifetime int64 +} + +type CouchbaseProvider struct { + maxlifetime int64 + savePath string + pool string + bucket string + b *couchbase.Bucket +} + +func (cs *CouchbaseSessionStore) Set(key, value interface{}) error { + cs.lock.Lock() + defer cs.lock.Unlock() + cs.values[key] = value + return nil +} + +func (cs *CouchbaseSessionStore) Get(key interface{}) interface{} { + cs.lock.RLock() + defer cs.lock.RUnlock() + if v, ok := cs.values[key]; ok { + return v + } else { + return nil + } + return nil +} + +func (cs *CouchbaseSessionStore) Delete(key interface{}) error { + cs.lock.Lock() + defer cs.lock.Unlock() + delete(cs.values, key) + return nil +} + +func (cs *CouchbaseSessionStore) Flush() error { + cs.lock.Lock() + defer cs.lock.Unlock() + cs.values = make(map[interface{}]interface{}) + return nil +} + +func (cs *CouchbaseSessionStore) SessionID() string { + return cs.sid +} + +func (cs *CouchbaseSessionStore) SessionRelease(w http.ResponseWriter) { + defer cs.b.Close() + + // if rs.values is empty, return directly + if len(cs.values) < 1 { + cs.b.Delete(cs.sid) + return + } + + bo, err := encodeGob(cs.values) + if err != nil { + return + } + + cs.b.Set(cs.sid, int(cs.maxlifetime), bo) +} + +func (cp *CouchbaseProvider) getBucket() *couchbase.Bucket { + c, err := couchbase.Connect(cp.savePath) + if err != nil { + return nil + } + + pool, err := c.GetPool(cp.pool) + if err != nil { + return nil + } + + bucket, err := pool.GetBucket(cp.bucket) + if err != nil { + return nil + } + + return bucket +} + +// init couchbase session +// savepath like couchbase server REST/JSON URL +// e.g. http://host:port/, Pool, Bucket +func (cp *CouchbaseProvider) SessionInit(maxlifetime int64, savePath string) error { + cp.maxlifetime = maxlifetime + configs := strings.Split(savePath, ",") + if len(configs) > 0 { + cp.savePath = configs[0] + } + if len(configs) > 1 { + cp.pool = configs[1] + } + if len(configs) > 2 { + cp.bucket = configs[2] + } + + return nil +} + +// read couchbase session by sid +func (cp *CouchbaseProvider) SessionRead(sid string) (SessionStore, error) { + cp.b = cp.getBucket() + + var doc []byte + + err := cp.b.Get(sid, &doc) + var kv map[interface{}]interface{} + if doc == nil { + kv = make(map[interface{}]interface{}) + } else { + kv, err = decodeGob(doc) + if err != nil { + return nil, err + } + } + + cs := &CouchbaseSessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime} + return cs, nil +} + +func (cp *CouchbaseProvider) SessionExist(sid string) bool { + cp.b = cp.getBucket() + defer cp.b.Close() + + var doc []byte + + if err := cp.b.Get(sid, &doc); err != nil || doc == nil { + return false + } else { + return true + } +} + +func (cp *CouchbaseProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) { + cp.b = cp.getBucket() + + var doc []byte + if err := cp.b.Get(oldsid, &doc); err != nil || doc == nil { + cp.b.Set(sid, int(cp.maxlifetime), "") + } else { + err := cp.b.Delete(oldsid) + if err != nil { + return nil, err + } + _, _ = cp.b.Add(sid, int(cp.maxlifetime), doc) + } + + err := cp.b.Get(sid, &doc) + if err != nil { + return nil, err + } + var kv map[interface{}]interface{} + if doc == nil { + kv = make(map[interface{}]interface{}) + } else { + kv, err = decodeGob(doc) + if err != nil { + return nil, err + } + } + + cs := &CouchbaseSessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime} + return cs, nil +} + +func (cp *CouchbaseProvider) SessionDestroy(sid string) error { + cp.b = cp.getBucket() + defer cp.b.Close() + + cp.b.Delete(sid) + return nil +} + +func (cp *CouchbaseProvider) SessionGC() { + return +} + +func (cp *CouchbaseProvider) SessionAll() int { + return 0 +} + +func init() { + Register("couchbase", couchbpder) +} diff --git a/session/sess_file.go b/session/sess_file.go index e8746532..7e9e2229 100644 --- a/session/sess_file.go +++ b/session/sess_file.go @@ -152,8 +152,7 @@ func (fp *FileProvider) SessionExist(sid string) bool { func (fp *FileProvider) SessionDestroy(sid string) error { filepder.lock.Lock() defer filepder.lock.Unlock() - - os.Remove(path.Join(fp.savePath)) + os.Remove(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) return nil } diff --git a/session/sess_redis.go b/session/sess_redis.go index 3c51b793..e64d4c90 100644 --- a/session/sess_redis.go +++ b/session/sess_redis.go @@ -129,7 +129,8 @@ func (rp *RedisProvider) SessionInit(maxlifetime int64, savePath string) error { } return c, err }, rp.poolsize) - return nil + + return rp.poollist.Get().Err() } // read redis session by sid diff --git a/session/session.go b/session/session.go index d1a44538..bc8832b0 100644 --- a/session/session.go +++ b/session/session.go @@ -56,11 +56,10 @@ type managerConfig struct { EnableSetCookie bool `json:"enableSetCookie,omitempty"` Gclifetime int64 `json:"gclifetime"` Maxlifetime int64 `json:"maxLifetime"` - Maxage int `json:"maxage"` Secure bool `json:"secure"` SessionIDHashFunc string `json:"sessionIDHashFunc"` SessionIDHashKey string `json:"sessionIDHashKey"` - CookieLifeTime int64 `json:"cookieLifeTime"` + CookieLifeTime int `json:"cookieLifeTime"` ProviderConfig string `json:"providerConfig"` } @@ -125,8 +124,8 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se Path: "/", HttpOnly: true, Secure: manager.config.Secure} - if manager.config.Maxage >= 0 { - cookie.MaxAge = manager.config.Maxage + if manager.config.CookieLifeTime >= 0 { + cookie.MaxAge = manager.config.CookieLifeTime } if manager.config.EnableSetCookie { http.SetCookie(w, cookie) @@ -144,8 +143,8 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se Path: "/", HttpOnly: true, Secure: manager.config.Secure} - if manager.config.Maxage >= 0 { - cookie.MaxAge = manager.config.Maxage + if manager.config.CookieLifeTime >= 0 { + cookie.MaxAge = manager.config.CookieLifeTime } if manager.config.EnableSetCookie { http.SetCookie(w, cookie) @@ -206,8 +205,8 @@ func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Reque cookie.HttpOnly = true cookie.Path = "/" } - if manager.config.Maxage >= 0 { - cookie.MaxAge = manager.config.Maxage + if manager.config.CookieLifeTime >= 0 { + cookie.MaxAge = manager.config.CookieLifeTime } http.SetCookie(w, cookie) r.AddCookie(cookie) diff --git a/utils/captcha/captcha.go b/utils/captcha/captcha.go index 14979d78..f3998733 100644 --- a/utils/captcha/captcha.go +++ b/utils/captcha/captcha.go @@ -67,7 +67,7 @@ const ( fieldIdName = "captcha_id" fieldCaptchaName = "captcha" cachePrefix = "captcha_" - urlPrefix = "/captcha/" + defaultURLPrefix = "/captcha/" ) // Captcha struct @@ -76,7 +76,7 @@ type Captcha struct { store cache.Cache // url prefix for captcha image - urlPrefix string + URLPrefix string // specify captcha id input field name FieldIdName string @@ -155,7 +155,7 @@ func (c *Captcha) CreateCaptchaHtml() template.HTML { return template.HTML(fmt.Sprintf(``+ ``+ ``+ - ``, c.FieldIdName, value, c.urlPrefix, value, c.urlPrefix, value)) + ``, c.FieldIdName, value, c.URLPrefix, value, c.URLPrefix, value)) } // create a new captcha id @@ -224,14 +224,14 @@ func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha { cpt.StdHeight = stdHeight if len(urlPrefix) == 0 { - urlPrefix = urlPrefix + urlPrefix = defaultURLPrefix } if urlPrefix[len(urlPrefix)-1] != '/' { urlPrefix += "/" } - cpt.urlPrefix = urlPrefix + cpt.URLPrefix = urlPrefix return cpt } @@ -242,7 +242,7 @@ func NewWithFilter(urlPrefix string, store cache.Cache) *Captcha { cpt := NewCaptcha(urlPrefix, store) // create filter for serve captcha image - beego.AddFilter(urlPrefix+":", "BeforeRouter", cpt.Handler) + beego.AddFilter(cpt.URLPrefix+":", "BeforeRouter", cpt.Handler) // add to template func map beego.AddFuncMap("create_captcha", cpt.CreateCaptchaHtml) diff --git a/validation/validators.go b/validation/validators.go index 6abc7b12..d198d442 100644 --- a/validation/validators.go +++ b/validation/validators.go @@ -443,7 +443,7 @@ func (b Base64) GetLimitValue() interface{} { } // just for chinese mobile phone number -var mobilePattern = regexp.MustCompile("^((\\+86)|(86))?(1(([35][0-9])|(47)|[8][01236789]))\\d{8}$") +var mobilePattern = regexp.MustCompile("^((\\+86)|(86))?(1(([35][0-9])|(47)|[8][012356789]))\\d{8}$") type Mobile struct { Match