1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-22 15:30:55 +00:00

Merge branch 'release/1.1.1'

This commit is contained in:
asta.xie 2014-03-12 21:25:41 +08:00
commit 439b1afb85
20 changed files with 621 additions and 156 deletions

View File

@ -12,7 +12,7 @@ import (
) )
// beego web framework version. // beego web framework version.
const VERSION = "1.1.0" const VERSION = "1.1.1"
type hookfunc func() error //hook function to run type hookfunc func() error //hook function to run
var hooks []hookfunc //hook function slice to store the hookfunc var hooks []hookfunc //hook function slice to store the hookfunc
@ -28,12 +28,12 @@ type GroupRouters []groupRouter
// Get a new GroupRouters // Get a new GroupRouters
func NewGroupRouters() GroupRouters { func NewGroupRouters() GroupRouters {
return make([]groupRouter, 0) return make(GroupRouters, 0)
} }
// Add Router in the GroupRouters // Add Router in the GroupRouters
// it is for plugin or module to register router // it is for plugin or module to register router
func (gr GroupRouters) AddRouter(pattern string, c ControllerInterface, mappingMethod ...string) { func (gr *GroupRouters) AddRouter(pattern string, c ControllerInterface, mappingMethod ...string) {
var newRG groupRouter var newRG groupRouter
if len(mappingMethod) > 0 { if len(mappingMethod) > 0 {
newRG = groupRouter{ newRG = groupRouter{
@ -48,16 +48,16 @@ func (gr GroupRouters) AddRouter(pattern string, c ControllerInterface, mappingM
"", "",
} }
} }
gr = append(gr, newRG) *gr = append(*gr, newRG)
} }
func (gr GroupRouters) AddAuto(c ControllerInterface) { func (gr *GroupRouters) AddAuto(c ControllerInterface) {
newRG := groupRouter{ newRG := groupRouter{
"", "",
c, c,
"", "",
} }
gr = append(gr, newRG) *gr = append(*gr, newRG)
} }
// AddGroupRouter with the prefix // AddGroupRouter with the prefix

2
cache/file.go vendored
View File

@ -147,6 +147,8 @@ func (this *FileCache) Get(key string) interface{} {
// timeout means how long to keep this file, unit of ms. // timeout means how long to keep this file, unit of ms.
// if timeout equals FileCacheEmbedExpiry(default is 0), cache this item forever. // if timeout equals FileCacheEmbedExpiry(default is 0), cache this item forever.
func (this *FileCache) Put(key string, val interface{}, timeout int64) error { func (this *FileCache) Put(key string, val interface{}, timeout int64) error {
gob.Register(val)
filename := this.getCacheFileName(key) filename := this.getCacheFileName(key)
var item FileCacheItem var item FileCacheItem
item.Data = val item.Data = val

View File

@ -1,7 +1,14 @@
package context package context
import ( import (
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"fmt"
"net/http" "net/http"
"strconv"
"strings"
"time"
"github.com/astaxie/beego/middleware" "github.com/astaxie/beego/middleware"
) )
@ -59,3 +66,41 @@ func (ctx *Context) GetCookie(key string) string {
func (ctx *Context) SetCookie(name string, value string, others ...interface{}) { func (ctx *Context) SetCookie(name string, value string, others ...interface{}) {
ctx.Output.Cookie(name, value, others...) ctx.Output.Cookie(name, value, others...)
} }
// Get secure cookie from request by a given key.
func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) {
val := ctx.Input.Cookie(key)
if val == "" {
return "", false
}
parts := strings.SplitN(val, "|", 3)
if len(parts) != 3 {
return "", false
}
vs := parts[0]
timestamp := parts[1]
sig := parts[2]
h := hmac.New(sha1.New, []byte(Secret))
fmt.Fprintf(h, "%s%s", vs, timestamp)
if fmt.Sprintf("%02x", h.Sum(nil)) != sig {
return "", false
}
res, _ := base64.URLEncoding.DecodeString(vs)
return string(res), true
}
// Set Secure cookie for response.
func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) {
vs := base64.URLEncoding.EncodeToString([]byte(value))
timestamp := strconv.FormatInt(time.Now().UnixNano(), 10)
h := hmac.New(sha1.New, []byte(Secret))
fmt.Fprintf(h, "%s%s", vs, timestamp)
sig := fmt.Sprintf("%02x", h.Sum(nil))
cookie := strings.Join([]string{vs, timestamp, sig}, "|")
ctx.Output.Cookie(name, cookie, others...)
}

View File

@ -77,39 +77,77 @@ func (output *BeegoOutput) Cookie(name string, value string, others ...interface
var b bytes.Buffer var b bytes.Buffer
fmt.Fprintf(&b, "%s=%s", sanitizeName(name), sanitizeValue(value)) fmt.Fprintf(&b, "%s=%s", sanitizeName(name), sanitizeValue(value))
if len(others) > 0 { if len(others) > 0 {
switch others[0].(type) { switch v := others[0].(type) {
case int: case int:
if others[0].(int) > 0 { if v > 0 {
fmt.Fprintf(&b, "; Max-Age=%d", others[0].(int)) fmt.Fprintf(&b, "; Max-Age=%d", v)
} else if others[0].(int) < 0 { } else if v < 0 {
fmt.Fprintf(&b, "; Max-Age=0") fmt.Fprintf(&b, "; Max-Age=0")
} }
case int64: case int64:
if others[0].(int64) > 0 { if v > 0 {
fmt.Fprintf(&b, "; Max-Age=%d", others[0].(int64)) fmt.Fprintf(&b, "; Max-Age=%d", v)
} else if others[0].(int64) < 0 { } else if v < 0 {
fmt.Fprintf(&b, "; Max-Age=0") fmt.Fprintf(&b, "; Max-Age=0")
} }
case int32: case int32:
if others[0].(int32) > 0 { if v > 0 {
fmt.Fprintf(&b, "; Max-Age=%d", others[0].(int32)) fmt.Fprintf(&b, "; Max-Age=%d", v)
} else if others[0].(int32) < 0 { } else if v < 0 {
fmt.Fprintf(&b, "; Max-Age=0") fmt.Fprintf(&b, "; Max-Age=0")
} }
} }
} }
// the settings below
// Path, Domain, Secure, HttpOnly
// can use nil skip set
// default "/"
if len(others) > 1 { if len(others) > 1 {
fmt.Fprintf(&b, "; Path=%s", sanitizeValue(others[1].(string))) if v, ok := others[1].(string); ok && len(v) > 0 {
fmt.Fprintf(&b, "; Path=%s", sanitizeValue(v))
} }
} else {
fmt.Fprintf(&b, "; Path=%s", "/")
}
// default empty
if len(others) > 2 { if len(others) > 2 {
fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(others[2].(string))) if v, ok := others[2].(string); ok && len(v) > 0 {
fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(v))
} }
}
// default empty
if len(others) > 3 { if len(others) > 3 {
var secure bool
switch v := others[3].(type) {
case bool:
secure = v
default:
if others[3] != nil {
secure = true
}
}
if secure {
fmt.Fprintf(&b, "; Secure") fmt.Fprintf(&b, "; Secure")
} }
}
// default true
httponly := true
if len(others) > 4 { if len(others) > 4 {
if v, ok := others[4].(bool); ok && !v || others[4] == nil {
// HttpOnly = false
httponly = false
}
}
if httponly {
fmt.Fprintf(&b, "; HttpOnly") fmt.Fprintf(&b, "; HttpOnly")
} }
output.Context.ResponseWriter.Header().Add("Set-Cookie", b.String()) output.Context.ResponseWriter.Header().Add("Set-Cookie", b.String())
} }

View File

@ -2,11 +2,7 @@ package beego
import ( import (
"bytes" "bytes"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"errors" "errors"
"fmt"
"html/template" "html/template"
"io" "io"
"io/ioutil" "io/ioutil"
@ -17,7 +13,6 @@ import (
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/astaxie/beego/context" "github.com/astaxie/beego/context"
"github.com/astaxie/beego/session" "github.com/astaxie/beego/session"
@ -313,11 +308,11 @@ func (c *Controller) GetString(key string) string {
// GetStrings returns the input string slice by key string. // GetStrings returns the input string slice by key string.
// it's designed for multi-value input field such as checkbox(input[type=checkbox]), multi-selection. // it's designed for multi-value input field such as checkbox(input[type=checkbox]), multi-selection.
func (c *Controller) GetStrings(key string) []string { func (c *Controller) GetStrings(key string) []string {
r := c.Ctx.Request f := c.Input()
if r.Form == nil { if f == nil {
return []string{} return []string{}
} }
vs := r.Form[key] vs := f[key]
if len(vs) > 0 { if len(vs) > 0 {
return vs return vs
} }
@ -417,40 +412,12 @@ func (c *Controller) IsAjax() bool {
// GetSecureCookie returns decoded cookie value from encoded browser cookie values. // GetSecureCookie returns decoded cookie value from encoded browser cookie values.
func (c *Controller) GetSecureCookie(Secret, key string) (string, bool) { func (c *Controller) GetSecureCookie(Secret, key string) (string, bool) {
val := c.Ctx.GetCookie(key) return c.Ctx.GetSecureCookie(Secret, key)
if val == "" {
return "", false
}
parts := strings.SplitN(val, "|", 3)
if len(parts) != 3 {
return "", false
}
vs := parts[0]
timestamp := parts[1]
sig := parts[2]
h := hmac.New(sha1.New, []byte(Secret))
fmt.Fprintf(h, "%s%s", vs, timestamp)
if fmt.Sprintf("%02x", h.Sum(nil)) != sig {
return "", false
}
res, _ := base64.URLEncoding.DecodeString(vs)
return string(res), true
} }
// SetSecureCookie puts value into cookie after encoded the value. // SetSecureCookie puts value into cookie after encoded the value.
func (c *Controller) SetSecureCookie(Secret, name, val string, age int64) { func (c *Controller) SetSecureCookie(Secret, name, value string, others ...interface{}) {
vs := base64.URLEncoding.EncodeToString([]byte(val)) c.Ctx.SetSecureCookie(Secret, name, value, others...)
timestamp := strconv.FormatInt(time.Now().UnixNano(), 10)
h := hmac.New(sha1.New, []byte(Secret))
fmt.Fprintf(h, "%s%s", vs, timestamp)
sig := fmt.Sprintf("%02x", h.Sum(nil))
cookie := strings.Join([]string{vs, timestamp, sig}, "|")
c.Ctx.SetCookie(name, cookie, age, "/")
} }
// XsrfToken creates a xsrf token string and returns. // XsrfToken creates a xsrf token string and returns.

View File

@ -24,7 +24,7 @@ func Get(url string) *BeegoHttpRequest {
req.Method = "GET" req.Method = "GET"
req.Header = http.Header{} req.Header = http.Header{}
req.Header.Set("User-Agent", defaultUserAgent) req.Header.Set("User-Agent", defaultUserAgent)
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil} return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil}
} }
// Post returns *BeegoHttpRequest with POST method. // Post returns *BeegoHttpRequest with POST method.
@ -33,7 +33,7 @@ func Post(url string) *BeegoHttpRequest {
req.Method = "POST" req.Method = "POST"
req.Header = http.Header{} req.Header = http.Header{}
req.Header.Set("User-Agent", defaultUserAgent) req.Header.Set("User-Agent", defaultUserAgent)
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil} return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil}
} }
// Put returns *BeegoHttpRequest with PUT method. // Put returns *BeegoHttpRequest with PUT method.
@ -42,7 +42,7 @@ func Put(url string) *BeegoHttpRequest {
req.Method = "PUT" req.Method = "PUT"
req.Header = http.Header{} req.Header = http.Header{}
req.Header.Set("User-Agent", defaultUserAgent) req.Header.Set("User-Agent", defaultUserAgent)
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil} return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil}
} }
// Delete returns *BeegoHttpRequest DELETE GET method. // Delete returns *BeegoHttpRequest DELETE GET method.
@ -51,7 +51,7 @@ func Delete(url string) *BeegoHttpRequest {
req.Method = "DELETE" req.Method = "DELETE"
req.Header = http.Header{} req.Header = http.Header{}
req.Header.Set("User-Agent", defaultUserAgent) req.Header.Set("User-Agent", defaultUserAgent)
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil} return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil}
} }
// Head returns *BeegoHttpRequest with HEAD method. // Head returns *BeegoHttpRequest with HEAD method.
@ -60,7 +60,7 @@ func Head(url string) *BeegoHttpRequest {
req.Method = "HEAD" req.Method = "HEAD"
req.Header = http.Header{} req.Header = http.Header{}
req.Header.Set("User-Agent", defaultUserAgent) req.Header.Set("User-Agent", defaultUserAgent)
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil} return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil}
} }
// BeegoHttpRequest provides more useful methods for requesting one url than http.Request. // BeegoHttpRequest provides more useful methods for requesting one url than http.Request.
@ -72,6 +72,8 @@ type BeegoHttpRequest struct {
connectTimeout time.Duration connectTimeout time.Duration
readWriteTimeout time.Duration readWriteTimeout time.Duration
tlsClientConfig *tls.Config tlsClientConfig *tls.Config
proxy func(*http.Request) (*url.URL, error)
transport http.RoundTripper
} }
// Debug sets show debug or not when executing request. // Debug sets show debug or not when executing request.
@ -105,6 +107,24 @@ func (b *BeegoHttpRequest) SetCookie(cookie *http.Cookie) *BeegoHttpRequest {
return b return b
} }
// Set transport to
func (b *BeegoHttpRequest) SetTransport(transport http.RoundTripper) *BeegoHttpRequest {
b.transport = transport
return b
}
// Set http proxy
// example:
//
// func(req *http.Request) (*url.URL, error) {
// u, _ := url.ParseRequestURI("http://127.0.0.1:8118")
// return u, nil
// }
func (b *BeegoHttpRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHttpRequest {
b.proxy = proxy
return b
}
// Param adds query param in to request. // Param adds query param in to request.
// params build query string as ?key1=value1&key2=value2... // params build query string as ?key1=value1&key2=value2...
func (b *BeegoHttpRequest) Param(key, value string) *BeegoHttpRequest { func (b *BeegoHttpRequest) Param(key, value string) *BeegoHttpRequest {
@ -171,12 +191,34 @@ func (b *BeegoHttpRequest) getResponse() (*http.Response, error) {
println(string(dump)) println(string(dump))
} }
client := &http.Client{ trans := b.transport
Transport: &http.Transport{
if trans == nil {
// create default transport
trans = &http.Transport{
TLSClientConfig: b.tlsClientConfig, TLSClientConfig: b.tlsClientConfig,
Proxy: b.proxy,
Dial: TimeoutDialer(b.connectTimeout, b.readWriteTimeout), Dial: TimeoutDialer(b.connectTimeout, b.readWriteTimeout),
},
} }
} else {
// if b.transport is *http.Transport then set the settings.
if t, ok := trans.(*http.Transport); ok {
if t.TLSClientConfig == nil {
t.TLSClientConfig = b.tlsClientConfig
}
if t.Proxy == nil {
t.Proxy = b.proxy
}
if t.Dial == nil {
t.Dial = TimeoutDialer(b.connectTimeout, b.readWriteTimeout)
}
}
}
client := &http.Client{
Transport: trans,
}
resp, err := client.Do(b.req) resp, err := client.Do(b.req)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -35,7 +35,7 @@ var (
"istartswith": true, "istartswith": true,
"iendswith": true, "iendswith": true,
"in": true, "in": true,
// "range": true, "between": true,
// "year": true, // "year": true,
// "month": true, // "month": true,
// "day": true, // "day": true,
@ -103,16 +103,37 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
} else { } else {
switch fi.fieldType { switch fi.fieldType {
case TypeBooleanField: case TypeBooleanField:
if nb, ok := field.Interface().(sql.NullBool); ok {
value = nil
if nb.Valid {
value = nb.Bool
}
} else {
value = field.Bool() value = field.Bool()
}
case TypeCharField, TypeTextField: case TypeCharField, TypeTextField:
if ns, ok := field.Interface().(sql.NullString); ok {
value = nil
if ns.Valid {
value = ns.String
}
} else {
value = field.String() value = field.String()
}
case TypeFloatField, TypeDecimalField: case TypeFloatField, TypeDecimalField:
if nf, ok := field.Interface().(sql.NullFloat64); ok {
value = nil
if nf.Valid {
value = nf.Float64
}
} else {
vu := field.Interface() vu := field.Interface()
if _, ok := vu.(float32); ok { if _, ok := vu.(float32); ok {
value, _ = StrTo(ToStr(vu)).Float64() value, _ = StrTo(ToStr(vu)).Float64()
} else { } else {
value = field.Float() value = field.Float()
} }
}
case TypeDateField, TypeDateTimeField: case TypeDateField, TypeDateTimeField:
value = field.Interface() value = field.Interface()
if t, ok := value.(time.Time); ok { if t, ok := value.(time.Time); ok {
@ -124,7 +145,14 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
case fi.fieldType&IsPostiveIntegerField > 0: case fi.fieldType&IsPostiveIntegerField > 0:
value = field.Uint() value = field.Uint()
case fi.fieldType&IsIntegerField > 0: case fi.fieldType&IsIntegerField > 0:
if ni, ok := field.Interface().(sql.NullInt64); ok {
value = nil
if ni.Valid {
value = ni.Int64
}
} else {
value = field.Int() value = field.Int()
}
case fi.fieldType&IsRelField > 0: case fi.fieldType&IsRelField > 0:
if field.IsNil() { if field.IsNil() {
value = nil value = nil
@ -144,6 +172,11 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
switch fi.fieldType { switch fi.fieldType {
case TypeDateField, TypeDateTimeField: case TypeDateField, TypeDateTimeField:
if fi.auto_now || fi.auto_now_add && insert { if fi.auto_now || fi.auto_now_add && insert {
if insert {
if t, ok := value.(time.Time); ok && !t.IsZero() {
break
}
}
tnow := time.Now() tnow := time.Now()
d.ins.TimeToDB(&tnow, tz) d.ins.TimeToDB(&tnow, tz)
value = tnow value = tnow
@ -883,13 +916,19 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri
} }
arg := params[0] arg := params[0]
if operator == "in" { switch operator {
case "in":
marks := make([]string, len(params)) marks := make([]string, len(params))
for i, _ := range marks { for i, _ := range marks {
marks[i] = "?" marks[i] = "?"
} }
sql = fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) sql = fmt.Sprintf("IN (%s)", strings.Join(marks, ", "))
} else { case "between":
if len(params) != 2 {
panic(fmt.Errorf("operator `%s` need 2 args not %d", operator, len(params)))
}
sql = "BETWEEN ? AND ?"
default:
if len(params) > 1 { if len(params) > 1 {
panic(fmt.Errorf("operator `%s` need 1 args not %d", operator, len(params))) panic(fmt.Errorf("operator `%s` need 1 args not %d", operator, len(params)))
} }
@ -1117,18 +1156,38 @@ setValue:
switch { switch {
case fieldType == TypeBooleanField: case fieldType == TypeBooleanField:
if isNative { if isNative {
if nb, ok := field.Interface().(sql.NullBool); ok {
if value == nil {
nb.Valid = false
} else {
nb.Bool = value.(bool)
nb.Valid = true
}
field.Set(reflect.ValueOf(nb))
} else {
if value == nil { if value == nil {
value = false value = false
} }
field.SetBool(value.(bool)) field.SetBool(value.(bool))
} }
}
case fieldType == TypeCharField || fieldType == TypeTextField: case fieldType == TypeCharField || fieldType == TypeTextField:
if isNative { if isNative {
if ns, ok := field.Interface().(sql.NullString); ok {
if value == nil {
ns.Valid = false
} else {
ns.String = value.(string)
ns.Valid = true
}
field.Set(reflect.ValueOf(ns))
} else {
if value == nil { if value == nil {
value = "" value = ""
} }
field.SetString(value.(string)) field.SetString(value.(string))
} }
}
case fieldType == TypeDateField || fieldType == TypeDateTimeField: case fieldType == TypeDateField || fieldType == TypeDateTimeField:
if isNative { if isNative {
if value == nil { if value == nil {
@ -1146,19 +1205,40 @@ setValue:
} }
} else { } else {
if isNative { if isNative {
if ni, ok := field.Interface().(sql.NullInt64); ok {
if value == nil {
ni.Valid = false
} else {
ni.Int64 = value.(int64)
ni.Valid = true
}
field.Set(reflect.ValueOf(ni))
} else {
if value == nil { if value == nil {
value = int64(0) value = int64(0)
} }
field.SetInt(value.(int64)) field.SetInt(value.(int64))
} }
} }
}
case fieldType == TypeFloatField || fieldType == TypeDecimalField: case fieldType == TypeFloatField || fieldType == TypeDecimalField:
if isNative { if isNative {
if nf, ok := field.Interface().(sql.NullFloat64); ok {
if value == nil {
nf.Valid = false
} else {
nf.Float64 = value.(float64)
nf.Valid = true
}
field.Set(reflect.ValueOf(nf))
} else {
if value == nil { if value == nil {
value = float64(0) value = float64(0)
} }
field.SetFloat(value.(float64)) field.SetFloat(value.(float64))
} }
}
case fieldType&IsRelField > 0: case fieldType&IsRelField > 0:
if value != nil { if value != nil {
fieldType = fi.relModelInfo.fields.pk.fieldType fieldType = fi.relModelInfo.fields.pk.fieldType

View File

@ -168,7 +168,7 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
} }
if dataBaseCache.add(aliasName, al) == false { if dataBaseCache.add(aliasName, al) == false {
return nil, fmt.Errorf("db name `%s` already registered, cannot reuse", aliasName) return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName)
} }
return al, nil return al, nil
@ -239,7 +239,7 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error {
if al, ok := dataBaseCache.get(aliasName); ok { if al, ok := dataBaseCache.get(aliasName); ok {
al.TZ = tz al.TZ = tz
} else { } else {
return fmt.Errorf("DataBase name `%s` not registered\n", aliasName) return fmt.Errorf("DataBase alias name `%s` not registered\n", aliasName)
} }
return nil return nil
} }
@ -260,3 +260,19 @@ func SetMaxOpenConns(aliasName string, maxOpenConns int) {
fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)}) fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)})
} }
} }
// Get *sql.DB from registered database by db alias name.
// Use "default" as alias name if you not set.
func GetDB(aliasNames ...string) (*sql.DB, error) {
var name string
if len(aliasNames) > 0 {
name = aliasNames[0]
} else {
name = "default"
}
if al, ok := dataBaseCache.get(name); ok {
return al.DB, nil
} else {
return nil, fmt.Errorf("DataBase of alias name `%s` not found\n", name)
}
}

View File

@ -98,3 +98,9 @@ func (mc *_modelCache) clean() {
mc.cacheByFN = make(map[string]*modelInfo) mc.cacheByFN = make(map[string]*modelInfo)
mc.done = false mc.done = false
} }
// Clean model cache. Then you can re-RegisterModel.
// Common use this api for test case.
func ResetModelCache() {
modelCache.clean()
}

View File

@ -1,6 +1,7 @@
package orm package orm
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
@ -137,6 +138,10 @@ type DataNull struct {
Float32 float32 `orm:"null"` Float32 float32 `orm:"null"`
Float64 float64 `orm:"null"` Float64 float64 `orm:"null"`
Decimal float64 `orm:"digits(8);decimals(4);null"` Decimal float64 `orm:"digits(8);decimals(4);null"`
NullString sql.NullString `orm:"null"`
NullBool sql.NullBool `orm:"null"`
NullFloat64 sql.NullFloat64 `orm:"null"`
NullInt64 sql.NullInt64 `orm:"null"`
} }
// only for mysql // only for mysql
@ -303,9 +308,8 @@ go test -v github.com/astaxie/beego/orm
#### Sqlite3 #### Sqlite3
touch /path/to/orm_test.db
export ORM_DRIVER=sqlite3 export ORM_DRIVER=sqlite3
export ORM_SOURCE=/path/to/orm_test.db export ORM_SOURCE='file:memory_test?mode=memory'
go test -v github.com/astaxie/beego/orm go test -v github.com/astaxie/beego/orm

View File

@ -1,6 +1,7 @@
package orm package orm
import ( import (
"database/sql"
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
@ -98,30 +99,29 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col
// return field type as type constant from reflect.Value // return field type as type constant from reflect.Value
func getFieldType(val reflect.Value) (ft int, err error) { func getFieldType(val reflect.Value) (ft int, err error) {
elm := reflect.Indirect(val) elm := reflect.Indirect(val)
switch elm.Kind() { switch elm.Interface().(type) {
case reflect.Int8: case int8:
ft = TypeBitField ft = TypeBitField
case reflect.Int16: case int16:
ft = TypeSmallIntegerField ft = TypeSmallIntegerField
case reflect.Int32, reflect.Int: case int32, int:
ft = TypeIntegerField ft = TypeIntegerField
case reflect.Int64: case int64, sql.NullInt64:
ft = TypeBigIntegerField ft = TypeBigIntegerField
case reflect.Uint8: case uint8:
ft = TypePositiveBitField ft = TypePositiveBitField
case reflect.Uint16: case uint16:
ft = TypePositiveSmallIntegerField ft = TypePositiveSmallIntegerField
case reflect.Uint32, reflect.Uint: case uint32, uint:
ft = TypePositiveIntegerField ft = TypePositiveIntegerField
case reflect.Uint64: case uint64:
ft = TypePositiveBigIntegerField ft = TypePositiveBigIntegerField
case reflect.Float32, reflect.Float64: case float32, float64, sql.NullFloat64:
ft = TypeFloatField ft = TypeFloatField
case reflect.Bool: case bool, sql.NullBool:
ft = TypeBooleanField ft = TypeBooleanField
case reflect.String: case string, sql.NullString:
ft = TypeCharField ft = TypeCharField
case reflect.Invalid:
default: default:
if elm.CanInterface() { if elm.CanInterface() {
if _, ok := elm.Interface().(time.Time); ok { if _, ok := elm.Interface().(time.Time); ok {

View File

@ -2,6 +2,7 @@ package orm
import ( import (
"bytes" "bytes"
"database/sql"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
@ -138,6 +139,15 @@ func throwFailNow(t *testing.T, err error, args ...interface{}) {
} }
} }
func TestGetDB(t *testing.T) {
if db, err := GetDB(); err != nil {
throwFailNow(t, err)
} else {
err = db.Ping()
throwFailNow(t, err)
}
}
func TestSyncDb(t *testing.T) { func TestSyncDb(t *testing.T) {
RegisterModel(new(Data), new(DataNull)) RegisterModel(new(Data), new(DataNull))
RegisterModel(new(User)) RegisterModel(new(User))
@ -258,12 +268,45 @@ func TestNullDataTypes(t *testing.T) {
err = dORM.Read(&d) err = dORM.Read(&d)
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(d.NullBool.Valid, false))
throwFail(t, AssertIs(d.NullString.Valid, false))
throwFail(t, AssertIs(d.NullInt64.Valid, false))
throwFail(t, AssertIs(d.NullFloat64.Valid, false))
_, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec() _, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec()
throwFail(t, err) throwFail(t, err)
d = DataNull{Id: 2} d = DataNull{Id: 2}
err = dORM.Read(&d) err = dORM.Read(&d)
throwFail(t, err) throwFail(t, err)
d = DataNull{
DateTime: time.Now(),
NullString: sql.NullString{"test", true},
NullBool: sql.NullBool{true, true},
NullInt64: sql.NullInt64{42, true},
NullFloat64: sql.NullFloat64{42.42, true},
}
id, err = dORM.Insert(&d)
throwFail(t, err)
throwFail(t, AssertIs(id, 3))
d = DataNull{Id: 3}
err = dORM.Read(&d)
throwFail(t, err)
throwFail(t, AssertIs(d.NullBool.Valid, true))
throwFail(t, AssertIs(d.NullBool.Bool, true))
throwFail(t, AssertIs(d.NullString.Valid, true))
throwFail(t, AssertIs(d.NullString.String, "test"))
throwFail(t, AssertIs(d.NullInt64.Valid, true))
throwFail(t, AssertIs(d.NullInt64.Int64, 42))
throwFail(t, AssertIs(d.NullFloat64.Valid, true))
throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42))
} }
func TestCRUD(t *testing.T) { func TestCRUD(t *testing.T) {
@ -619,6 +662,14 @@ func TestOperators(t *testing.T) {
num, err = qs.Filter("status__in", []*int{&n1}, &n2).Count() num, err = qs.Filter("status__in", []*int{&n1}, &n2).Count()
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 2)) throwFail(t, AssertIs(num, 2))
num, err = qs.Filter("id__between", 2, 3).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 2))
num, err = qs.Filter("id__between", []int{2, 3}).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 2))
} }
func TestSetCond(t *testing.T) { func TestSetCond(t *testing.T) {
@ -1577,7 +1628,6 @@ func TestDelete(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 4)) throwFail(t, AssertIs(num, 4))
fmt.Println("...")
qs = dORM.QueryTable("comment") qs = dORM.QueryTable("comment")
num, err = qs.Filter("Post__User", 3).Delete() num, err = qs.Filter("Post__User", 3).Delete()
throwFail(t, err) throwFail(t, err)

View File

@ -8,7 +8,7 @@ package auth
// } // }
// return false // return false
// } // }
// authPlugin := auth.NewBasicAuthenticator(SecretAuth) // authPlugin := auth.NewBasicAuthenticator(SecretAuth, "My Realm")
// beego.AddFilter("*","AfterStatic",authPlugin) // beego.AddFilter("*","AfterStatic",authPlugin)
import ( import (

View File

@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"path"
"reflect" "reflect"
"regexp" "regexp"
"runtime" "runtime"
@ -545,14 +546,26 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
//static file server //static file server
for prefix, staticDir := range StaticDir { for prefix, staticDir := range StaticDir {
if len(prefix) == 0 {
continue
}
if r.URL.Path == "/favicon.ico" { if r.URL.Path == "/favicon.ico" {
file := staticDir + r.URL.Path file := path.Join(staticDir, r.URL.Path)
if utils.FileExists(file) {
http.ServeFile(w, r, file) http.ServeFile(w, r, file)
w.started = true w.started = true
goto Admin goto Admin
} }
}
if strings.HasPrefix(r.URL.Path, prefix) { if strings.HasPrefix(r.URL.Path, prefix) {
file := staticDir + r.URL.Path[len(prefix):] if len(r.URL.Path) > len(prefix) && r.URL.Path[len(prefix)] != '/' {
continue
}
if r.URL.Path == prefix && prefix[len(prefix)-1] != '/' {
http.Redirect(rw, r, r.URL.Path+"/", 302)
goto Admin
}
file := path.Join(staticDir, r.URL.Path[len(prefix):])
finfo, err := os.Stat(file) finfo, err := os.Stat(file)
if err != nil { if err != nil {
if RunMode == "dev" { if RunMode == "dev" {

203
session/sess_couchbase.go Normal file
View File

@ -0,0 +1,203 @@
package session
import (
"github.com/couchbaselabs/go-couchbase"
"net/http"
"strings"
"sync"
)
var couchbpder = &CouchbaseProvider{}
type CouchbaseSessionStore struct {
b *couchbase.Bucket
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
type CouchbaseProvider struct {
maxlifetime int64
savePath string
pool string
bucket string
b *couchbase.Bucket
}
func (cs *CouchbaseSessionStore) Set(key, value interface{}) error {
cs.lock.Lock()
defer cs.lock.Unlock()
cs.values[key] = value
return nil
}
func (cs *CouchbaseSessionStore) Get(key interface{}) interface{} {
cs.lock.RLock()
defer cs.lock.RUnlock()
if v, ok := cs.values[key]; ok {
return v
} else {
return nil
}
return nil
}
func (cs *CouchbaseSessionStore) Delete(key interface{}) error {
cs.lock.Lock()
defer cs.lock.Unlock()
delete(cs.values, key)
return nil
}
func (cs *CouchbaseSessionStore) Flush() error {
cs.lock.Lock()
defer cs.lock.Unlock()
cs.values = make(map[interface{}]interface{})
return nil
}
func (cs *CouchbaseSessionStore) SessionID() string {
return cs.sid
}
func (cs *CouchbaseSessionStore) SessionRelease(w http.ResponseWriter) {
defer cs.b.Close()
// if rs.values is empty, return directly
if len(cs.values) < 1 {
cs.b.Delete(cs.sid)
return
}
bo, err := encodeGob(cs.values)
if err != nil {
return
}
cs.b.Set(cs.sid, int(cs.maxlifetime), bo)
}
func (cp *CouchbaseProvider) getBucket() *couchbase.Bucket {
c, err := couchbase.Connect(cp.savePath)
if err != nil {
return nil
}
pool, err := c.GetPool(cp.pool)
if err != nil {
return nil
}
bucket, err := pool.GetBucket(cp.bucket)
if err != nil {
return nil
}
return bucket
}
// init couchbase session
// savepath like couchbase server REST/JSON URL
// e.g. http://host:port/, Pool, Bucket
func (cp *CouchbaseProvider) SessionInit(maxlifetime int64, savePath string) error {
cp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
cp.savePath = configs[0]
}
if len(configs) > 1 {
cp.pool = configs[1]
}
if len(configs) > 2 {
cp.bucket = configs[2]
}
return nil
}
// read couchbase session by sid
func (cp *CouchbaseProvider) SessionRead(sid string) (SessionStore, error) {
cp.b = cp.getBucket()
var doc []byte
err := cp.b.Get(sid, &doc)
var kv map[interface{}]interface{}
if doc == nil {
kv = make(map[interface{}]interface{})
} else {
kv, err = decodeGob(doc)
if err != nil {
return nil, err
}
}
cs := &CouchbaseSessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime}
return cs, nil
}
func (cp *CouchbaseProvider) SessionExist(sid string) bool {
cp.b = cp.getBucket()
defer cp.b.Close()
var doc []byte
if err := cp.b.Get(sid, &doc); err != nil || doc == nil {
return false
} else {
return true
}
}
func (cp *CouchbaseProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
cp.b = cp.getBucket()
var doc []byte
if err := cp.b.Get(oldsid, &doc); err != nil || doc == nil {
cp.b.Set(sid, int(cp.maxlifetime), "")
} else {
err := cp.b.Delete(oldsid)
if err != nil {
return nil, err
}
_, _ = cp.b.Add(sid, int(cp.maxlifetime), doc)
}
err := cp.b.Get(sid, &doc)
if err != nil {
return nil, err
}
var kv map[interface{}]interface{}
if doc == nil {
kv = make(map[interface{}]interface{})
} else {
kv, err = decodeGob(doc)
if err != nil {
return nil, err
}
}
cs := &CouchbaseSessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime}
return cs, nil
}
func (cp *CouchbaseProvider) SessionDestroy(sid string) error {
cp.b = cp.getBucket()
defer cp.b.Close()
cp.b.Delete(sid)
return nil
}
func (cp *CouchbaseProvider) SessionGC() {
return
}
func (cp *CouchbaseProvider) SessionAll() int {
return 0
}
func init() {
Register("couchbase", couchbpder)
}

View File

@ -152,8 +152,7 @@ func (fp *FileProvider) SessionExist(sid string) bool {
func (fp *FileProvider) SessionDestroy(sid string) error { func (fp *FileProvider) SessionDestroy(sid string) error {
filepder.lock.Lock() filepder.lock.Lock()
defer filepder.lock.Unlock() defer filepder.lock.Unlock()
os.Remove(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
os.Remove(path.Join(fp.savePath))
return nil return nil
} }

View File

@ -129,7 +129,8 @@ func (rp *RedisProvider) SessionInit(maxlifetime int64, savePath string) error {
} }
return c, err return c, err
}, rp.poolsize) }, rp.poolsize)
return nil
return rp.poollist.Get().Err()
} }
// read redis session by sid // read redis session by sid

View File

@ -56,11 +56,10 @@ type managerConfig struct {
EnableSetCookie bool `json:"enableSetCookie,omitempty"` EnableSetCookie bool `json:"enableSetCookie,omitempty"`
Gclifetime int64 `json:"gclifetime"` Gclifetime int64 `json:"gclifetime"`
Maxlifetime int64 `json:"maxLifetime"` Maxlifetime int64 `json:"maxLifetime"`
Maxage int `json:"maxage"`
Secure bool `json:"secure"` Secure bool `json:"secure"`
SessionIDHashFunc string `json:"sessionIDHashFunc"` SessionIDHashFunc string `json:"sessionIDHashFunc"`
SessionIDHashKey string `json:"sessionIDHashKey"` SessionIDHashKey string `json:"sessionIDHashKey"`
CookieLifeTime int64 `json:"cookieLifeTime"` CookieLifeTime int `json:"cookieLifeTime"`
ProviderConfig string `json:"providerConfig"` ProviderConfig string `json:"providerConfig"`
} }
@ -125,8 +124,8 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
Path: "/", Path: "/",
HttpOnly: true, HttpOnly: true,
Secure: manager.config.Secure} Secure: manager.config.Secure}
if manager.config.Maxage >= 0 { if manager.config.CookieLifeTime >= 0 {
cookie.MaxAge = manager.config.Maxage cookie.MaxAge = manager.config.CookieLifeTime
} }
if manager.config.EnableSetCookie { if manager.config.EnableSetCookie {
http.SetCookie(w, cookie) http.SetCookie(w, cookie)
@ -144,8 +143,8 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
Path: "/", Path: "/",
HttpOnly: true, HttpOnly: true,
Secure: manager.config.Secure} Secure: manager.config.Secure}
if manager.config.Maxage >= 0 { if manager.config.CookieLifeTime >= 0 {
cookie.MaxAge = manager.config.Maxage cookie.MaxAge = manager.config.CookieLifeTime
} }
if manager.config.EnableSetCookie { if manager.config.EnableSetCookie {
http.SetCookie(w, cookie) http.SetCookie(w, cookie)
@ -206,8 +205,8 @@ func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Reque
cookie.HttpOnly = true cookie.HttpOnly = true
cookie.Path = "/" cookie.Path = "/"
} }
if manager.config.Maxage >= 0 { if manager.config.CookieLifeTime >= 0 {
cookie.MaxAge = manager.config.Maxage cookie.MaxAge = manager.config.CookieLifeTime
} }
http.SetCookie(w, cookie) http.SetCookie(w, cookie)
r.AddCookie(cookie) r.AddCookie(cookie)

View File

@ -67,7 +67,7 @@ const (
fieldIdName = "captcha_id" fieldIdName = "captcha_id"
fieldCaptchaName = "captcha" fieldCaptchaName = "captcha"
cachePrefix = "captcha_" cachePrefix = "captcha_"
urlPrefix = "/captcha/" defaultURLPrefix = "/captcha/"
) )
// Captcha struct // Captcha struct
@ -76,7 +76,7 @@ type Captcha struct {
store cache.Cache store cache.Cache
// url prefix for captcha image // url prefix for captcha image
urlPrefix string URLPrefix string
// specify captcha id input field name // specify captcha id input field name
FieldIdName string FieldIdName string
@ -155,7 +155,7 @@ func (c *Captcha) CreateCaptchaHtml() template.HTML {
return template.HTML(fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`+ return template.HTML(fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`+
`<a class="captcha" href="javascript:">`+ `<a class="captcha" href="javascript:">`+
`<img onclick="this.src=('%s%s.png?reload='+(new Date()).getTime())" class="captcha-img" src="%s%s.png">`+ `<img onclick="this.src=('%s%s.png?reload='+(new Date()).getTime())" class="captcha-img" src="%s%s.png">`+
`</a>`, c.FieldIdName, value, c.urlPrefix, value, c.urlPrefix, value)) `</a>`, c.FieldIdName, value, c.URLPrefix, value, c.URLPrefix, value))
} }
// create a new captcha id // create a new captcha id
@ -224,14 +224,14 @@ func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha {
cpt.StdHeight = stdHeight cpt.StdHeight = stdHeight
if len(urlPrefix) == 0 { if len(urlPrefix) == 0 {
urlPrefix = urlPrefix urlPrefix = defaultURLPrefix
} }
if urlPrefix[len(urlPrefix)-1] != '/' { if urlPrefix[len(urlPrefix)-1] != '/' {
urlPrefix += "/" urlPrefix += "/"
} }
cpt.urlPrefix = urlPrefix cpt.URLPrefix = urlPrefix
return cpt return cpt
} }
@ -242,7 +242,7 @@ func NewWithFilter(urlPrefix string, store cache.Cache) *Captcha {
cpt := NewCaptcha(urlPrefix, store) cpt := NewCaptcha(urlPrefix, store)
// create filter for serve captcha image // create filter for serve captcha image
beego.AddFilter(urlPrefix+":", "BeforeRouter", cpt.Handler) beego.AddFilter(cpt.URLPrefix+":", "BeforeRouter", cpt.Handler)
// add to template func map // add to template func map
beego.AddFuncMap("create_captcha", cpt.CreateCaptchaHtml) beego.AddFuncMap("create_captcha", cpt.CreateCaptchaHtml)

View File

@ -443,7 +443,7 @@ func (b Base64) GetLimitValue() interface{} {
} }
// just for chinese mobile phone number // just for chinese mobile phone number
var mobilePattern = regexp.MustCompile("^((\\+86)|(86))?(1(([35][0-9])|(47)|[8][01236789]))\\d{8}$") var mobilePattern = regexp.MustCompile("^((\\+86)|(86))?(1(([35][0-9])|(47)|[8][012356789]))\\d{8}$")
type Mobile struct { type Mobile struct {
Match Match