diff --git a/beego.go b/beego.go
index 29cd45a0..e658c3b5 100644
--- a/beego.go
+++ b/beego.go
@@ -12,7 +12,7 @@ import (
)
// beego web framework version.
-const VERSION = "1.1.0"
+const VERSION = "1.1.1"
type hookfunc func() error //hook function to run
var hooks []hookfunc //hook function slice to store the hookfunc
@@ -28,12 +28,12 @@ type GroupRouters []groupRouter
// Get a new GroupRouters
func NewGroupRouters() GroupRouters {
- return make([]groupRouter, 0)
+ return make(GroupRouters, 0)
}
// Add Router in the GroupRouters
// it is for plugin or module to register router
-func (gr GroupRouters) AddRouter(pattern string, c ControllerInterface, mappingMethod ...string) {
+func (gr *GroupRouters) AddRouter(pattern string, c ControllerInterface, mappingMethod ...string) {
var newRG groupRouter
if len(mappingMethod) > 0 {
newRG = groupRouter{
@@ -48,16 +48,16 @@ func (gr GroupRouters) AddRouter(pattern string, c ControllerInterface, mappingM
"",
}
}
- gr = append(gr, newRG)
+ *gr = append(*gr, newRG)
}
-func (gr GroupRouters) AddAuto(c ControllerInterface) {
+func (gr *GroupRouters) AddAuto(c ControllerInterface) {
newRG := groupRouter{
"",
c,
"",
}
- gr = append(gr, newRG)
+ *gr = append(*gr, newRG)
}
// AddGroupRouter with the prefix
diff --git a/cache/file.go b/cache/file.go
index 410da3a0..d750fecc 100644
--- a/cache/file.go
+++ b/cache/file.go
@@ -147,6 +147,8 @@ func (this *FileCache) Get(key string) interface{} {
// timeout means how long to keep this file, unit of ms.
// if timeout equals FileCacheEmbedExpiry(default is 0), cache this item forever.
func (this *FileCache) Put(key string, val interface{}, timeout int64) error {
+ gob.Register(val)
+
filename := this.getCacheFileName(key)
var item FileCacheItem
item.Data = val
diff --git a/context/context.go b/context/context.go
index ee308476..36d566b5 100644
--- a/context/context.go
+++ b/context/context.go
@@ -1,7 +1,14 @@
package context
import (
+ "crypto/hmac"
+ "crypto/sha1"
+ "encoding/base64"
+ "fmt"
"net/http"
+ "strconv"
+ "strings"
+ "time"
"github.com/astaxie/beego/middleware"
)
@@ -59,3 +66,41 @@ func (ctx *Context) GetCookie(key string) string {
func (ctx *Context) SetCookie(name string, value string, others ...interface{}) {
ctx.Output.Cookie(name, value, others...)
}
+
+// Get secure cookie from request by a given key.
+func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) {
+ val := ctx.Input.Cookie(key)
+ if val == "" {
+ return "", false
+ }
+
+ parts := strings.SplitN(val, "|", 3)
+
+ if len(parts) != 3 {
+ return "", false
+ }
+
+ vs := parts[0]
+ timestamp := parts[1]
+ sig := parts[2]
+
+ h := hmac.New(sha1.New, []byte(Secret))
+ fmt.Fprintf(h, "%s%s", vs, timestamp)
+
+ if fmt.Sprintf("%02x", h.Sum(nil)) != sig {
+ return "", false
+ }
+ res, _ := base64.URLEncoding.DecodeString(vs)
+ return string(res), true
+}
+
+// Set Secure cookie for response.
+func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) {
+ vs := base64.URLEncoding.EncodeToString([]byte(value))
+ timestamp := strconv.FormatInt(time.Now().UnixNano(), 10)
+ h := hmac.New(sha1.New, []byte(Secret))
+ fmt.Fprintf(h, "%s%s", vs, timestamp)
+ sig := fmt.Sprintf("%02x", h.Sum(nil))
+ cookie := strings.Join([]string{vs, timestamp, sig}, "|")
+ ctx.Output.Cookie(name, cookie, others...)
+}
diff --git a/context/output.go b/context/output.go
index 43948e15..a8a304b6 100644
--- a/context/output.go
+++ b/context/output.go
@@ -77,39 +77,77 @@ func (output *BeegoOutput) Cookie(name string, value string, others ...interface
var b bytes.Buffer
fmt.Fprintf(&b, "%s=%s", sanitizeName(name), sanitizeValue(value))
if len(others) > 0 {
- switch others[0].(type) {
+ switch v := others[0].(type) {
case int:
- if others[0].(int) > 0 {
- fmt.Fprintf(&b, "; Max-Age=%d", others[0].(int))
- } else if others[0].(int) < 0 {
+ if v > 0 {
+ fmt.Fprintf(&b, "; Max-Age=%d", v)
+ } else if v < 0 {
fmt.Fprintf(&b, "; Max-Age=0")
}
case int64:
- if others[0].(int64) > 0 {
- fmt.Fprintf(&b, "; Max-Age=%d", others[0].(int64))
- } else if others[0].(int64) < 0 {
+ if v > 0 {
+ fmt.Fprintf(&b, "; Max-Age=%d", v)
+ } else if v < 0 {
fmt.Fprintf(&b, "; Max-Age=0")
}
case int32:
- if others[0].(int32) > 0 {
- fmt.Fprintf(&b, "; Max-Age=%d", others[0].(int32))
- } else if others[0].(int32) < 0 {
+ if v > 0 {
+ fmt.Fprintf(&b, "; Max-Age=%d", v)
+ } else if v < 0 {
fmt.Fprintf(&b, "; Max-Age=0")
}
}
}
+
+ // the settings below
+ // Path, Domain, Secure, HttpOnly
+ // can use nil skip set
+
+ // default "/"
if len(others) > 1 {
- fmt.Fprintf(&b, "; Path=%s", sanitizeValue(others[1].(string)))
+ if v, ok := others[1].(string); ok && len(v) > 0 {
+ fmt.Fprintf(&b, "; Path=%s", sanitizeValue(v))
+ }
+ } else {
+ fmt.Fprintf(&b, "; Path=%s", "/")
}
+
+ // default empty
if len(others) > 2 {
- fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(others[2].(string)))
+ if v, ok := others[2].(string); ok && len(v) > 0 {
+ fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(v))
+ }
}
+
+ // default empty
if len(others) > 3 {
- fmt.Fprintf(&b, "; Secure")
+ var secure bool
+ switch v := others[3].(type) {
+ case bool:
+ secure = v
+ default:
+ if others[3] != nil {
+ secure = true
+ }
+ }
+ if secure {
+ fmt.Fprintf(&b, "; Secure")
+ }
}
+
+ // default true
+ httponly := true
if len(others) > 4 {
+ if v, ok := others[4].(bool); ok && !v || others[4] == nil {
+ // HttpOnly = false
+ httponly = false
+ }
+ }
+
+ if httponly {
fmt.Fprintf(&b, "; HttpOnly")
}
+
output.Context.ResponseWriter.Header().Add("Set-Cookie", b.String())
}
diff --git a/controller.go b/controller.go
index 9d783747..c9ad10e9 100644
--- a/controller.go
+++ b/controller.go
@@ -2,11 +2,7 @@ package beego
import (
"bytes"
- "crypto/hmac"
- "crypto/sha1"
- "encoding/base64"
"errors"
- "fmt"
"html/template"
"io"
"io/ioutil"
@@ -17,7 +13,6 @@ import (
"reflect"
"strconv"
"strings"
- "time"
"github.com/astaxie/beego/context"
"github.com/astaxie/beego/session"
@@ -313,11 +308,11 @@ func (c *Controller) GetString(key string) string {
// GetStrings returns the input string slice by key string.
// it's designed for multi-value input field such as checkbox(input[type=checkbox]), multi-selection.
func (c *Controller) GetStrings(key string) []string {
- r := c.Ctx.Request
- if r.Form == nil {
+ f := c.Input()
+ if f == nil {
return []string{}
}
- vs := r.Form[key]
+ vs := f[key]
if len(vs) > 0 {
return vs
}
@@ -417,40 +412,12 @@ func (c *Controller) IsAjax() bool {
// GetSecureCookie returns decoded cookie value from encoded browser cookie values.
func (c *Controller) GetSecureCookie(Secret, key string) (string, bool) {
- val := c.Ctx.GetCookie(key)
- if val == "" {
- return "", false
- }
-
- parts := strings.SplitN(val, "|", 3)
-
- if len(parts) != 3 {
- return "", false
- }
-
- vs := parts[0]
- timestamp := parts[1]
- sig := parts[2]
-
- h := hmac.New(sha1.New, []byte(Secret))
- fmt.Fprintf(h, "%s%s", vs, timestamp)
-
- if fmt.Sprintf("%02x", h.Sum(nil)) != sig {
- return "", false
- }
- res, _ := base64.URLEncoding.DecodeString(vs)
- return string(res), true
+ return c.Ctx.GetSecureCookie(Secret, key)
}
// SetSecureCookie puts value into cookie after encoded the value.
-func (c *Controller) SetSecureCookie(Secret, name, val string, age int64) {
- vs := base64.URLEncoding.EncodeToString([]byte(val))
- timestamp := strconv.FormatInt(time.Now().UnixNano(), 10)
- h := hmac.New(sha1.New, []byte(Secret))
- fmt.Fprintf(h, "%s%s", vs, timestamp)
- sig := fmt.Sprintf("%02x", h.Sum(nil))
- cookie := strings.Join([]string{vs, timestamp, sig}, "|")
- c.Ctx.SetCookie(name, cookie, age, "/")
+func (c *Controller) SetSecureCookie(Secret, name, value string, others ...interface{}) {
+ c.Ctx.SetSecureCookie(Secret, name, value, others...)
}
// XsrfToken creates a xsrf token string and returns.
diff --git a/httplib/httplib.go b/httplib/httplib.go
index 1462ee45..d313c603 100644
--- a/httplib/httplib.go
+++ b/httplib/httplib.go
@@ -24,7 +24,7 @@ func Get(url string) *BeegoHttpRequest {
req.Method = "GET"
req.Header = http.Header{}
req.Header.Set("User-Agent", defaultUserAgent)
- return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil}
+ return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil}
}
// Post returns *BeegoHttpRequest with POST method.
@@ -33,7 +33,7 @@ func Post(url string) *BeegoHttpRequest {
req.Method = "POST"
req.Header = http.Header{}
req.Header.Set("User-Agent", defaultUserAgent)
- return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil}
+ return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil}
}
// Put returns *BeegoHttpRequest with PUT method.
@@ -42,7 +42,7 @@ func Put(url string) *BeegoHttpRequest {
req.Method = "PUT"
req.Header = http.Header{}
req.Header.Set("User-Agent", defaultUserAgent)
- return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil}
+ return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil}
}
// Delete returns *BeegoHttpRequest DELETE GET method.
@@ -51,7 +51,7 @@ func Delete(url string) *BeegoHttpRequest {
req.Method = "DELETE"
req.Header = http.Header{}
req.Header.Set("User-Agent", defaultUserAgent)
- return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil}
+ return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil}
}
// Head returns *BeegoHttpRequest with HEAD method.
@@ -60,7 +60,7 @@ func Head(url string) *BeegoHttpRequest {
req.Method = "HEAD"
req.Header = http.Header{}
req.Header.Set("User-Agent", defaultUserAgent)
- return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil}
+ return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil}
}
// BeegoHttpRequest provides more useful methods for requesting one url than http.Request.
@@ -72,6 +72,8 @@ type BeegoHttpRequest struct {
connectTimeout time.Duration
readWriteTimeout time.Duration
tlsClientConfig *tls.Config
+ proxy func(*http.Request) (*url.URL, error)
+ transport http.RoundTripper
}
// Debug sets show debug or not when executing request.
@@ -105,6 +107,24 @@ func (b *BeegoHttpRequest) SetCookie(cookie *http.Cookie) *BeegoHttpRequest {
return b
}
+// Set transport to
+func (b *BeegoHttpRequest) SetTransport(transport http.RoundTripper) *BeegoHttpRequest {
+ b.transport = transport
+ return b
+}
+
+// Set http proxy
+// example:
+//
+// func(req *http.Request) (*url.URL, error) {
+// u, _ := url.ParseRequestURI("http://127.0.0.1:8118")
+// return u, nil
+// }
+func (b *BeegoHttpRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHttpRequest {
+ b.proxy = proxy
+ return b
+}
+
// Param adds query param in to request.
// params build query string as ?key1=value1&key2=value2...
func (b *BeegoHttpRequest) Param(key, value string) *BeegoHttpRequest {
@@ -171,12 +191,34 @@ func (b *BeegoHttpRequest) getResponse() (*http.Response, error) {
println(string(dump))
}
- client := &http.Client{
- Transport: &http.Transport{
+ trans := b.transport
+
+ if trans == nil {
+ // create default transport
+ trans = &http.Transport{
TLSClientConfig: b.tlsClientConfig,
+ Proxy: b.proxy,
Dial: TimeoutDialer(b.connectTimeout, b.readWriteTimeout),
- },
+ }
+ } else {
+ // if b.transport is *http.Transport then set the settings.
+ if t, ok := trans.(*http.Transport); ok {
+ if t.TLSClientConfig == nil {
+ t.TLSClientConfig = b.tlsClientConfig
+ }
+ if t.Proxy == nil {
+ t.Proxy = b.proxy
+ }
+ if t.Dial == nil {
+ t.Dial = TimeoutDialer(b.connectTimeout, b.readWriteTimeout)
+ }
+ }
}
+
+ client := &http.Client{
+ Transport: trans,
+ }
+
resp, err := client.Do(b.req)
if err != nil {
return nil, err
diff --git a/orm/db.go b/orm/db.go
index 60e53765..dfb53621 100644
--- a/orm/db.go
+++ b/orm/db.go
@@ -35,7 +35,7 @@ var (
"istartswith": true,
"iendswith": true,
"in": true,
- // "range": true,
+ "between": true,
// "year": true,
// "month": true,
// "day": true,
@@ -103,15 +103,36 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
} else {
switch fi.fieldType {
case TypeBooleanField:
- value = field.Bool()
- case TypeCharField, TypeTextField:
- value = field.String()
- case TypeFloatField, TypeDecimalField:
- vu := field.Interface()
- if _, ok := vu.(float32); ok {
- value, _ = StrTo(ToStr(vu)).Float64()
+ if nb, ok := field.Interface().(sql.NullBool); ok {
+ value = nil
+ if nb.Valid {
+ value = nb.Bool
+ }
} else {
- value = field.Float()
+ value = field.Bool()
+ }
+ case TypeCharField, TypeTextField:
+ if ns, ok := field.Interface().(sql.NullString); ok {
+ value = nil
+ if ns.Valid {
+ value = ns.String
+ }
+ } else {
+ value = field.String()
+ }
+ case TypeFloatField, TypeDecimalField:
+ if nf, ok := field.Interface().(sql.NullFloat64); ok {
+ value = nil
+ if nf.Valid {
+ value = nf.Float64
+ }
+ } else {
+ vu := field.Interface()
+ if _, ok := vu.(float32); ok {
+ value, _ = StrTo(ToStr(vu)).Float64()
+ } else {
+ value = field.Float()
+ }
}
case TypeDateField, TypeDateTimeField:
value = field.Interface()
@@ -124,7 +145,14 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
case fi.fieldType&IsPostiveIntegerField > 0:
value = field.Uint()
case fi.fieldType&IsIntegerField > 0:
- value = field.Int()
+ if ni, ok := field.Interface().(sql.NullInt64); ok {
+ value = nil
+ if ni.Valid {
+ value = ni.Int64
+ }
+ } else {
+ value = field.Int()
+ }
case fi.fieldType&IsRelField > 0:
if field.IsNil() {
value = nil
@@ -144,6 +172,11 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
switch fi.fieldType {
case TypeDateField, TypeDateTimeField:
if fi.auto_now || fi.auto_now_add && insert {
+ if insert {
+ if t, ok := value.(time.Time); ok && !t.IsZero() {
+ break
+ }
+ }
tnow := time.Now()
d.ins.TimeToDB(&tnow, tz)
value = tnow
@@ -883,13 +916,19 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri
}
arg := params[0]
- if operator == "in" {
+ switch operator {
+ case "in":
marks := make([]string, len(params))
for i, _ := range marks {
marks[i] = "?"
}
sql = fmt.Sprintf("IN (%s)", strings.Join(marks, ", "))
- } else {
+ case "between":
+ if len(params) != 2 {
+ panic(fmt.Errorf("operator `%s` need 2 args not %d", operator, len(params)))
+ }
+ sql = "BETWEEN ? AND ?"
+ default:
if len(params) > 1 {
panic(fmt.Errorf("operator `%s` need 1 args not %d", operator, len(params)))
}
@@ -1117,17 +1156,37 @@ setValue:
switch {
case fieldType == TypeBooleanField:
if isNative {
- if value == nil {
- value = false
+ if nb, ok := field.Interface().(sql.NullBool); ok {
+ if value == nil {
+ nb.Valid = false
+ } else {
+ nb.Bool = value.(bool)
+ nb.Valid = true
+ }
+ field.Set(reflect.ValueOf(nb))
+ } else {
+ if value == nil {
+ value = false
+ }
+ field.SetBool(value.(bool))
}
- field.SetBool(value.(bool))
}
case fieldType == TypeCharField || fieldType == TypeTextField:
if isNative {
- if value == nil {
- value = ""
+ if ns, ok := field.Interface().(sql.NullString); ok {
+ if value == nil {
+ ns.Valid = false
+ } else {
+ ns.String = value.(string)
+ ns.Valid = true
+ }
+ field.Set(reflect.ValueOf(ns))
+ } else {
+ if value == nil {
+ value = ""
+ }
+ field.SetString(value.(string))
}
- field.SetString(value.(string))
}
case fieldType == TypeDateField || fieldType == TypeDateTimeField:
if isNative {
@@ -1146,18 +1205,39 @@ setValue:
}
} else {
if isNative {
- if value == nil {
- value = int64(0)
+ if ni, ok := field.Interface().(sql.NullInt64); ok {
+ if value == nil {
+ ni.Valid = false
+ } else {
+ ni.Int64 = value.(int64)
+ ni.Valid = true
+ }
+ field.Set(reflect.ValueOf(ni))
+ } else {
+ if value == nil {
+ value = int64(0)
+ }
+ field.SetInt(value.(int64))
}
- field.SetInt(value.(int64))
}
}
case fieldType == TypeFloatField || fieldType == TypeDecimalField:
if isNative {
- if value == nil {
- value = float64(0)
+ if nf, ok := field.Interface().(sql.NullFloat64); ok {
+ if value == nil {
+ nf.Valid = false
+ } else {
+ nf.Float64 = value.(float64)
+ nf.Valid = true
+ }
+ field.Set(reflect.ValueOf(nf))
+ } else {
+
+ if value == nil {
+ value = float64(0)
+ }
+ field.SetFloat(value.(float64))
}
- field.SetFloat(value.(float64))
}
case fieldType&IsRelField > 0:
if value != nil {
diff --git a/orm/db_alias.go b/orm/db_alias.go
index 22066514..6a6623cc 100644
--- a/orm/db_alias.go
+++ b/orm/db_alias.go
@@ -168,7 +168,7 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
}
if dataBaseCache.add(aliasName, al) == false {
- return nil, fmt.Errorf("db name `%s` already registered, cannot reuse", aliasName)
+ return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName)
}
return al, nil
@@ -239,7 +239,7 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error {
if al, ok := dataBaseCache.get(aliasName); ok {
al.TZ = tz
} else {
- return fmt.Errorf("DataBase name `%s` not registered\n", aliasName)
+ return fmt.Errorf("DataBase alias name `%s` not registered\n", aliasName)
}
return nil
}
@@ -260,3 +260,19 @@ func SetMaxOpenConns(aliasName string, maxOpenConns int) {
fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)})
}
}
+
+// Get *sql.DB from registered database by db alias name.
+// Use "default" as alias name if you not set.
+func GetDB(aliasNames ...string) (*sql.DB, error) {
+ var name string
+ if len(aliasNames) > 0 {
+ name = aliasNames[0]
+ } else {
+ name = "default"
+ }
+ if al, ok := dataBaseCache.get(name); ok {
+ return al.DB, nil
+ } else {
+ return nil, fmt.Errorf("DataBase of alias name `%s` not found\n", name)
+ }
+}
diff --git a/orm/models.go b/orm/models.go
index 5744d865..59a8a8a1 100644
--- a/orm/models.go
+++ b/orm/models.go
@@ -98,3 +98,9 @@ func (mc *_modelCache) clean() {
mc.cacheByFN = make(map[string]*modelInfo)
mc.done = false
}
+
+// Clean model cache. Then you can re-RegisterModel.
+// Common use this api for test case.
+func ResetModelCache() {
+ modelCache.clean()
+}
diff --git a/orm/models_test.go b/orm/models_test.go
index 706f04dc..168c091a 100644
--- a/orm/models_test.go
+++ b/orm/models_test.go
@@ -1,6 +1,7 @@
package orm
import (
+ "database/sql"
"encoding/json"
"fmt"
"os"
@@ -116,27 +117,31 @@ type Data struct {
}
type DataNull struct {
- Id int
- Boolean bool `orm:"null"`
- Char string `orm:"null;size(50)"`
- Text string `orm:"null;type(text)"`
- Date time.Time `orm:"null;type(date)"`
- DateTime time.Time `orm:"null;column(datetime)""`
- Byte byte `orm:"null"`
- Rune rune `orm:"null"`
- Int int `orm:"null"`
- Int8 int8 `orm:"null"`
- Int16 int16 `orm:"null"`
- Int32 int32 `orm:"null"`
- Int64 int64 `orm:"null"`
- Uint uint `orm:"null"`
- Uint8 uint8 `orm:"null"`
- Uint16 uint16 `orm:"null"`
- Uint32 uint32 `orm:"null"`
- Uint64 uint64 `orm:"null"`
- Float32 float32 `orm:"null"`
- Float64 float64 `orm:"null"`
- Decimal float64 `orm:"digits(8);decimals(4);null"`
+ Id int
+ Boolean bool `orm:"null"`
+ Char string `orm:"null;size(50)"`
+ Text string `orm:"null;type(text)"`
+ Date time.Time `orm:"null;type(date)"`
+ DateTime time.Time `orm:"null;column(datetime)""`
+ Byte byte `orm:"null"`
+ Rune rune `orm:"null"`
+ Int int `orm:"null"`
+ Int8 int8 `orm:"null"`
+ Int16 int16 `orm:"null"`
+ Int32 int32 `orm:"null"`
+ Int64 int64 `orm:"null"`
+ Uint uint `orm:"null"`
+ Uint8 uint8 `orm:"null"`
+ Uint16 uint16 `orm:"null"`
+ Uint32 uint32 `orm:"null"`
+ Uint64 uint64 `orm:"null"`
+ Float32 float32 `orm:"null"`
+ Float64 float64 `orm:"null"`
+ Decimal float64 `orm:"digits(8);decimals(4);null"`
+ NullString sql.NullString `orm:"null"`
+ NullBool sql.NullBool `orm:"null"`
+ NullFloat64 sql.NullFloat64 `orm:"null"`
+ NullInt64 sql.NullInt64 `orm:"null"`
}
// only for mysql
@@ -303,9 +308,8 @@ go test -v github.com/astaxie/beego/orm
#### Sqlite3
-touch /path/to/orm_test.db
export ORM_DRIVER=sqlite3
-export ORM_SOURCE=/path/to/orm_test.db
+export ORM_SOURCE='file:memory_test?mode=memory'
go test -v github.com/astaxie/beego/orm
diff --git a/orm/models_utils.go b/orm/models_utils.go
index 1466a724..759093ef 100644
--- a/orm/models_utils.go
+++ b/orm/models_utils.go
@@ -1,6 +1,7 @@
package orm
import (
+ "database/sql"
"fmt"
"reflect"
"strings"
@@ -98,30 +99,29 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col
// return field type as type constant from reflect.Value
func getFieldType(val reflect.Value) (ft int, err error) {
elm := reflect.Indirect(val)
- switch elm.Kind() {
- case reflect.Int8:
+ switch elm.Interface().(type) {
+ case int8:
ft = TypeBitField
- case reflect.Int16:
+ case int16:
ft = TypeSmallIntegerField
- case reflect.Int32, reflect.Int:
+ case int32, int:
ft = TypeIntegerField
- case reflect.Int64:
+ case int64, sql.NullInt64:
ft = TypeBigIntegerField
- case reflect.Uint8:
+ case uint8:
ft = TypePositiveBitField
- case reflect.Uint16:
+ case uint16:
ft = TypePositiveSmallIntegerField
- case reflect.Uint32, reflect.Uint:
+ case uint32, uint:
ft = TypePositiveIntegerField
- case reflect.Uint64:
+ case uint64:
ft = TypePositiveBigIntegerField
- case reflect.Float32, reflect.Float64:
+ case float32, float64, sql.NullFloat64:
ft = TypeFloatField
- case reflect.Bool:
+ case bool, sql.NullBool:
ft = TypeBooleanField
- case reflect.String:
+ case string, sql.NullString:
ft = TypeCharField
- case reflect.Invalid:
default:
if elm.CanInterface() {
if _, ok := elm.Interface().(time.Time); ok {
diff --git a/orm/orm_test.go b/orm/orm_test.go
index c951d5ca..69f2fc86 100644
--- a/orm/orm_test.go
+++ b/orm/orm_test.go
@@ -2,6 +2,7 @@ package orm
import (
"bytes"
+ "database/sql"
"fmt"
"io/ioutil"
"os"
@@ -138,6 +139,15 @@ func throwFailNow(t *testing.T, err error, args ...interface{}) {
}
}
+func TestGetDB(t *testing.T) {
+ if db, err := GetDB(); err != nil {
+ throwFailNow(t, err)
+ } else {
+ err = db.Ping()
+ throwFailNow(t, err)
+ }
+}
+
func TestSyncDb(t *testing.T) {
RegisterModel(new(Data), new(DataNull))
RegisterModel(new(User))
@@ -258,12 +268,45 @@ func TestNullDataTypes(t *testing.T) {
err = dORM.Read(&d)
throwFail(t, err)
+ throwFail(t, AssertIs(d.NullBool.Valid, false))
+ throwFail(t, AssertIs(d.NullString.Valid, false))
+ throwFail(t, AssertIs(d.NullInt64.Valid, false))
+ throwFail(t, AssertIs(d.NullFloat64.Valid, false))
+
_, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec()
throwFail(t, err)
d = DataNull{Id: 2}
err = dORM.Read(&d)
throwFail(t, err)
+
+ d = DataNull{
+ DateTime: time.Now(),
+ NullString: sql.NullString{"test", true},
+ NullBool: sql.NullBool{true, true},
+ NullInt64: sql.NullInt64{42, true},
+ NullFloat64: sql.NullFloat64{42.42, true},
+ }
+
+ id, err = dORM.Insert(&d)
+ throwFail(t, err)
+ throwFail(t, AssertIs(id, 3))
+
+ d = DataNull{Id: 3}
+ err = dORM.Read(&d)
+ throwFail(t, err)
+
+ throwFail(t, AssertIs(d.NullBool.Valid, true))
+ throwFail(t, AssertIs(d.NullBool.Bool, true))
+
+ throwFail(t, AssertIs(d.NullString.Valid, true))
+ throwFail(t, AssertIs(d.NullString.String, "test"))
+
+ throwFail(t, AssertIs(d.NullInt64.Valid, true))
+ throwFail(t, AssertIs(d.NullInt64.Int64, 42))
+
+ throwFail(t, AssertIs(d.NullFloat64.Valid, true))
+ throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42))
}
func TestCRUD(t *testing.T) {
@@ -619,6 +662,14 @@ func TestOperators(t *testing.T) {
num, err = qs.Filter("status__in", []*int{&n1}, &n2).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 2))
+
+ num, err = qs.Filter("id__between", 2, 3).Count()
+ throwFail(t, err)
+ throwFail(t, AssertIs(num, 2))
+
+ num, err = qs.Filter("id__between", []int{2, 3}).Count()
+ throwFail(t, err)
+ throwFail(t, AssertIs(num, 2))
}
func TestSetCond(t *testing.T) {
@@ -1577,7 +1628,6 @@ func TestDelete(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 4))
- fmt.Println("...")
qs = dORM.QueryTable("comment")
num, err = qs.Filter("Post__User", 3).Delete()
throwFail(t, err)
@@ -1646,10 +1696,10 @@ func TestTransaction(t *testing.T) {
func TestReadOrCreate(t *testing.T) {
u := &User{
UserName: "Kyle",
- Email: "kylemcc@gmail.com",
+ Email: "kylemcc@gmail.com",
Password: "other_pass",
- Status: 7,
- IsStaff: false,
+ Status: 7,
+ IsStaff: false,
IsActive: true,
}
diff --git a/plugins/auth/basic.go b/plugins/auth/basic.go
index 33577d72..5838acc2 100644
--- a/plugins/auth/basic.go
+++ b/plugins/auth/basic.go
@@ -8,7 +8,7 @@ package auth
// }
// return false
// }
-// authPlugin := auth.NewBasicAuthenticator(SecretAuth)
+// authPlugin := auth.NewBasicAuthenticator(SecretAuth, "My Realm")
// beego.AddFilter("*","AfterStatic",authPlugin)
import (
diff --git a/router.go b/router.go
index bb083768..62f3a1c5 100644
--- a/router.go
+++ b/router.go
@@ -8,6 +8,7 @@ import (
"net/http"
"net/url"
"os"
+ "path"
"reflect"
"regexp"
"runtime"
@@ -545,14 +546,26 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
//static file server
for prefix, staticDir := range StaticDir {
+ if len(prefix) == 0 {
+ continue
+ }
if r.URL.Path == "/favicon.ico" {
- file := staticDir + r.URL.Path
- http.ServeFile(w, r, file)
- w.started = true
- goto Admin
+ file := path.Join(staticDir, r.URL.Path)
+ if utils.FileExists(file) {
+ http.ServeFile(w, r, file)
+ w.started = true
+ goto Admin
+ }
}
if strings.HasPrefix(r.URL.Path, prefix) {
- file := staticDir + r.URL.Path[len(prefix):]
+ if len(r.URL.Path) > len(prefix) && r.URL.Path[len(prefix)] != '/' {
+ continue
+ }
+ if r.URL.Path == prefix && prefix[len(prefix)-1] != '/' {
+ http.Redirect(rw, r, r.URL.Path+"/", 302)
+ goto Admin
+ }
+ file := path.Join(staticDir, r.URL.Path[len(prefix):])
finfo, err := os.Stat(file)
if err != nil {
if RunMode == "dev" {
diff --git a/session/sess_couchbase.go b/session/sess_couchbase.go
new file mode 100644
index 00000000..74b2242c
--- /dev/null
+++ b/session/sess_couchbase.go
@@ -0,0 +1,203 @@
+package session
+
+import (
+ "github.com/couchbaselabs/go-couchbase"
+ "net/http"
+ "strings"
+ "sync"
+)
+
+var couchbpder = &CouchbaseProvider{}
+
+type CouchbaseSessionStore struct {
+ b *couchbase.Bucket
+ sid string
+ lock sync.RWMutex
+ values map[interface{}]interface{}
+ maxlifetime int64
+}
+
+type CouchbaseProvider struct {
+ maxlifetime int64
+ savePath string
+ pool string
+ bucket string
+ b *couchbase.Bucket
+}
+
+func (cs *CouchbaseSessionStore) Set(key, value interface{}) error {
+ cs.lock.Lock()
+ defer cs.lock.Unlock()
+ cs.values[key] = value
+ return nil
+}
+
+func (cs *CouchbaseSessionStore) Get(key interface{}) interface{} {
+ cs.lock.RLock()
+ defer cs.lock.RUnlock()
+ if v, ok := cs.values[key]; ok {
+ return v
+ } else {
+ return nil
+ }
+ return nil
+}
+
+func (cs *CouchbaseSessionStore) Delete(key interface{}) error {
+ cs.lock.Lock()
+ defer cs.lock.Unlock()
+ delete(cs.values, key)
+ return nil
+}
+
+func (cs *CouchbaseSessionStore) Flush() error {
+ cs.lock.Lock()
+ defer cs.lock.Unlock()
+ cs.values = make(map[interface{}]interface{})
+ return nil
+}
+
+func (cs *CouchbaseSessionStore) SessionID() string {
+ return cs.sid
+}
+
+func (cs *CouchbaseSessionStore) SessionRelease(w http.ResponseWriter) {
+ defer cs.b.Close()
+
+ // if rs.values is empty, return directly
+ if len(cs.values) < 1 {
+ cs.b.Delete(cs.sid)
+ return
+ }
+
+ bo, err := encodeGob(cs.values)
+ if err != nil {
+ return
+ }
+
+ cs.b.Set(cs.sid, int(cs.maxlifetime), bo)
+}
+
+func (cp *CouchbaseProvider) getBucket() *couchbase.Bucket {
+ c, err := couchbase.Connect(cp.savePath)
+ if err != nil {
+ return nil
+ }
+
+ pool, err := c.GetPool(cp.pool)
+ if err != nil {
+ return nil
+ }
+
+ bucket, err := pool.GetBucket(cp.bucket)
+ if err != nil {
+ return nil
+ }
+
+ return bucket
+}
+
+// init couchbase session
+// savepath like couchbase server REST/JSON URL
+// e.g. http://host:port/, Pool, Bucket
+func (cp *CouchbaseProvider) SessionInit(maxlifetime int64, savePath string) error {
+ cp.maxlifetime = maxlifetime
+ configs := strings.Split(savePath, ",")
+ if len(configs) > 0 {
+ cp.savePath = configs[0]
+ }
+ if len(configs) > 1 {
+ cp.pool = configs[1]
+ }
+ if len(configs) > 2 {
+ cp.bucket = configs[2]
+ }
+
+ return nil
+}
+
+// read couchbase session by sid
+func (cp *CouchbaseProvider) SessionRead(sid string) (SessionStore, error) {
+ cp.b = cp.getBucket()
+
+ var doc []byte
+
+ err := cp.b.Get(sid, &doc)
+ var kv map[interface{}]interface{}
+ if doc == nil {
+ kv = make(map[interface{}]interface{})
+ } else {
+ kv, err = decodeGob(doc)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ cs := &CouchbaseSessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime}
+ return cs, nil
+}
+
+func (cp *CouchbaseProvider) SessionExist(sid string) bool {
+ cp.b = cp.getBucket()
+ defer cp.b.Close()
+
+ var doc []byte
+
+ if err := cp.b.Get(sid, &doc); err != nil || doc == nil {
+ return false
+ } else {
+ return true
+ }
+}
+
+func (cp *CouchbaseProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
+ cp.b = cp.getBucket()
+
+ var doc []byte
+ if err := cp.b.Get(oldsid, &doc); err != nil || doc == nil {
+ cp.b.Set(sid, int(cp.maxlifetime), "")
+ } else {
+ err := cp.b.Delete(oldsid)
+ if err != nil {
+ return nil, err
+ }
+ _, _ = cp.b.Add(sid, int(cp.maxlifetime), doc)
+ }
+
+ err := cp.b.Get(sid, &doc)
+ if err != nil {
+ return nil, err
+ }
+ var kv map[interface{}]interface{}
+ if doc == nil {
+ kv = make(map[interface{}]interface{})
+ } else {
+ kv, err = decodeGob(doc)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ cs := &CouchbaseSessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime}
+ return cs, nil
+}
+
+func (cp *CouchbaseProvider) SessionDestroy(sid string) error {
+ cp.b = cp.getBucket()
+ defer cp.b.Close()
+
+ cp.b.Delete(sid)
+ return nil
+}
+
+func (cp *CouchbaseProvider) SessionGC() {
+ return
+}
+
+func (cp *CouchbaseProvider) SessionAll() int {
+ return 0
+}
+
+func init() {
+ Register("couchbase", couchbpder)
+}
diff --git a/session/sess_file.go b/session/sess_file.go
index e8746532..7e9e2229 100644
--- a/session/sess_file.go
+++ b/session/sess_file.go
@@ -152,8 +152,7 @@ func (fp *FileProvider) SessionExist(sid string) bool {
func (fp *FileProvider) SessionDestroy(sid string) error {
filepder.lock.Lock()
defer filepder.lock.Unlock()
-
- os.Remove(path.Join(fp.savePath))
+ os.Remove(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
return nil
}
diff --git a/session/sess_redis.go b/session/sess_redis.go
index 3c51b793..e64d4c90 100644
--- a/session/sess_redis.go
+++ b/session/sess_redis.go
@@ -129,7 +129,8 @@ func (rp *RedisProvider) SessionInit(maxlifetime int64, savePath string) error {
}
return c, err
}, rp.poolsize)
- return nil
+
+ return rp.poollist.Get().Err()
}
// read redis session by sid
diff --git a/session/session.go b/session/session.go
index d1a44538..bc8832b0 100644
--- a/session/session.go
+++ b/session/session.go
@@ -56,11 +56,10 @@ type managerConfig struct {
EnableSetCookie bool `json:"enableSetCookie,omitempty"`
Gclifetime int64 `json:"gclifetime"`
Maxlifetime int64 `json:"maxLifetime"`
- Maxage int `json:"maxage"`
Secure bool `json:"secure"`
SessionIDHashFunc string `json:"sessionIDHashFunc"`
SessionIDHashKey string `json:"sessionIDHashKey"`
- CookieLifeTime int64 `json:"cookieLifeTime"`
+ CookieLifeTime int `json:"cookieLifeTime"`
ProviderConfig string `json:"providerConfig"`
}
@@ -125,8 +124,8 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
Path: "/",
HttpOnly: true,
Secure: manager.config.Secure}
- if manager.config.Maxage >= 0 {
- cookie.MaxAge = manager.config.Maxage
+ if manager.config.CookieLifeTime >= 0 {
+ cookie.MaxAge = manager.config.CookieLifeTime
}
if manager.config.EnableSetCookie {
http.SetCookie(w, cookie)
@@ -144,8 +143,8 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
Path: "/",
HttpOnly: true,
Secure: manager.config.Secure}
- if manager.config.Maxage >= 0 {
- cookie.MaxAge = manager.config.Maxage
+ if manager.config.CookieLifeTime >= 0 {
+ cookie.MaxAge = manager.config.CookieLifeTime
}
if manager.config.EnableSetCookie {
http.SetCookie(w, cookie)
@@ -206,8 +205,8 @@ func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Reque
cookie.HttpOnly = true
cookie.Path = "/"
}
- if manager.config.Maxage >= 0 {
- cookie.MaxAge = manager.config.Maxage
+ if manager.config.CookieLifeTime >= 0 {
+ cookie.MaxAge = manager.config.CookieLifeTime
}
http.SetCookie(w, cookie)
r.AddCookie(cookie)
diff --git a/utils/captcha/captcha.go b/utils/captcha/captcha.go
index 14979d78..f3998733 100644
--- a/utils/captcha/captcha.go
+++ b/utils/captcha/captcha.go
@@ -67,7 +67,7 @@ const (
fieldIdName = "captcha_id"
fieldCaptchaName = "captcha"
cachePrefix = "captcha_"
- urlPrefix = "/captcha/"
+ defaultURLPrefix = "/captcha/"
)
// Captcha struct
@@ -76,7 +76,7 @@ type Captcha struct {
store cache.Cache
// url prefix for captcha image
- urlPrefix string
+ URLPrefix string
// specify captcha id input field name
FieldIdName string
@@ -155,7 +155,7 @@ func (c *Captcha) CreateCaptchaHtml() template.HTML {
return template.HTML(fmt.Sprintf(``+
``+
``+
- ``, c.FieldIdName, value, c.urlPrefix, value, c.urlPrefix, value))
+ ``, c.FieldIdName, value, c.URLPrefix, value, c.URLPrefix, value))
}
// create a new captcha id
@@ -224,14 +224,14 @@ func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha {
cpt.StdHeight = stdHeight
if len(urlPrefix) == 0 {
- urlPrefix = urlPrefix
+ urlPrefix = defaultURLPrefix
}
if urlPrefix[len(urlPrefix)-1] != '/' {
urlPrefix += "/"
}
- cpt.urlPrefix = urlPrefix
+ cpt.URLPrefix = urlPrefix
return cpt
}
@@ -242,7 +242,7 @@ func NewWithFilter(urlPrefix string, store cache.Cache) *Captcha {
cpt := NewCaptcha(urlPrefix, store)
// create filter for serve captcha image
- beego.AddFilter(urlPrefix+":", "BeforeRouter", cpt.Handler)
+ beego.AddFilter(cpt.URLPrefix+":", "BeforeRouter", cpt.Handler)
// add to template func map
beego.AddFuncMap("create_captcha", cpt.CreateCaptchaHtml)
diff --git a/validation/validators.go b/validation/validators.go
index 6abc7b12..d198d442 100644
--- a/validation/validators.go
+++ b/validation/validators.go
@@ -443,7 +443,7 @@ func (b Base64) GetLimitValue() interface{} {
}
// just for chinese mobile phone number
-var mobilePattern = regexp.MustCompile("^((\\+86)|(86))?(1(([35][0-9])|(47)|[8][01236789]))\\d{8}$")
+var mobilePattern = regexp.MustCompile("^((\\+86)|(86))?(1(([35][0-9])|(47)|[8][012356789]))\\d{8}$")
type Mobile struct {
Match