mirror of
https://github.com/astaxie/beego.git
synced 2024-11-22 16:00:59 +00:00
Merge branch 'release/1.1.1'
This commit is contained in:
commit
439b1afb85
12
beego.go
12
beego.go
@ -12,7 +12,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// beego web framework version.
|
// beego web framework version.
|
||||||
const VERSION = "1.1.0"
|
const VERSION = "1.1.1"
|
||||||
|
|
||||||
type hookfunc func() error //hook function to run
|
type hookfunc func() error //hook function to run
|
||||||
var hooks []hookfunc //hook function slice to store the hookfunc
|
var hooks []hookfunc //hook function slice to store the hookfunc
|
||||||
@ -28,12 +28,12 @@ type GroupRouters []groupRouter
|
|||||||
|
|
||||||
// Get a new GroupRouters
|
// Get a new GroupRouters
|
||||||
func NewGroupRouters() GroupRouters {
|
func NewGroupRouters() GroupRouters {
|
||||||
return make([]groupRouter, 0)
|
return make(GroupRouters, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add Router in the GroupRouters
|
// Add Router in the GroupRouters
|
||||||
// it is for plugin or module to register router
|
// it is for plugin or module to register router
|
||||||
func (gr GroupRouters) AddRouter(pattern string, c ControllerInterface, mappingMethod ...string) {
|
func (gr *GroupRouters) AddRouter(pattern string, c ControllerInterface, mappingMethod ...string) {
|
||||||
var newRG groupRouter
|
var newRG groupRouter
|
||||||
if len(mappingMethod) > 0 {
|
if len(mappingMethod) > 0 {
|
||||||
newRG = groupRouter{
|
newRG = groupRouter{
|
||||||
@ -48,16 +48,16 @@ func (gr GroupRouters) AddRouter(pattern string, c ControllerInterface, mappingM
|
|||||||
"",
|
"",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
gr = append(gr, newRG)
|
*gr = append(*gr, newRG)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (gr GroupRouters) AddAuto(c ControllerInterface) {
|
func (gr *GroupRouters) AddAuto(c ControllerInterface) {
|
||||||
newRG := groupRouter{
|
newRG := groupRouter{
|
||||||
"",
|
"",
|
||||||
c,
|
c,
|
||||||
"",
|
"",
|
||||||
}
|
}
|
||||||
gr = append(gr, newRG)
|
*gr = append(*gr, newRG)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddGroupRouter with the prefix
|
// AddGroupRouter with the prefix
|
||||||
|
2
cache/file.go
vendored
2
cache/file.go
vendored
@ -147,6 +147,8 @@ func (this *FileCache) Get(key string) interface{} {
|
|||||||
// timeout means how long to keep this file, unit of ms.
|
// timeout means how long to keep this file, unit of ms.
|
||||||
// if timeout equals FileCacheEmbedExpiry(default is 0), cache this item forever.
|
// if timeout equals FileCacheEmbedExpiry(default is 0), cache this item forever.
|
||||||
func (this *FileCache) Put(key string, val interface{}, timeout int64) error {
|
func (this *FileCache) Put(key string, val interface{}, timeout int64) error {
|
||||||
|
gob.Register(val)
|
||||||
|
|
||||||
filename := this.getCacheFileName(key)
|
filename := this.getCacheFileName(key)
|
||||||
var item FileCacheItem
|
var item FileCacheItem
|
||||||
item.Data = val
|
item.Data = val
|
||||||
|
@ -1,7 +1,14 @@
|
|||||||
package context
|
package context
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha1"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/astaxie/beego/middleware"
|
"github.com/astaxie/beego/middleware"
|
||||||
)
|
)
|
||||||
@ -59,3 +66,41 @@ func (ctx *Context) GetCookie(key string) string {
|
|||||||
func (ctx *Context) SetCookie(name string, value string, others ...interface{}) {
|
func (ctx *Context) SetCookie(name string, value string, others ...interface{}) {
|
||||||
ctx.Output.Cookie(name, value, others...)
|
ctx.Output.Cookie(name, value, others...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get secure cookie from request by a given key.
|
||||||
|
func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) {
|
||||||
|
val := ctx.Input.Cookie(key)
|
||||||
|
if val == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.SplitN(val, "|", 3)
|
||||||
|
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
vs := parts[0]
|
||||||
|
timestamp := parts[1]
|
||||||
|
sig := parts[2]
|
||||||
|
|
||||||
|
h := hmac.New(sha1.New, []byte(Secret))
|
||||||
|
fmt.Fprintf(h, "%s%s", vs, timestamp)
|
||||||
|
|
||||||
|
if fmt.Sprintf("%02x", h.Sum(nil)) != sig {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
res, _ := base64.URLEncoding.DecodeString(vs)
|
||||||
|
return string(res), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set Secure cookie for response.
|
||||||
|
func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) {
|
||||||
|
vs := base64.URLEncoding.EncodeToString([]byte(value))
|
||||||
|
timestamp := strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||||
|
h := hmac.New(sha1.New, []byte(Secret))
|
||||||
|
fmt.Fprintf(h, "%s%s", vs, timestamp)
|
||||||
|
sig := fmt.Sprintf("%02x", h.Sum(nil))
|
||||||
|
cookie := strings.Join([]string{vs, timestamp, sig}, "|")
|
||||||
|
ctx.Output.Cookie(name, cookie, others...)
|
||||||
|
}
|
||||||
|
@ -77,39 +77,77 @@ func (output *BeegoOutput) Cookie(name string, value string, others ...interface
|
|||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
fmt.Fprintf(&b, "%s=%s", sanitizeName(name), sanitizeValue(value))
|
fmt.Fprintf(&b, "%s=%s", sanitizeName(name), sanitizeValue(value))
|
||||||
if len(others) > 0 {
|
if len(others) > 0 {
|
||||||
switch others[0].(type) {
|
switch v := others[0].(type) {
|
||||||
case int:
|
case int:
|
||||||
if others[0].(int) > 0 {
|
if v > 0 {
|
||||||
fmt.Fprintf(&b, "; Max-Age=%d", others[0].(int))
|
fmt.Fprintf(&b, "; Max-Age=%d", v)
|
||||||
} else if others[0].(int) < 0 {
|
} else if v < 0 {
|
||||||
fmt.Fprintf(&b, "; Max-Age=0")
|
fmt.Fprintf(&b, "; Max-Age=0")
|
||||||
}
|
}
|
||||||
case int64:
|
case int64:
|
||||||
if others[0].(int64) > 0 {
|
if v > 0 {
|
||||||
fmt.Fprintf(&b, "; Max-Age=%d", others[0].(int64))
|
fmt.Fprintf(&b, "; Max-Age=%d", v)
|
||||||
} else if others[0].(int64) < 0 {
|
} else if v < 0 {
|
||||||
fmt.Fprintf(&b, "; Max-Age=0")
|
fmt.Fprintf(&b, "; Max-Age=0")
|
||||||
}
|
}
|
||||||
case int32:
|
case int32:
|
||||||
if others[0].(int32) > 0 {
|
if v > 0 {
|
||||||
fmt.Fprintf(&b, "; Max-Age=%d", others[0].(int32))
|
fmt.Fprintf(&b, "; Max-Age=%d", v)
|
||||||
} else if others[0].(int32) < 0 {
|
} else if v < 0 {
|
||||||
fmt.Fprintf(&b, "; Max-Age=0")
|
fmt.Fprintf(&b, "; Max-Age=0")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// the settings below
|
||||||
|
// Path, Domain, Secure, HttpOnly
|
||||||
|
// can use nil skip set
|
||||||
|
|
||||||
|
// default "/"
|
||||||
if len(others) > 1 {
|
if len(others) > 1 {
|
||||||
fmt.Fprintf(&b, "; Path=%s", sanitizeValue(others[1].(string)))
|
if v, ok := others[1].(string); ok && len(v) > 0 {
|
||||||
|
fmt.Fprintf(&b, "; Path=%s", sanitizeValue(v))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(&b, "; Path=%s", "/")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// default empty
|
||||||
if len(others) > 2 {
|
if len(others) > 2 {
|
||||||
fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(others[2].(string)))
|
if v, ok := others[2].(string); ok && len(v) > 0 {
|
||||||
|
fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(v))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// default empty
|
||||||
if len(others) > 3 {
|
if len(others) > 3 {
|
||||||
fmt.Fprintf(&b, "; Secure")
|
var secure bool
|
||||||
|
switch v := others[3].(type) {
|
||||||
|
case bool:
|
||||||
|
secure = v
|
||||||
|
default:
|
||||||
|
if others[3] != nil {
|
||||||
|
secure = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if secure {
|
||||||
|
fmt.Fprintf(&b, "; Secure")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// default true
|
||||||
|
httponly := true
|
||||||
if len(others) > 4 {
|
if len(others) > 4 {
|
||||||
|
if v, ok := others[4].(bool); ok && !v || others[4] == nil {
|
||||||
|
// HttpOnly = false
|
||||||
|
httponly = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if httponly {
|
||||||
fmt.Fprintf(&b, "; HttpOnly")
|
fmt.Fprintf(&b, "; HttpOnly")
|
||||||
}
|
}
|
||||||
|
|
||||||
output.Context.ResponseWriter.Header().Add("Set-Cookie", b.String())
|
output.Context.ResponseWriter.Header().Add("Set-Cookie", b.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,11 +2,7 @@ package beego
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/hmac"
|
|
||||||
"crypto/sha1"
|
|
||||||
"encoding/base64"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"html/template"
|
"html/template"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
@ -17,7 +13,6 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/astaxie/beego/context"
|
"github.com/astaxie/beego/context"
|
||||||
"github.com/astaxie/beego/session"
|
"github.com/astaxie/beego/session"
|
||||||
@ -313,11 +308,11 @@ func (c *Controller) GetString(key string) string {
|
|||||||
// GetStrings returns the input string slice by key string.
|
// GetStrings returns the input string slice by key string.
|
||||||
// it's designed for multi-value input field such as checkbox(input[type=checkbox]), multi-selection.
|
// it's designed for multi-value input field such as checkbox(input[type=checkbox]), multi-selection.
|
||||||
func (c *Controller) GetStrings(key string) []string {
|
func (c *Controller) GetStrings(key string) []string {
|
||||||
r := c.Ctx.Request
|
f := c.Input()
|
||||||
if r.Form == nil {
|
if f == nil {
|
||||||
return []string{}
|
return []string{}
|
||||||
}
|
}
|
||||||
vs := r.Form[key]
|
vs := f[key]
|
||||||
if len(vs) > 0 {
|
if len(vs) > 0 {
|
||||||
return vs
|
return vs
|
||||||
}
|
}
|
||||||
@ -417,40 +412,12 @@ func (c *Controller) IsAjax() bool {
|
|||||||
|
|
||||||
// GetSecureCookie returns decoded cookie value from encoded browser cookie values.
|
// GetSecureCookie returns decoded cookie value from encoded browser cookie values.
|
||||||
func (c *Controller) GetSecureCookie(Secret, key string) (string, bool) {
|
func (c *Controller) GetSecureCookie(Secret, key string) (string, bool) {
|
||||||
val := c.Ctx.GetCookie(key)
|
return c.Ctx.GetSecureCookie(Secret, key)
|
||||||
if val == "" {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
parts := strings.SplitN(val, "|", 3)
|
|
||||||
|
|
||||||
if len(parts) != 3 {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
vs := parts[0]
|
|
||||||
timestamp := parts[1]
|
|
||||||
sig := parts[2]
|
|
||||||
|
|
||||||
h := hmac.New(sha1.New, []byte(Secret))
|
|
||||||
fmt.Fprintf(h, "%s%s", vs, timestamp)
|
|
||||||
|
|
||||||
if fmt.Sprintf("%02x", h.Sum(nil)) != sig {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
res, _ := base64.URLEncoding.DecodeString(vs)
|
|
||||||
return string(res), true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSecureCookie puts value into cookie after encoded the value.
|
// SetSecureCookie puts value into cookie after encoded the value.
|
||||||
func (c *Controller) SetSecureCookie(Secret, name, val string, age int64) {
|
func (c *Controller) SetSecureCookie(Secret, name, value string, others ...interface{}) {
|
||||||
vs := base64.URLEncoding.EncodeToString([]byte(val))
|
c.Ctx.SetSecureCookie(Secret, name, value, others...)
|
||||||
timestamp := strconv.FormatInt(time.Now().UnixNano(), 10)
|
|
||||||
h := hmac.New(sha1.New, []byte(Secret))
|
|
||||||
fmt.Fprintf(h, "%s%s", vs, timestamp)
|
|
||||||
sig := fmt.Sprintf("%02x", h.Sum(nil))
|
|
||||||
cookie := strings.Join([]string{vs, timestamp, sig}, "|")
|
|
||||||
c.Ctx.SetCookie(name, cookie, age, "/")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// XsrfToken creates a xsrf token string and returns.
|
// XsrfToken creates a xsrf token string and returns.
|
||||||
|
@ -24,7 +24,7 @@ func Get(url string) *BeegoHttpRequest {
|
|||||||
req.Method = "GET"
|
req.Method = "GET"
|
||||||
req.Header = http.Header{}
|
req.Header = http.Header{}
|
||||||
req.Header.Set("User-Agent", defaultUserAgent)
|
req.Header.Set("User-Agent", defaultUserAgent)
|
||||||
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil}
|
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Post returns *BeegoHttpRequest with POST method.
|
// Post returns *BeegoHttpRequest with POST method.
|
||||||
@ -33,7 +33,7 @@ func Post(url string) *BeegoHttpRequest {
|
|||||||
req.Method = "POST"
|
req.Method = "POST"
|
||||||
req.Header = http.Header{}
|
req.Header = http.Header{}
|
||||||
req.Header.Set("User-Agent", defaultUserAgent)
|
req.Header.Set("User-Agent", defaultUserAgent)
|
||||||
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil}
|
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Put returns *BeegoHttpRequest with PUT method.
|
// Put returns *BeegoHttpRequest with PUT method.
|
||||||
@ -42,7 +42,7 @@ func Put(url string) *BeegoHttpRequest {
|
|||||||
req.Method = "PUT"
|
req.Method = "PUT"
|
||||||
req.Header = http.Header{}
|
req.Header = http.Header{}
|
||||||
req.Header.Set("User-Agent", defaultUserAgent)
|
req.Header.Set("User-Agent", defaultUserAgent)
|
||||||
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil}
|
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete returns *BeegoHttpRequest DELETE GET method.
|
// Delete returns *BeegoHttpRequest DELETE GET method.
|
||||||
@ -51,7 +51,7 @@ func Delete(url string) *BeegoHttpRequest {
|
|||||||
req.Method = "DELETE"
|
req.Method = "DELETE"
|
||||||
req.Header = http.Header{}
|
req.Header = http.Header{}
|
||||||
req.Header.Set("User-Agent", defaultUserAgent)
|
req.Header.Set("User-Agent", defaultUserAgent)
|
||||||
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil}
|
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Head returns *BeegoHttpRequest with HEAD method.
|
// Head returns *BeegoHttpRequest with HEAD method.
|
||||||
@ -60,7 +60,7 @@ func Head(url string) *BeegoHttpRequest {
|
|||||||
req.Method = "HEAD"
|
req.Method = "HEAD"
|
||||||
req.Header = http.Header{}
|
req.Header = http.Header{}
|
||||||
req.Header.Set("User-Agent", defaultUserAgent)
|
req.Header.Set("User-Agent", defaultUserAgent)
|
||||||
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil}
|
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil}
|
||||||
}
|
}
|
||||||
|
|
||||||
// BeegoHttpRequest provides more useful methods for requesting one url than http.Request.
|
// BeegoHttpRequest provides more useful methods for requesting one url than http.Request.
|
||||||
@ -72,6 +72,8 @@ type BeegoHttpRequest struct {
|
|||||||
connectTimeout time.Duration
|
connectTimeout time.Duration
|
||||||
readWriteTimeout time.Duration
|
readWriteTimeout time.Duration
|
||||||
tlsClientConfig *tls.Config
|
tlsClientConfig *tls.Config
|
||||||
|
proxy func(*http.Request) (*url.URL, error)
|
||||||
|
transport http.RoundTripper
|
||||||
}
|
}
|
||||||
|
|
||||||
// Debug sets show debug or not when executing request.
|
// Debug sets show debug or not when executing request.
|
||||||
@ -105,6 +107,24 @@ func (b *BeegoHttpRequest) SetCookie(cookie *http.Cookie) *BeegoHttpRequest {
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set transport to
|
||||||
|
func (b *BeegoHttpRequest) SetTransport(transport http.RoundTripper) *BeegoHttpRequest {
|
||||||
|
b.transport = transport
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set http proxy
|
||||||
|
// example:
|
||||||
|
//
|
||||||
|
// func(req *http.Request) (*url.URL, error) {
|
||||||
|
// u, _ := url.ParseRequestURI("http://127.0.0.1:8118")
|
||||||
|
// return u, nil
|
||||||
|
// }
|
||||||
|
func (b *BeegoHttpRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHttpRequest {
|
||||||
|
b.proxy = proxy
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
// Param adds query param in to request.
|
// Param adds query param in to request.
|
||||||
// params build query string as ?key1=value1&key2=value2...
|
// params build query string as ?key1=value1&key2=value2...
|
||||||
func (b *BeegoHttpRequest) Param(key, value string) *BeegoHttpRequest {
|
func (b *BeegoHttpRequest) Param(key, value string) *BeegoHttpRequest {
|
||||||
@ -171,12 +191,34 @@ func (b *BeegoHttpRequest) getResponse() (*http.Response, error) {
|
|||||||
println(string(dump))
|
println(string(dump))
|
||||||
}
|
}
|
||||||
|
|
||||||
client := &http.Client{
|
trans := b.transport
|
||||||
Transport: &http.Transport{
|
|
||||||
|
if trans == nil {
|
||||||
|
// create default transport
|
||||||
|
trans = &http.Transport{
|
||||||
TLSClientConfig: b.tlsClientConfig,
|
TLSClientConfig: b.tlsClientConfig,
|
||||||
|
Proxy: b.proxy,
|
||||||
Dial: TimeoutDialer(b.connectTimeout, b.readWriteTimeout),
|
Dial: TimeoutDialer(b.connectTimeout, b.readWriteTimeout),
|
||||||
},
|
}
|
||||||
|
} else {
|
||||||
|
// if b.transport is *http.Transport then set the settings.
|
||||||
|
if t, ok := trans.(*http.Transport); ok {
|
||||||
|
if t.TLSClientConfig == nil {
|
||||||
|
t.TLSClientConfig = b.tlsClientConfig
|
||||||
|
}
|
||||||
|
if t.Proxy == nil {
|
||||||
|
t.Proxy = b.proxy
|
||||||
|
}
|
||||||
|
if t.Dial == nil {
|
||||||
|
t.Dial = TimeoutDialer(b.connectTimeout, b.readWriteTimeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: trans,
|
||||||
|
}
|
||||||
|
|
||||||
resp, err := client.Do(b.req)
|
resp, err := client.Do(b.req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
128
orm/db.go
128
orm/db.go
@ -35,7 +35,7 @@ var (
|
|||||||
"istartswith": true,
|
"istartswith": true,
|
||||||
"iendswith": true,
|
"iendswith": true,
|
||||||
"in": true,
|
"in": true,
|
||||||
// "range": true,
|
"between": true,
|
||||||
// "year": true,
|
// "year": true,
|
||||||
// "month": true,
|
// "month": true,
|
||||||
// "day": true,
|
// "day": true,
|
||||||
@ -103,15 +103,36 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
|
|||||||
} else {
|
} else {
|
||||||
switch fi.fieldType {
|
switch fi.fieldType {
|
||||||
case TypeBooleanField:
|
case TypeBooleanField:
|
||||||
value = field.Bool()
|
if nb, ok := field.Interface().(sql.NullBool); ok {
|
||||||
case TypeCharField, TypeTextField:
|
value = nil
|
||||||
value = field.String()
|
if nb.Valid {
|
||||||
case TypeFloatField, TypeDecimalField:
|
value = nb.Bool
|
||||||
vu := field.Interface()
|
}
|
||||||
if _, ok := vu.(float32); ok {
|
|
||||||
value, _ = StrTo(ToStr(vu)).Float64()
|
|
||||||
} else {
|
} else {
|
||||||
value = field.Float()
|
value = field.Bool()
|
||||||
|
}
|
||||||
|
case TypeCharField, TypeTextField:
|
||||||
|
if ns, ok := field.Interface().(sql.NullString); ok {
|
||||||
|
value = nil
|
||||||
|
if ns.Valid {
|
||||||
|
value = ns.String
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
value = field.String()
|
||||||
|
}
|
||||||
|
case TypeFloatField, TypeDecimalField:
|
||||||
|
if nf, ok := field.Interface().(sql.NullFloat64); ok {
|
||||||
|
value = nil
|
||||||
|
if nf.Valid {
|
||||||
|
value = nf.Float64
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
vu := field.Interface()
|
||||||
|
if _, ok := vu.(float32); ok {
|
||||||
|
value, _ = StrTo(ToStr(vu)).Float64()
|
||||||
|
} else {
|
||||||
|
value = field.Float()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case TypeDateField, TypeDateTimeField:
|
case TypeDateField, TypeDateTimeField:
|
||||||
value = field.Interface()
|
value = field.Interface()
|
||||||
@ -124,7 +145,14 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
|
|||||||
case fi.fieldType&IsPostiveIntegerField > 0:
|
case fi.fieldType&IsPostiveIntegerField > 0:
|
||||||
value = field.Uint()
|
value = field.Uint()
|
||||||
case fi.fieldType&IsIntegerField > 0:
|
case fi.fieldType&IsIntegerField > 0:
|
||||||
value = field.Int()
|
if ni, ok := field.Interface().(sql.NullInt64); ok {
|
||||||
|
value = nil
|
||||||
|
if ni.Valid {
|
||||||
|
value = ni.Int64
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
value = field.Int()
|
||||||
|
}
|
||||||
case fi.fieldType&IsRelField > 0:
|
case fi.fieldType&IsRelField > 0:
|
||||||
if field.IsNil() {
|
if field.IsNil() {
|
||||||
value = nil
|
value = nil
|
||||||
@ -144,6 +172,11 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
|
|||||||
switch fi.fieldType {
|
switch fi.fieldType {
|
||||||
case TypeDateField, TypeDateTimeField:
|
case TypeDateField, TypeDateTimeField:
|
||||||
if fi.auto_now || fi.auto_now_add && insert {
|
if fi.auto_now || fi.auto_now_add && insert {
|
||||||
|
if insert {
|
||||||
|
if t, ok := value.(time.Time); ok && !t.IsZero() {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
tnow := time.Now()
|
tnow := time.Now()
|
||||||
d.ins.TimeToDB(&tnow, tz)
|
d.ins.TimeToDB(&tnow, tz)
|
||||||
value = tnow
|
value = tnow
|
||||||
@ -883,13 +916,19 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri
|
|||||||
}
|
}
|
||||||
arg := params[0]
|
arg := params[0]
|
||||||
|
|
||||||
if operator == "in" {
|
switch operator {
|
||||||
|
case "in":
|
||||||
marks := make([]string, len(params))
|
marks := make([]string, len(params))
|
||||||
for i, _ := range marks {
|
for i, _ := range marks {
|
||||||
marks[i] = "?"
|
marks[i] = "?"
|
||||||
}
|
}
|
||||||
sql = fmt.Sprintf("IN (%s)", strings.Join(marks, ", "))
|
sql = fmt.Sprintf("IN (%s)", strings.Join(marks, ", "))
|
||||||
} else {
|
case "between":
|
||||||
|
if len(params) != 2 {
|
||||||
|
panic(fmt.Errorf("operator `%s` need 2 args not %d", operator, len(params)))
|
||||||
|
}
|
||||||
|
sql = "BETWEEN ? AND ?"
|
||||||
|
default:
|
||||||
if len(params) > 1 {
|
if len(params) > 1 {
|
||||||
panic(fmt.Errorf("operator `%s` need 1 args not %d", operator, len(params)))
|
panic(fmt.Errorf("operator `%s` need 1 args not %d", operator, len(params)))
|
||||||
}
|
}
|
||||||
@ -1117,17 +1156,37 @@ setValue:
|
|||||||
switch {
|
switch {
|
||||||
case fieldType == TypeBooleanField:
|
case fieldType == TypeBooleanField:
|
||||||
if isNative {
|
if isNative {
|
||||||
if value == nil {
|
if nb, ok := field.Interface().(sql.NullBool); ok {
|
||||||
value = false
|
if value == nil {
|
||||||
|
nb.Valid = false
|
||||||
|
} else {
|
||||||
|
nb.Bool = value.(bool)
|
||||||
|
nb.Valid = true
|
||||||
|
}
|
||||||
|
field.Set(reflect.ValueOf(nb))
|
||||||
|
} else {
|
||||||
|
if value == nil {
|
||||||
|
value = false
|
||||||
|
}
|
||||||
|
field.SetBool(value.(bool))
|
||||||
}
|
}
|
||||||
field.SetBool(value.(bool))
|
|
||||||
}
|
}
|
||||||
case fieldType == TypeCharField || fieldType == TypeTextField:
|
case fieldType == TypeCharField || fieldType == TypeTextField:
|
||||||
if isNative {
|
if isNative {
|
||||||
if value == nil {
|
if ns, ok := field.Interface().(sql.NullString); ok {
|
||||||
value = ""
|
if value == nil {
|
||||||
|
ns.Valid = false
|
||||||
|
} else {
|
||||||
|
ns.String = value.(string)
|
||||||
|
ns.Valid = true
|
||||||
|
}
|
||||||
|
field.Set(reflect.ValueOf(ns))
|
||||||
|
} else {
|
||||||
|
if value == nil {
|
||||||
|
value = ""
|
||||||
|
}
|
||||||
|
field.SetString(value.(string))
|
||||||
}
|
}
|
||||||
field.SetString(value.(string))
|
|
||||||
}
|
}
|
||||||
case fieldType == TypeDateField || fieldType == TypeDateTimeField:
|
case fieldType == TypeDateField || fieldType == TypeDateTimeField:
|
||||||
if isNative {
|
if isNative {
|
||||||
@ -1146,18 +1205,39 @@ setValue:
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if isNative {
|
if isNative {
|
||||||
if value == nil {
|
if ni, ok := field.Interface().(sql.NullInt64); ok {
|
||||||
value = int64(0)
|
if value == nil {
|
||||||
|
ni.Valid = false
|
||||||
|
} else {
|
||||||
|
ni.Int64 = value.(int64)
|
||||||
|
ni.Valid = true
|
||||||
|
}
|
||||||
|
field.Set(reflect.ValueOf(ni))
|
||||||
|
} else {
|
||||||
|
if value == nil {
|
||||||
|
value = int64(0)
|
||||||
|
}
|
||||||
|
field.SetInt(value.(int64))
|
||||||
}
|
}
|
||||||
field.SetInt(value.(int64))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case fieldType == TypeFloatField || fieldType == TypeDecimalField:
|
case fieldType == TypeFloatField || fieldType == TypeDecimalField:
|
||||||
if isNative {
|
if isNative {
|
||||||
if value == nil {
|
if nf, ok := field.Interface().(sql.NullFloat64); ok {
|
||||||
value = float64(0)
|
if value == nil {
|
||||||
|
nf.Valid = false
|
||||||
|
} else {
|
||||||
|
nf.Float64 = value.(float64)
|
||||||
|
nf.Valid = true
|
||||||
|
}
|
||||||
|
field.Set(reflect.ValueOf(nf))
|
||||||
|
} else {
|
||||||
|
|
||||||
|
if value == nil {
|
||||||
|
value = float64(0)
|
||||||
|
}
|
||||||
|
field.SetFloat(value.(float64))
|
||||||
}
|
}
|
||||||
field.SetFloat(value.(float64))
|
|
||||||
}
|
}
|
||||||
case fieldType&IsRelField > 0:
|
case fieldType&IsRelField > 0:
|
||||||
if value != nil {
|
if value != nil {
|
||||||
|
@ -168,7 +168,7 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if dataBaseCache.add(aliasName, al) == false {
|
if dataBaseCache.add(aliasName, al) == false {
|
||||||
return nil, fmt.Errorf("db name `%s` already registered, cannot reuse", aliasName)
|
return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName)
|
||||||
}
|
}
|
||||||
|
|
||||||
return al, nil
|
return al, nil
|
||||||
@ -239,7 +239,7 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error {
|
|||||||
if al, ok := dataBaseCache.get(aliasName); ok {
|
if al, ok := dataBaseCache.get(aliasName); ok {
|
||||||
al.TZ = tz
|
al.TZ = tz
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("DataBase name `%s` not registered\n", aliasName)
|
return fmt.Errorf("DataBase alias name `%s` not registered\n", aliasName)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -260,3 +260,19 @@ func SetMaxOpenConns(aliasName string, maxOpenConns int) {
|
|||||||
fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)})
|
fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get *sql.DB from registered database by db alias name.
|
||||||
|
// Use "default" as alias name if you not set.
|
||||||
|
func GetDB(aliasNames ...string) (*sql.DB, error) {
|
||||||
|
var name string
|
||||||
|
if len(aliasNames) > 0 {
|
||||||
|
name = aliasNames[0]
|
||||||
|
} else {
|
||||||
|
name = "default"
|
||||||
|
}
|
||||||
|
if al, ok := dataBaseCache.get(name); ok {
|
||||||
|
return al.DB, nil
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("DataBase of alias name `%s` not found\n", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -98,3 +98,9 @@ func (mc *_modelCache) clean() {
|
|||||||
mc.cacheByFN = make(map[string]*modelInfo)
|
mc.cacheByFN = make(map[string]*modelInfo)
|
||||||
mc.done = false
|
mc.done = false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clean model cache. Then you can re-RegisterModel.
|
||||||
|
// Common use this api for test case.
|
||||||
|
func ResetModelCache() {
|
||||||
|
modelCache.clean()
|
||||||
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package orm
|
package orm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
@ -116,27 +117,31 @@ type Data struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type DataNull struct {
|
type DataNull struct {
|
||||||
Id int
|
Id int
|
||||||
Boolean bool `orm:"null"`
|
Boolean bool `orm:"null"`
|
||||||
Char string `orm:"null;size(50)"`
|
Char string `orm:"null;size(50)"`
|
||||||
Text string `orm:"null;type(text)"`
|
Text string `orm:"null;type(text)"`
|
||||||
Date time.Time `orm:"null;type(date)"`
|
Date time.Time `orm:"null;type(date)"`
|
||||||
DateTime time.Time `orm:"null;column(datetime)""`
|
DateTime time.Time `orm:"null;column(datetime)""`
|
||||||
Byte byte `orm:"null"`
|
Byte byte `orm:"null"`
|
||||||
Rune rune `orm:"null"`
|
Rune rune `orm:"null"`
|
||||||
Int int `orm:"null"`
|
Int int `orm:"null"`
|
||||||
Int8 int8 `orm:"null"`
|
Int8 int8 `orm:"null"`
|
||||||
Int16 int16 `orm:"null"`
|
Int16 int16 `orm:"null"`
|
||||||
Int32 int32 `orm:"null"`
|
Int32 int32 `orm:"null"`
|
||||||
Int64 int64 `orm:"null"`
|
Int64 int64 `orm:"null"`
|
||||||
Uint uint `orm:"null"`
|
Uint uint `orm:"null"`
|
||||||
Uint8 uint8 `orm:"null"`
|
Uint8 uint8 `orm:"null"`
|
||||||
Uint16 uint16 `orm:"null"`
|
Uint16 uint16 `orm:"null"`
|
||||||
Uint32 uint32 `orm:"null"`
|
Uint32 uint32 `orm:"null"`
|
||||||
Uint64 uint64 `orm:"null"`
|
Uint64 uint64 `orm:"null"`
|
||||||
Float32 float32 `orm:"null"`
|
Float32 float32 `orm:"null"`
|
||||||
Float64 float64 `orm:"null"`
|
Float64 float64 `orm:"null"`
|
||||||
Decimal float64 `orm:"digits(8);decimals(4);null"`
|
Decimal float64 `orm:"digits(8);decimals(4);null"`
|
||||||
|
NullString sql.NullString `orm:"null"`
|
||||||
|
NullBool sql.NullBool `orm:"null"`
|
||||||
|
NullFloat64 sql.NullFloat64 `orm:"null"`
|
||||||
|
NullInt64 sql.NullInt64 `orm:"null"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// only for mysql
|
// only for mysql
|
||||||
@ -303,9 +308,8 @@ go test -v github.com/astaxie/beego/orm
|
|||||||
|
|
||||||
|
|
||||||
#### Sqlite3
|
#### Sqlite3
|
||||||
touch /path/to/orm_test.db
|
|
||||||
export ORM_DRIVER=sqlite3
|
export ORM_DRIVER=sqlite3
|
||||||
export ORM_SOURCE=/path/to/orm_test.db
|
export ORM_SOURCE='file:memory_test?mode=memory'
|
||||||
go test -v github.com/astaxie/beego/orm
|
go test -v github.com/astaxie/beego/orm
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package orm
|
package orm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
@ -98,30 +99,29 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col
|
|||||||
// return field type as type constant from reflect.Value
|
// return field type as type constant from reflect.Value
|
||||||
func getFieldType(val reflect.Value) (ft int, err error) {
|
func getFieldType(val reflect.Value) (ft int, err error) {
|
||||||
elm := reflect.Indirect(val)
|
elm := reflect.Indirect(val)
|
||||||
switch elm.Kind() {
|
switch elm.Interface().(type) {
|
||||||
case reflect.Int8:
|
case int8:
|
||||||
ft = TypeBitField
|
ft = TypeBitField
|
||||||
case reflect.Int16:
|
case int16:
|
||||||
ft = TypeSmallIntegerField
|
ft = TypeSmallIntegerField
|
||||||
case reflect.Int32, reflect.Int:
|
case int32, int:
|
||||||
ft = TypeIntegerField
|
ft = TypeIntegerField
|
||||||
case reflect.Int64:
|
case int64, sql.NullInt64:
|
||||||
ft = TypeBigIntegerField
|
ft = TypeBigIntegerField
|
||||||
case reflect.Uint8:
|
case uint8:
|
||||||
ft = TypePositiveBitField
|
ft = TypePositiveBitField
|
||||||
case reflect.Uint16:
|
case uint16:
|
||||||
ft = TypePositiveSmallIntegerField
|
ft = TypePositiveSmallIntegerField
|
||||||
case reflect.Uint32, reflect.Uint:
|
case uint32, uint:
|
||||||
ft = TypePositiveIntegerField
|
ft = TypePositiveIntegerField
|
||||||
case reflect.Uint64:
|
case uint64:
|
||||||
ft = TypePositiveBigIntegerField
|
ft = TypePositiveBigIntegerField
|
||||||
case reflect.Float32, reflect.Float64:
|
case float32, float64, sql.NullFloat64:
|
||||||
ft = TypeFloatField
|
ft = TypeFloatField
|
||||||
case reflect.Bool:
|
case bool, sql.NullBool:
|
||||||
ft = TypeBooleanField
|
ft = TypeBooleanField
|
||||||
case reflect.String:
|
case string, sql.NullString:
|
||||||
ft = TypeCharField
|
ft = TypeCharField
|
||||||
case reflect.Invalid:
|
|
||||||
default:
|
default:
|
||||||
if elm.CanInterface() {
|
if elm.CanInterface() {
|
||||||
if _, ok := elm.Interface().(time.Time); ok {
|
if _, ok := elm.Interface().(time.Time); ok {
|
||||||
|
@ -2,6 +2,7 @@ package orm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
@ -138,6 +139,15 @@ func throwFailNow(t *testing.T, err error, args ...interface{}) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetDB(t *testing.T) {
|
||||||
|
if db, err := GetDB(); err != nil {
|
||||||
|
throwFailNow(t, err)
|
||||||
|
} else {
|
||||||
|
err = db.Ping()
|
||||||
|
throwFailNow(t, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSyncDb(t *testing.T) {
|
func TestSyncDb(t *testing.T) {
|
||||||
RegisterModel(new(Data), new(DataNull))
|
RegisterModel(new(Data), new(DataNull))
|
||||||
RegisterModel(new(User))
|
RegisterModel(new(User))
|
||||||
@ -258,12 +268,45 @@ func TestNullDataTypes(t *testing.T) {
|
|||||||
err = dORM.Read(&d)
|
err = dORM.Read(&d)
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
|
|
||||||
|
throwFail(t, AssertIs(d.NullBool.Valid, false))
|
||||||
|
throwFail(t, AssertIs(d.NullString.Valid, false))
|
||||||
|
throwFail(t, AssertIs(d.NullInt64.Valid, false))
|
||||||
|
throwFail(t, AssertIs(d.NullFloat64.Valid, false))
|
||||||
|
|
||||||
_, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec()
|
_, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec()
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
|
|
||||||
d = DataNull{Id: 2}
|
d = DataNull{Id: 2}
|
||||||
err = dORM.Read(&d)
|
err = dORM.Read(&d)
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
|
|
||||||
|
d = DataNull{
|
||||||
|
DateTime: time.Now(),
|
||||||
|
NullString: sql.NullString{"test", true},
|
||||||
|
NullBool: sql.NullBool{true, true},
|
||||||
|
NullInt64: sql.NullInt64{42, true},
|
||||||
|
NullFloat64: sql.NullFloat64{42.42, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err = dORM.Insert(&d)
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(id, 3))
|
||||||
|
|
||||||
|
d = DataNull{Id: 3}
|
||||||
|
err = dORM.Read(&d)
|
||||||
|
throwFail(t, err)
|
||||||
|
|
||||||
|
throwFail(t, AssertIs(d.NullBool.Valid, true))
|
||||||
|
throwFail(t, AssertIs(d.NullBool.Bool, true))
|
||||||
|
|
||||||
|
throwFail(t, AssertIs(d.NullString.Valid, true))
|
||||||
|
throwFail(t, AssertIs(d.NullString.String, "test"))
|
||||||
|
|
||||||
|
throwFail(t, AssertIs(d.NullInt64.Valid, true))
|
||||||
|
throwFail(t, AssertIs(d.NullInt64.Int64, 42))
|
||||||
|
|
||||||
|
throwFail(t, AssertIs(d.NullFloat64.Valid, true))
|
||||||
|
throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCRUD(t *testing.T) {
|
func TestCRUD(t *testing.T) {
|
||||||
@ -619,6 +662,14 @@ func TestOperators(t *testing.T) {
|
|||||||
num, err = qs.Filter("status__in", []*int{&n1}, &n2).Count()
|
num, err = qs.Filter("status__in", []*int{&n1}, &n2).Count()
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(num, 2))
|
throwFail(t, AssertIs(num, 2))
|
||||||
|
|
||||||
|
num, err = qs.Filter("id__between", 2, 3).Count()
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 2))
|
||||||
|
|
||||||
|
num, err = qs.Filter("id__between", []int{2, 3}).Count()
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 2))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSetCond(t *testing.T) {
|
func TestSetCond(t *testing.T) {
|
||||||
@ -1577,7 +1628,6 @@ func TestDelete(t *testing.T) {
|
|||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(num, 4))
|
throwFail(t, AssertIs(num, 4))
|
||||||
|
|
||||||
fmt.Println("...")
|
|
||||||
qs = dORM.QueryTable("comment")
|
qs = dORM.QueryTable("comment")
|
||||||
num, err = qs.Filter("Post__User", 3).Delete()
|
num, err = qs.Filter("Post__User", 3).Delete()
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
@ -1646,10 +1696,10 @@ func TestTransaction(t *testing.T) {
|
|||||||
func TestReadOrCreate(t *testing.T) {
|
func TestReadOrCreate(t *testing.T) {
|
||||||
u := &User{
|
u := &User{
|
||||||
UserName: "Kyle",
|
UserName: "Kyle",
|
||||||
Email: "kylemcc@gmail.com",
|
Email: "kylemcc@gmail.com",
|
||||||
Password: "other_pass",
|
Password: "other_pass",
|
||||||
Status: 7,
|
Status: 7,
|
||||||
IsStaff: false,
|
IsStaff: false,
|
||||||
IsActive: true,
|
IsActive: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ package auth
|
|||||||
// }
|
// }
|
||||||
// return false
|
// return false
|
||||||
// }
|
// }
|
||||||
// authPlugin := auth.NewBasicAuthenticator(SecretAuth)
|
// authPlugin := auth.NewBasicAuthenticator(SecretAuth, "My Realm")
|
||||||
// beego.AddFilter("*","AfterStatic",authPlugin)
|
// beego.AddFilter("*","AfterStatic",authPlugin)
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
23
router.go
23
router.go
@ -8,6 +8,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
"path"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
@ -545,14 +546,26 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
|||||||
|
|
||||||
//static file server
|
//static file server
|
||||||
for prefix, staticDir := range StaticDir {
|
for prefix, staticDir := range StaticDir {
|
||||||
|
if len(prefix) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if r.URL.Path == "/favicon.ico" {
|
if r.URL.Path == "/favicon.ico" {
|
||||||
file := staticDir + r.URL.Path
|
file := path.Join(staticDir, r.URL.Path)
|
||||||
http.ServeFile(w, r, file)
|
if utils.FileExists(file) {
|
||||||
w.started = true
|
http.ServeFile(w, r, file)
|
||||||
goto Admin
|
w.started = true
|
||||||
|
goto Admin
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(r.URL.Path, prefix) {
|
if strings.HasPrefix(r.URL.Path, prefix) {
|
||||||
file := staticDir + r.URL.Path[len(prefix):]
|
if len(r.URL.Path) > len(prefix) && r.URL.Path[len(prefix)] != '/' {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if r.URL.Path == prefix && prefix[len(prefix)-1] != '/' {
|
||||||
|
http.Redirect(rw, r, r.URL.Path+"/", 302)
|
||||||
|
goto Admin
|
||||||
|
}
|
||||||
|
file := path.Join(staticDir, r.URL.Path[len(prefix):])
|
||||||
finfo, err := os.Stat(file)
|
finfo, err := os.Stat(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if RunMode == "dev" {
|
if RunMode == "dev" {
|
||||||
|
203
session/sess_couchbase.go
Normal file
203
session/sess_couchbase.go
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/couchbaselabs/go-couchbase"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
var couchbpder = &CouchbaseProvider{}
|
||||||
|
|
||||||
|
type CouchbaseSessionStore struct {
|
||||||
|
b *couchbase.Bucket
|
||||||
|
sid string
|
||||||
|
lock sync.RWMutex
|
||||||
|
values map[interface{}]interface{}
|
||||||
|
maxlifetime int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type CouchbaseProvider struct {
|
||||||
|
maxlifetime int64
|
||||||
|
savePath string
|
||||||
|
pool string
|
||||||
|
bucket string
|
||||||
|
b *couchbase.Bucket
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *CouchbaseSessionStore) Set(key, value interface{}) error {
|
||||||
|
cs.lock.Lock()
|
||||||
|
defer cs.lock.Unlock()
|
||||||
|
cs.values[key] = value
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *CouchbaseSessionStore) Get(key interface{}) interface{} {
|
||||||
|
cs.lock.RLock()
|
||||||
|
defer cs.lock.RUnlock()
|
||||||
|
if v, ok := cs.values[key]; ok {
|
||||||
|
return v
|
||||||
|
} else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *CouchbaseSessionStore) Delete(key interface{}) error {
|
||||||
|
cs.lock.Lock()
|
||||||
|
defer cs.lock.Unlock()
|
||||||
|
delete(cs.values, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *CouchbaseSessionStore) Flush() error {
|
||||||
|
cs.lock.Lock()
|
||||||
|
defer cs.lock.Unlock()
|
||||||
|
cs.values = make(map[interface{}]interface{})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *CouchbaseSessionStore) SessionID() string {
|
||||||
|
return cs.sid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *CouchbaseSessionStore) SessionRelease(w http.ResponseWriter) {
|
||||||
|
defer cs.b.Close()
|
||||||
|
|
||||||
|
// if rs.values is empty, return directly
|
||||||
|
if len(cs.values) < 1 {
|
||||||
|
cs.b.Delete(cs.sid)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
bo, err := encodeGob(cs.values)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cs.b.Set(cs.sid, int(cs.maxlifetime), bo)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cp *CouchbaseProvider) getBucket() *couchbase.Bucket {
|
||||||
|
c, err := couchbase.Connect(cp.savePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pool, err := c.GetPool(cp.pool)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
bucket, err := pool.GetBucket(cp.bucket)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return bucket
|
||||||
|
}
|
||||||
|
|
||||||
|
// init couchbase session
|
||||||
|
// savepath like couchbase server REST/JSON URL
|
||||||
|
// e.g. http://host:port/, Pool, Bucket
|
||||||
|
func (cp *CouchbaseProvider) SessionInit(maxlifetime int64, savePath string) error {
|
||||||
|
cp.maxlifetime = maxlifetime
|
||||||
|
configs := strings.Split(savePath, ",")
|
||||||
|
if len(configs) > 0 {
|
||||||
|
cp.savePath = configs[0]
|
||||||
|
}
|
||||||
|
if len(configs) > 1 {
|
||||||
|
cp.pool = configs[1]
|
||||||
|
}
|
||||||
|
if len(configs) > 2 {
|
||||||
|
cp.bucket = configs[2]
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// read couchbase session by sid
|
||||||
|
func (cp *CouchbaseProvider) SessionRead(sid string) (SessionStore, error) {
|
||||||
|
cp.b = cp.getBucket()
|
||||||
|
|
||||||
|
var doc []byte
|
||||||
|
|
||||||
|
err := cp.b.Get(sid, &doc)
|
||||||
|
var kv map[interface{}]interface{}
|
||||||
|
if doc == nil {
|
||||||
|
kv = make(map[interface{}]interface{})
|
||||||
|
} else {
|
||||||
|
kv, err = decodeGob(doc)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cs := &CouchbaseSessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime}
|
||||||
|
return cs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cp *CouchbaseProvider) SessionExist(sid string) bool {
|
||||||
|
cp.b = cp.getBucket()
|
||||||
|
defer cp.b.Close()
|
||||||
|
|
||||||
|
var doc []byte
|
||||||
|
|
||||||
|
if err := cp.b.Get(sid, &doc); err != nil || doc == nil {
|
||||||
|
return false
|
||||||
|
} else {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cp *CouchbaseProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
|
||||||
|
cp.b = cp.getBucket()
|
||||||
|
|
||||||
|
var doc []byte
|
||||||
|
if err := cp.b.Get(oldsid, &doc); err != nil || doc == nil {
|
||||||
|
cp.b.Set(sid, int(cp.maxlifetime), "")
|
||||||
|
} else {
|
||||||
|
err := cp.b.Delete(oldsid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_, _ = cp.b.Add(sid, int(cp.maxlifetime), doc)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := cp.b.Get(sid, &doc)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var kv map[interface{}]interface{}
|
||||||
|
if doc == nil {
|
||||||
|
kv = make(map[interface{}]interface{})
|
||||||
|
} else {
|
||||||
|
kv, err = decodeGob(doc)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cs := &CouchbaseSessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime}
|
||||||
|
return cs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cp *CouchbaseProvider) SessionDestroy(sid string) error {
|
||||||
|
cp.b = cp.getBucket()
|
||||||
|
defer cp.b.Close()
|
||||||
|
|
||||||
|
cp.b.Delete(sid)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cp *CouchbaseProvider) SessionGC() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cp *CouchbaseProvider) SessionAll() int {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
Register("couchbase", couchbpder)
|
||||||
|
}
|
@ -152,8 +152,7 @@ func (fp *FileProvider) SessionExist(sid string) bool {
|
|||||||
func (fp *FileProvider) SessionDestroy(sid string) error {
|
func (fp *FileProvider) SessionDestroy(sid string) error {
|
||||||
filepder.lock.Lock()
|
filepder.lock.Lock()
|
||||||
defer filepder.lock.Unlock()
|
defer filepder.lock.Unlock()
|
||||||
|
os.Remove(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
|
||||||
os.Remove(path.Join(fp.savePath))
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -129,7 +129,8 @@ func (rp *RedisProvider) SessionInit(maxlifetime int64, savePath string) error {
|
|||||||
}
|
}
|
||||||
return c, err
|
return c, err
|
||||||
}, rp.poolsize)
|
}, rp.poolsize)
|
||||||
return nil
|
|
||||||
|
return rp.poollist.Get().Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
// read redis session by sid
|
// read redis session by sid
|
||||||
|
@ -56,11 +56,10 @@ type managerConfig struct {
|
|||||||
EnableSetCookie bool `json:"enableSetCookie,omitempty"`
|
EnableSetCookie bool `json:"enableSetCookie,omitempty"`
|
||||||
Gclifetime int64 `json:"gclifetime"`
|
Gclifetime int64 `json:"gclifetime"`
|
||||||
Maxlifetime int64 `json:"maxLifetime"`
|
Maxlifetime int64 `json:"maxLifetime"`
|
||||||
Maxage int `json:"maxage"`
|
|
||||||
Secure bool `json:"secure"`
|
Secure bool `json:"secure"`
|
||||||
SessionIDHashFunc string `json:"sessionIDHashFunc"`
|
SessionIDHashFunc string `json:"sessionIDHashFunc"`
|
||||||
SessionIDHashKey string `json:"sessionIDHashKey"`
|
SessionIDHashKey string `json:"sessionIDHashKey"`
|
||||||
CookieLifeTime int64 `json:"cookieLifeTime"`
|
CookieLifeTime int `json:"cookieLifeTime"`
|
||||||
ProviderConfig string `json:"providerConfig"`
|
ProviderConfig string `json:"providerConfig"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -125,8 +124,8 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
|
|||||||
Path: "/",
|
Path: "/",
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: manager.config.Secure}
|
Secure: manager.config.Secure}
|
||||||
if manager.config.Maxage >= 0 {
|
if manager.config.CookieLifeTime >= 0 {
|
||||||
cookie.MaxAge = manager.config.Maxage
|
cookie.MaxAge = manager.config.CookieLifeTime
|
||||||
}
|
}
|
||||||
if manager.config.EnableSetCookie {
|
if manager.config.EnableSetCookie {
|
||||||
http.SetCookie(w, cookie)
|
http.SetCookie(w, cookie)
|
||||||
@ -144,8 +143,8 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
|
|||||||
Path: "/",
|
Path: "/",
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: manager.config.Secure}
|
Secure: manager.config.Secure}
|
||||||
if manager.config.Maxage >= 0 {
|
if manager.config.CookieLifeTime >= 0 {
|
||||||
cookie.MaxAge = manager.config.Maxage
|
cookie.MaxAge = manager.config.CookieLifeTime
|
||||||
}
|
}
|
||||||
if manager.config.EnableSetCookie {
|
if manager.config.EnableSetCookie {
|
||||||
http.SetCookie(w, cookie)
|
http.SetCookie(w, cookie)
|
||||||
@ -206,8 +205,8 @@ func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Reque
|
|||||||
cookie.HttpOnly = true
|
cookie.HttpOnly = true
|
||||||
cookie.Path = "/"
|
cookie.Path = "/"
|
||||||
}
|
}
|
||||||
if manager.config.Maxage >= 0 {
|
if manager.config.CookieLifeTime >= 0 {
|
||||||
cookie.MaxAge = manager.config.Maxage
|
cookie.MaxAge = manager.config.CookieLifeTime
|
||||||
}
|
}
|
||||||
http.SetCookie(w, cookie)
|
http.SetCookie(w, cookie)
|
||||||
r.AddCookie(cookie)
|
r.AddCookie(cookie)
|
||||||
|
@ -67,7 +67,7 @@ const (
|
|||||||
fieldIdName = "captcha_id"
|
fieldIdName = "captcha_id"
|
||||||
fieldCaptchaName = "captcha"
|
fieldCaptchaName = "captcha"
|
||||||
cachePrefix = "captcha_"
|
cachePrefix = "captcha_"
|
||||||
urlPrefix = "/captcha/"
|
defaultURLPrefix = "/captcha/"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Captcha struct
|
// Captcha struct
|
||||||
@ -76,7 +76,7 @@ type Captcha struct {
|
|||||||
store cache.Cache
|
store cache.Cache
|
||||||
|
|
||||||
// url prefix for captcha image
|
// url prefix for captcha image
|
||||||
urlPrefix string
|
URLPrefix string
|
||||||
|
|
||||||
// specify captcha id input field name
|
// specify captcha id input field name
|
||||||
FieldIdName string
|
FieldIdName string
|
||||||
@ -155,7 +155,7 @@ func (c *Captcha) CreateCaptchaHtml() template.HTML {
|
|||||||
return template.HTML(fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`+
|
return template.HTML(fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`+
|
||||||
`<a class="captcha" href="javascript:">`+
|
`<a class="captcha" href="javascript:">`+
|
||||||
`<img onclick="this.src=('%s%s.png?reload='+(new Date()).getTime())" class="captcha-img" src="%s%s.png">`+
|
`<img onclick="this.src=('%s%s.png?reload='+(new Date()).getTime())" class="captcha-img" src="%s%s.png">`+
|
||||||
`</a>`, c.FieldIdName, value, c.urlPrefix, value, c.urlPrefix, value))
|
`</a>`, c.FieldIdName, value, c.URLPrefix, value, c.URLPrefix, value))
|
||||||
}
|
}
|
||||||
|
|
||||||
// create a new captcha id
|
// create a new captcha id
|
||||||
@ -224,14 +224,14 @@ func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha {
|
|||||||
cpt.StdHeight = stdHeight
|
cpt.StdHeight = stdHeight
|
||||||
|
|
||||||
if len(urlPrefix) == 0 {
|
if len(urlPrefix) == 0 {
|
||||||
urlPrefix = urlPrefix
|
urlPrefix = defaultURLPrefix
|
||||||
}
|
}
|
||||||
|
|
||||||
if urlPrefix[len(urlPrefix)-1] != '/' {
|
if urlPrefix[len(urlPrefix)-1] != '/' {
|
||||||
urlPrefix += "/"
|
urlPrefix += "/"
|
||||||
}
|
}
|
||||||
|
|
||||||
cpt.urlPrefix = urlPrefix
|
cpt.URLPrefix = urlPrefix
|
||||||
|
|
||||||
return cpt
|
return cpt
|
||||||
}
|
}
|
||||||
@ -242,7 +242,7 @@ func NewWithFilter(urlPrefix string, store cache.Cache) *Captcha {
|
|||||||
cpt := NewCaptcha(urlPrefix, store)
|
cpt := NewCaptcha(urlPrefix, store)
|
||||||
|
|
||||||
// create filter for serve captcha image
|
// create filter for serve captcha image
|
||||||
beego.AddFilter(urlPrefix+":", "BeforeRouter", cpt.Handler)
|
beego.AddFilter(cpt.URLPrefix+":", "BeforeRouter", cpt.Handler)
|
||||||
|
|
||||||
// add to template func map
|
// add to template func map
|
||||||
beego.AddFuncMap("create_captcha", cpt.CreateCaptchaHtml)
|
beego.AddFuncMap("create_captcha", cpt.CreateCaptchaHtml)
|
||||||
|
@ -443,7 +443,7 @@ func (b Base64) GetLimitValue() interface{} {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// just for chinese mobile phone number
|
// just for chinese mobile phone number
|
||||||
var mobilePattern = regexp.MustCompile("^((\\+86)|(86))?(1(([35][0-9])|(47)|[8][01236789]))\\d{8}$")
|
var mobilePattern = regexp.MustCompile("^((\\+86)|(86))?(1(([35][0-9])|(47)|[8][012356789]))\\d{8}$")
|
||||||
|
|
||||||
type Mobile struct {
|
type Mobile struct {
|
||||||
Match
|
Match
|
||||||
|
Loading…
Reference in New Issue
Block a user