1
0
mirror of https://github.com/astaxie/beego.git synced 2024-06-30 15:24:13 +00:00

Merge branch 'release/release1.1.0'

This commit is contained in:
asta.xie 2014-02-10 11:39:03 +08:00
commit 92196c602b
90 changed files with 4274 additions and 831 deletions

View File

@ -35,9 +35,5 @@ More info [beego.me](http://beego.me)
beego is licensed under the Apache Licence, Version 2.0 beego is licensed under the Apache Licence, Version 2.0
(http://www.apache.org/licenses/LICENSE-2.0.html). (http://www.apache.org/licenses/LICENSE-2.0.html).
[![Clone in Koding](http://learn.koding.com/btn/clone_d.png)][koding]
## Use case [koding]: https://koding.com/Teamwork?import=https://github.com/astaxie/beego/archive/master.zip&c=git1
- Displaying API documentation: [gowalker](https://github.com/Unknwon/gowalker)
- seocms: [seocms](https://github.com/chinakr/seocms)
- CMS: [toropress](https://github.com/insionng/toropress)

8
app.go
View File

@ -118,6 +118,14 @@ func (app *App) AutoRouter(c ControllerInterface) *App {
return app return app
} }
// AutoRouterWithPrefix adds beego-defined controller handler with prefix.
// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page,
// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function.
func (app *App) AutoRouterWithPrefix(prefix string, c ControllerInterface) *App {
app.Handlers.AddAutoPrefix(prefix, c)
return app
}
// UrlFor creates a url with another registered controller handler with params. // UrlFor creates a url with another registered controller handler with params.
// The endpoint is formed as path.controller.name to defined the controller method which will run. // The endpoint is formed as path.controller.name to defined the controller method which will run.
// The values need key-pair data to assign into controller method. // The values need key-pair data to assign into controller method.

128
beego.go
View File

@ -4,6 +4,7 @@ import (
"net/http" "net/http"
"path" "path"
"path/filepath" "path/filepath"
"strconv"
"strings" "strings"
"github.com/astaxie/beego/middleware" "github.com/astaxie/beego/middleware"
@ -11,7 +12,77 @@ import (
) )
// beego web framework version. // beego web framework version.
const VERSION = "1.0.1" const VERSION = "1.1.0"
type hookfunc func() error //hook function to run
var hooks []hookfunc //hook function slice to store the hookfunc
type groupRouter struct {
pattern string
controller ControllerInterface
mappingMethods string
}
// RouterGroups which will store routers
type GroupRouters []groupRouter
// Get a new GroupRouters
func NewGroupRouters() GroupRouters {
return make([]groupRouter, 0)
}
// Add Router in the GroupRouters
// it is for plugin or module to register router
func (gr GroupRouters) AddRouter(pattern string, c ControllerInterface, mappingMethod ...string) {
var newRG groupRouter
if len(mappingMethod) > 0 {
newRG = groupRouter{
pattern,
c,
mappingMethod[0],
}
} else {
newRG = groupRouter{
pattern,
c,
"",
}
}
gr = append(gr, newRG)
}
func (gr GroupRouters) AddAuto(c ControllerInterface) {
newRG := groupRouter{
"",
c,
"",
}
gr = append(gr, newRG)
}
// AddGroupRouter with the prefix
// it will register the router in BeeApp
// the follow code is write in modules:
// GR:=NewGroupRouters()
// GR.AddRouter("/login",&UserController,"get:Login")
// GR.AddRouter("/logout",&UserController,"get:Logout")
// GR.AddRouter("/register",&UserController,"get:Reg")
// the follow code is write in app:
// import "github.com/beego/modules/auth"
// AddRouterGroup("/admin", auth.GR)
func AddGroupRouter(prefix string, groups GroupRouters) *App {
for _, v := range groups {
if v.pattern == "" {
BeeApp.AutoRouterWithPrefix(prefix, v.controller)
} else if v.mappingMethods != "" {
BeeApp.Router(prefix+v.pattern, v.controller, v.mappingMethods)
} else {
BeeApp.Router(prefix+v.pattern, v.controller)
}
}
return BeeApp
}
// Router adds a patterned controller handler to BeeApp. // Router adds a patterned controller handler to BeeApp.
// it's an alias method of App.Router. // it's an alias method of App.Router.
@ -36,6 +107,13 @@ func AutoRouter(c ControllerInterface) *App {
return BeeApp return BeeApp
} }
// AutoPrefix adds controller handler to BeeApp with prefix.
// it's same to App.AutoRouterWithPrefix.
func AutoPrefix(prefix string, c ControllerInterface) *App {
BeeApp.AutoRouterWithPrefix(prefix, c)
return BeeApp
}
// ErrorHandler registers http.HandlerFunc to each http err code string. // ErrorHandler registers http.HandlerFunc to each http err code string.
// usage: // usage:
// beego.ErrorHandler("404",NotFound) // beego.ErrorHandler("404",NotFound)
@ -87,6 +165,12 @@ func InsertFilter(pattern string, pos int, filter FilterFunc) *App {
return BeeApp return BeeApp
} }
// The hookfunc will run in beego.Run()
// such as sessionInit, middlerware start, buildtemplate, admin start
func AddAPPStartHook(hf hookfunc) {
hooks = append(hooks, hf)
}
// Run beego application. // Run beego application.
// it's alias of App.Run. // it's alias of App.Run.
func Run() { func Run() {
@ -99,18 +183,32 @@ func Run() {
} }
} }
//init mime // do hooks function
initMime() for _, hk := range hooks {
err := hk()
if err != nil {
panic(err)
}
}
if SessionOn { if SessionOn {
GlobalSessions, _ = session.NewManager(SessionProvider, var err error
SessionName, sessionConfig := AppConfig.String("sessionConfig")
SessionGCMaxLifetime, if sessionConfig == "" {
SessionSavePath, sessionConfig = `{"cookieName":"` + SessionName + `",` +
HttpTLS, `"gclifetime":` + strconv.FormatInt(SessionGCMaxLifetime, 10) + `,` +
SessionHashFunc, `"providerConfig":"` + SessionSavePath + `",` +
SessionHashKey, `"secure":` + strconv.FormatBool(HttpTLS) + `,` +
SessionCookieLifeTime) `"sessionIDHashFunc":"` + SessionHashFunc + `",` +
`"sessionIDHashKey":"` + SessionHashKey + `",` +
`"enableSetCookie":` + strconv.FormatBool(SessionAutoSetCookie) + `,` +
`"cookieLifeTime":` + strconv.Itoa(SessionCookieLifeTime) + `}`
}
GlobalSessions, err = session.NewManager(SessionProvider,
sessionConfig)
if err != nil {
panic(err)
}
go GlobalSessions.GC() go GlobalSessions.GC()
} }
@ -123,7 +221,7 @@ func Run() {
middleware.VERSION = VERSION middleware.VERSION = VERSION
middleware.AppName = AppName middleware.AppName = AppName
middleware.RegisterErrorHander() middleware.RegisterErrorHandler()
if EnableAdmin { if EnableAdmin {
go BeeAdminApp.Run() go BeeAdminApp.Run()
@ -131,3 +229,9 @@ func Run() {
BeeApp.Run() BeeApp.Run()
} }
func init() {
hooks = make([]hookfunc, 0)
//init mime
AddAPPStartHook(initMime)
}

2
cache/README.md vendored
View File

@ -43,7 +43,7 @@ interval means the gc time. The cache will check at each time interval, whether
## Memcache adapter ## Memcache adapter
memory adapter use the vitess's [Memcache](http://code.google.com/p/vitess/go/memcache) client. Memcache adapter use the vitess's [Memcache](http://code.google.com/p/vitess/go/memcache) client.
Configure like this: Configure like this:

50
cache/cache_test.go vendored
View File

@ -5,7 +5,7 @@ import (
"time" "time"
) )
func Test_cache(t *testing.T) { func TestCache(t *testing.T) {
bm, err := NewCache("memory", `{"interval":20}`) bm, err := NewCache("memory", `{"interval":20}`)
if err != nil { if err != nil {
t.Error("init err") t.Error("init err")
@ -51,3 +51,51 @@ func Test_cache(t *testing.T) {
t.Error("delete err") t.Error("delete err")
} }
} }
func TestFileCache(t *testing.T) {
bm, err := NewCache("file", `{"CachePath":"/cache","FileSuffix":".bin","DirectoryLevel":2,"EmbedExpiry":0}`)
if err != nil {
t.Error("init err")
}
if err = bm.Put("astaxie", 1, 10); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
t.Error("check err")
}
if v := bm.Get("astaxie"); v.(int) != 1 {
t.Error("get err")
}
if err = bm.Incr("astaxie"); err != nil {
t.Error("Incr Error", err)
}
if v := bm.Get("astaxie"); v.(int) != 2 {
t.Error("get err")
}
if err = bm.Decr("astaxie"); err != nil {
t.Error("Incr Error", err)
}
if v := bm.Get("astaxie"); v.(int) != 1 {
t.Error("get err")
}
bm.Delete("astaxie")
if bm.IsExist("astaxie") {
t.Error("delete err")
}
//test string
if err = bm.Put("astaxie", "author", 10); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
t.Error("check err")
}
if v := bm.Get("astaxie"); v.(string) != "author" {
t.Error("get err")
}
}

10
cache/file.go vendored
View File

@ -61,6 +61,7 @@ func (this *FileCache) StartAndGC(config string) error {
var cfg map[string]string var cfg map[string]string
json.Unmarshal([]byte(config), &cfg) json.Unmarshal([]byte(config), &cfg)
//fmt.Println(cfg) //fmt.Println(cfg)
//fmt.Println(config)
if _, ok := cfg["CachePath"]; !ok { if _, ok := cfg["CachePath"]; !ok {
cfg["CachePath"] = FileCachePath cfg["CachePath"] = FileCachePath
} }
@ -135,7 +136,7 @@ func (this *FileCache) Get(key string) interface{} {
return "" return ""
} }
var to FileCacheItem var to FileCacheItem
Gob_decode([]byte(filedata), &to) Gob_decode(filedata, &to)
if to.Expired < time.Now().Unix() { if to.Expired < time.Now().Unix() {
return "" return ""
} }
@ -177,7 +178,7 @@ func (this *FileCache) Delete(key string) error {
func (this *FileCache) Incr(key string) error { func (this *FileCache) Incr(key string) error {
data := this.Get(key) data := this.Get(key)
var incr int var incr int
fmt.Println(reflect.TypeOf(data).Name()) //fmt.Println(reflect.TypeOf(data).Name())
if reflect.TypeOf(data).Name() != "int" { if reflect.TypeOf(data).Name() != "int" {
incr = 0 incr = 0
} else { } else {
@ -210,8 +211,7 @@ func (this *FileCache) IsExist(key string) bool {
// Clean cached files. // Clean cached files.
// not implemented. // not implemented.
func (this *FileCache) ClearAll() error { func (this *FileCache) ClearAll() error {
//this.CachePath .递归删除 //this.CachePath
return nil return nil
} }
@ -271,7 +271,7 @@ func Gob_encode(data interface{}) ([]byte, error) {
} }
// Gob decodes file cache item. // Gob decodes file cache item.
func Gob_decode(data []byte, to interface{}) error { func Gob_decode(data []byte, to *FileCacheItem) error {
buf := bytes.NewBuffer(data) buf := bytes.NewBuffer(data)
dec := gob.NewDecoder(buf) dec := gob.NewDecoder(buf)
return dec.Decode(&to) return dec.Decode(&to)

41
cache/memcache.go vendored
View File

@ -21,7 +21,11 @@ func NewMemCache() *MemcacheCache {
// get value from memcache. // get value from memcache.
func (rc *MemcacheCache) Get(key string) interface{} { func (rc *MemcacheCache) Get(key string) interface{} {
if rc.c == nil { if rc.c == nil {
rc.c = rc.connectInit() var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
} }
v, err := rc.c.Get(key) v, err := rc.c.Get(key)
if err != nil { if err != nil {
@ -39,7 +43,11 @@ func (rc *MemcacheCache) Get(key string) interface{} {
// put value to memcache. only support string. // put value to memcache. only support string.
func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error { func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error {
if rc.c == nil { if rc.c == nil {
rc.c = rc.connectInit() var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
} }
v, ok := val.(string) v, ok := val.(string)
if !ok { if !ok {
@ -55,7 +63,11 @@ func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error {
// delete value in memcache. // delete value in memcache.
func (rc *MemcacheCache) Delete(key string) error { func (rc *MemcacheCache) Delete(key string) error {
if rc.c == nil { if rc.c == nil {
rc.c = rc.connectInit() var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
} }
_, err := rc.c.Delete(key) _, err := rc.c.Delete(key)
return err return err
@ -76,7 +88,11 @@ func (rc *MemcacheCache) Decr(key string) error {
// check value exists in memcache. // check value exists in memcache.
func (rc *MemcacheCache) IsExist(key string) bool { func (rc *MemcacheCache) IsExist(key string) bool {
if rc.c == nil { if rc.c == nil {
rc.c = rc.connectInit() var err error
rc.c, err = rc.connectInit()
if err != nil {
return false
}
} }
v, err := rc.c.Get(key) v, err := rc.c.Get(key)
if err != nil { if err != nil {
@ -93,7 +109,11 @@ func (rc *MemcacheCache) IsExist(key string) bool {
// clear all cached in memcache. // clear all cached in memcache.
func (rc *MemcacheCache) ClearAll() error { func (rc *MemcacheCache) ClearAll() error {
if rc.c == nil { if rc.c == nil {
rc.c = rc.connectInit() var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
} }
err := rc.c.FlushAll() err := rc.c.FlushAll()
return err return err
@ -109,20 +129,21 @@ func (rc *MemcacheCache) StartAndGC(config string) error {
return errors.New("config has no conn key") return errors.New("config has no conn key")
} }
rc.conninfo = cf["conn"] rc.conninfo = cf["conn"]
rc.c = rc.connectInit() var err error
if rc.c == nil { rc.c, err = rc.connectInit()
if err != nil {
return errors.New("dial tcp conn error") return errors.New("dial tcp conn error")
} }
return nil return nil
} }
// connect to memcache and keep the connection. // connect to memcache and keep the connection.
func (rc *MemcacheCache) connectInit() *memcache.Connection { func (rc *MemcacheCache) connectInit() (*memcache.Connection, error) {
c, err := memcache.Connect(rc.conninfo) c, err := memcache.Connect(rc.conninfo)
if err != nil { if err != nil {
return nil return nil, err
} }
return c return c, nil
} }
func init() { func init() {

118
cache/redis.go vendored
View File

@ -3,6 +3,7 @@ package cache
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"time"
"github.com/beego/redigo/redis" "github.com/beego/redigo/redis"
) )
@ -14,7 +15,7 @@ var (
// Redis cache adapter. // Redis cache adapter.
type RedisCache struct { type RedisCache struct {
c redis.Conn p *redis.Pool // redis connection pool
conninfo string conninfo string
key string key string
} }
@ -24,107 +25,62 @@ func NewRedisCache() *RedisCache {
return &RedisCache{key: DefaultKey} return &RedisCache{key: DefaultKey}
} }
// actually do the redis cmds
func (rc *RedisCache) do(commandName string, args ...interface{}) (reply interface{}, err error) {
c := rc.p.Get()
defer c.Close()
return c.Do(commandName, args...)
}
// Get cache from redis. // Get cache from redis.
func (rc *RedisCache) Get(key string) interface{} { func (rc *RedisCache) Get(key string) interface{} {
if rc.c == nil { v, err := rc.do("HGET", rc.key, key)
var err error
rc.c, err = rc.connectInit()
if err != nil {
return nil
}
}
v, err := rc.c.Do("HGET", rc.key, key)
if err != nil { if err != nil {
return nil return nil
} }
return v return v
} }
// put cache to redis. // put cache to redis.
// timeout is ignored. // timeout is ignored.
func (rc *RedisCache) Put(key string, val interface{}, timeout int64) error { func (rc *RedisCache) Put(key string, val interface{}, timeout int64) error {
if rc.c == nil { _, err := rc.do("HSET", rc.key, key, val)
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
_, err := rc.c.Do("HSET", rc.key, key, val)
return err return err
} }
// delete cache in redis. // delete cache in redis.
func (rc *RedisCache) Delete(key string) error { func (rc *RedisCache) Delete(key string) error {
if rc.c == nil { _, err := rc.do("HDEL", rc.key, key)
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
_, err := rc.c.Do("HDEL", rc.key, key)
return err return err
} }
// check cache exist in redis. // check cache exist in redis.
func (rc *RedisCache) IsExist(key string) bool { func (rc *RedisCache) IsExist(key string) bool {
if rc.c == nil { v, err := redis.Bool(rc.do("HEXISTS", rc.key, key))
var err error
rc.c, err = rc.connectInit()
if err != nil {
return false
}
}
v, err := redis.Bool(rc.c.Do("HEXISTS", rc.key, key))
if err != nil { if err != nil {
return false return false
} }
return v return v
} }
// increase counter in redis. // increase counter in redis.
func (rc *RedisCache) Incr(key string) error { func (rc *RedisCache) Incr(key string) error {
if rc.c == nil { _, err := redis.Bool(rc.do("HINCRBY", rc.key, key, 1))
var err error return err
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
_, err := redis.Bool(rc.c.Do("HINCRBY", rc.key, key, 1))
if err != nil {
return err
}
return nil
} }
// decrease counter in redis. // decrease counter in redis.
func (rc *RedisCache) Decr(key string) error { func (rc *RedisCache) Decr(key string) error {
if rc.c == nil { _, err := redis.Bool(rc.do("HINCRBY", rc.key, key, -1))
var err error return err
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
_, err := redis.Bool(rc.c.Do("HINCRBY", rc.key, key, -1))
if err != nil {
return err
}
return nil
} }
// clean all cache in redis. delete this redis collection. // clean all cache in redis. delete this redis collection.
func (rc *RedisCache) ClearAll() error { func (rc *RedisCache) ClearAll() error {
if rc.c == nil { _, err := rc.do("DEL", rc.key)
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
_, err := rc.c.Do("DEL", rc.key)
return err return err
} }
@ -135,32 +91,42 @@ func (rc *RedisCache) ClearAll() error {
func (rc *RedisCache) StartAndGC(config string) error { func (rc *RedisCache) StartAndGC(config string) error {
var cf map[string]string var cf map[string]string
json.Unmarshal([]byte(config), &cf) json.Unmarshal([]byte(config), &cf)
if _, ok := cf["key"]; !ok { if _, ok := cf["key"]; !ok {
cf["key"] = DefaultKey cf["key"] = DefaultKey
} }
if _, ok := cf["conn"]; !ok { if _, ok := cf["conn"]; !ok {
return errors.New("config has no conn key") return errors.New("config has no conn key")
} }
rc.key = cf["key"] rc.key = cf["key"]
rc.conninfo = cf["conn"] rc.conninfo = cf["conn"]
var err error rc.connectInit()
rc.c, err = rc.connectInit()
if err != nil { c := rc.p.Get()
defer c.Close()
if err := c.Err(); err != nil {
return err return err
} }
if rc.c == nil {
return errors.New("dial tcp conn error")
}
return nil return nil
} }
// connect to redis. // connect to redis.
func (rc *RedisCache) connectInit() (redis.Conn, error) { func (rc *RedisCache) connectInit() {
c, err := redis.Dial("tcp", rc.conninfo) // initialize a new pool
if err != nil { rc.p = &redis.Pool{
return nil, err MaxIdle: 3,
IdleTimeout: 180 * time.Second,
Dial: func() (redis.Conn, error) {
c, err := redis.Dial("tcp", rc.conninfo)
if err != nil {
return nil, err
}
return c, nil
},
} }
return c, nil
} }
func init() { func init() {

View File

@ -40,6 +40,7 @@ var (
SessionHashFunc string // session hash generation func. SessionHashFunc string // session hash generation func.
SessionHashKey string // session hash salt string. SessionHashKey string // session hash salt string.
SessionCookieLifeTime int // the life time of session id in cookie. SessionCookieLifeTime int // the life time of session id in cookie.
SessionAutoSetCookie bool // auto setcookie
UseFcgi bool UseFcgi bool
MaxMemory int64 MaxMemory int64
EnableGzip bool // flag of enable gzip EnableGzip bool // flag of enable gzip
@ -96,6 +97,7 @@ func init() {
SessionHashFunc = "sha1" SessionHashFunc = "sha1"
SessionHashKey = "beegoserversessionkey" SessionHashKey = "beegoserversessionkey"
SessionCookieLifeTime = 0 //set cookie default is the brower life SessionCookieLifeTime = 0 //set cookie default is the brower life
SessionAutoSetCookie = true
UseFcgi = false UseFcgi = false
@ -139,6 +141,7 @@ func init() {
func ParseConfig() (err error) { func ParseConfig() (err error) {
AppConfig, err = config.NewConfig("ini", AppConfigPath) AppConfig, err = config.NewConfig("ini", AppConfigPath)
if err != nil { if err != nil {
AppConfig = config.NewFakeConfig()
return err return err
} else { } else {
HttpAddr = AppConfig.String("HttpAddr") HttpAddr = AppConfig.String("HttpAddr")

View File

@ -6,8 +6,9 @@ import (
// ConfigContainer defines how to get and set value from configuration raw data. // ConfigContainer defines how to get and set value from configuration raw data.
type ConfigContainer interface { type ConfigContainer interface {
Set(key, val string) error // support section::key type in given key when using ini type. Set(key, val string) error // support section::key type in given key when using ini type.
String(key string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. String(key string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
Strings(key string) []string //get string slice
Int(key string) (int, error) Int(key string) (int, error)
Int64(key string) (int64, error) Int64(key string) (int64, error)
Bool(key string) (bool, error) Bool(key string) (bool, error)

62
config/fake.go Normal file
View File

@ -0,0 +1,62 @@
package config
import (
"errors"
"strconv"
"strings"
)
type fakeConfigContainer struct {
data map[string]string
}
func (c *fakeConfigContainer) getData(key string) string {
key = strings.ToLower(key)
return c.data[key]
}
func (c *fakeConfigContainer) Set(key, val string) error {
key = strings.ToLower(key)
c.data[key] = val
return nil
}
func (c *fakeConfigContainer) String(key string) string {
return c.getData(key)
}
func (c *fakeConfigContainer) Strings(key string) []string {
return strings.Split(c.getData(key), ";")
}
func (c *fakeConfigContainer) Int(key string) (int, error) {
return strconv.Atoi(c.getData(key))
}
func (c *fakeConfigContainer) Int64(key string) (int64, error) {
return strconv.ParseInt(c.getData(key), 10, 64)
}
func (c *fakeConfigContainer) Bool(key string) (bool, error) {
return strconv.ParseBool(c.getData(key))
}
func (c *fakeConfigContainer) Float(key string) (float64, error) {
return strconv.ParseFloat(c.getData(key), 64)
}
func (c *fakeConfigContainer) DIY(key string) (interface{}, error) {
key = strings.ToLower(key)
if v, ok := c.data[key]; ok {
return v, nil
}
return nil, errors.New("key not find")
}
var _ ConfigContainer = new(fakeConfigContainer)
func NewFakeConfig() ConfigContainer {
return &fakeConfigContainer{
data: make(map[string]string),
}
}

View File

@ -146,6 +146,11 @@ func (c *IniConfigContainer) String(key string) string {
return c.getdata(key) return c.getdata(key)
} }
// Strings returns the []string value for a given key.
func (c *IniConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";")
}
// WriteValue writes a new value for key. // WriteValue writes a new value for key.
// if write to one section, the key need be "section::key". // if write to one section, the key need be "section::key".
// if the section is not existed, it panics. // if the section is not existed, it panics.

View File

@ -19,6 +19,7 @@ copyrequestbody = true
key1="asta" key1="asta"
key2 = "xie" key2 = "xie"
CaseInsensitive = true CaseInsensitive = true
peers = one;two;three
` `
func TestIni(t *testing.T) { func TestIni(t *testing.T) {
@ -78,4 +79,11 @@ func TestIni(t *testing.T) {
if v, err := iniconf.Bool("demo::caseinsensitive"); err != nil || v != true { if v, err := iniconf.Bool("demo::caseinsensitive"); err != nil || v != true {
t.Fatal("get demo.caseinsensitive error") t.Fatal("get demo.caseinsensitive error")
} }
if data := iniconf.Strings("demo::peers"); len(data) != 3 {
t.Fatal("get strings error", data)
} else if data[0] != "one" {
t.Fatal("get first params error not equat to one")
}
} }

View File

@ -116,6 +116,11 @@ func (c *JsonConfigContainer) String(key string) string {
return "" return ""
} }
// Strings returns the []string value for a given key.
func (c *JsonConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";")
}
// WriteValue writes a new value for key. // WriteValue writes a new value for key.
func (c *JsonConfigContainer) Set(key, val string) error { func (c *JsonConfigContainer) Set(key, val string) error {
c.Lock() c.Lock()

View File

@ -5,6 +5,7 @@ import (
"io/ioutil" "io/ioutil"
"os" "os"
"strconv" "strconv"
"strings"
"sync" "sync"
"github.com/beego/x2j" "github.com/beego/x2j"
@ -72,6 +73,11 @@ func (c *XMLConfigContainer) String(key string) string {
return "" return ""
} }
// Strings returns the []string value for a given key.
func (c *XMLConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";")
}
// WriteValue writes a new value for key. // WriteValue writes a new value for key.
func (c *XMLConfigContainer) Set(key, val string) error { func (c *XMLConfigContainer) Set(key, val string) error {
c.Lock() c.Lock()

View File

@ -7,6 +7,7 @@ import (
"io/ioutil" "io/ioutil"
"log" "log"
"os" "os"
"strings"
"sync" "sync"
"github.com/beego/goyaml2" "github.com/beego/goyaml2"
@ -117,6 +118,11 @@ func (c *YAMLConfigContainer) String(key string) string {
return "" return ""
} }
// Strings returns the []string value for a given key.
func (c *YAMLConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";")
}
// WriteValue writes a new value for key. // WriteValue writes a new value for key.
func (c *YAMLConfigContainer) Set(key, val string) error { func (c *YAMLConfigContainer) Set(key, val string) error {
c.Lock() c.Lock()

View File

@ -3,7 +3,6 @@ package beego
import ( import (
"bytes" "bytes"
"crypto/hmac" "crypto/hmac"
"crypto/rand"
"crypto/sha1" "crypto/sha1"
"encoding/base64" "encoding/base64"
"errors" "errors"
@ -22,6 +21,7 @@ import (
"github.com/astaxie/beego/context" "github.com/astaxie/beego/context"
"github.com/astaxie/beego/session" "github.com/astaxie/beego/session"
"github.com/astaxie/beego/utils"
) )
var ( var (
@ -45,6 +45,7 @@ type Controller struct {
CruSession session.SessionStore CruSession session.SessionStore
XSRFExpire int XSRFExpire int
AppController interface{} AppController interface{}
EnableReander bool
} }
// ControllerInterface is an interface to uniform all controller handler. // ControllerInterface is an interface to uniform all controller handler.
@ -74,6 +75,8 @@ func (c *Controller) Init(ctx *context.Context, controllerName, actionName strin
c.Ctx = ctx c.Ctx = ctx
c.TplExt = "tpl" c.TplExt = "tpl"
c.AppController = app c.AppController = app
c.EnableReander = true
c.Data = ctx.Input.Data
} }
// Prepare runs after Init before request function execution. // Prepare runs after Init before request function execution.
@ -123,6 +126,9 @@ func (c *Controller) Options() {
// Render sends the response with rendered template bytes as text/html type. // Render sends the response with rendered template bytes as text/html type.
func (c *Controller) Render() error { func (c *Controller) Render() error {
if !c.EnableReander {
return nil
}
rb, err := c.RenderBytes() rb, err := c.RenderBytes()
if err != nil { if err != nil {
@ -140,7 +146,7 @@ func (c *Controller) RenderString() (string, error) {
return string(b), e return string(b), e
} }
// RenderBytes returns the bytes of renderd tempate string. Do not send out response. // RenderBytes returns the bytes of rendered template string. Do not send out response.
func (c *Controller) RenderBytes() ([]byte, error) { func (c *Controller) RenderBytes() ([]byte, error) {
//if the controller has set layout, then first get the tplname's content set the content to the layout //if the controller has set layout, then first get the tplname's content set the content to the layout
if c.Layout != "" { if c.Layout != "" {
@ -165,7 +171,7 @@ func (c *Controller) RenderBytes() ([]byte, error) {
if c.LayoutSections != nil { if c.LayoutSections != nil {
for sectionName, sectionTpl := range c.LayoutSections { for sectionName, sectionTpl := range c.LayoutSections {
if (sectionTpl == "") { if sectionTpl == "" {
c.Data[sectionName] = "" c.Data[sectionName] = ""
continue continue
} }
@ -391,12 +397,14 @@ func (c *Controller) DelSession(name interface{}) {
// SessionRegenerateID regenerates session id for this session. // SessionRegenerateID regenerates session id for this session.
// the session data have no changes. // the session data have no changes.
func (c *Controller) SessionRegenerateID() { func (c *Controller) SessionRegenerateID() {
c.CruSession.SessionRelease(c.Ctx.ResponseWriter)
c.CruSession = GlobalSessions.SessionRegenerateId(c.Ctx.ResponseWriter, c.Ctx.Request) c.CruSession = GlobalSessions.SessionRegenerateId(c.Ctx.ResponseWriter, c.Ctx.Request)
c.Ctx.Input.CruSession = c.CruSession c.Ctx.Input.CruSession = c.CruSession
} }
// DestroySession cleans session data and session cookie. // DestroySession cleans session data and session cookie.
func (c *Controller) DestroySession() { func (c *Controller) DestroySession() {
c.Ctx.Input.CruSession.Flush()
GlobalSessions.SessionDestroy(c.Ctx.ResponseWriter, c.Ctx.Request) GlobalSessions.SessionDestroy(c.Ctx.ResponseWriter, c.Ctx.Request)
} }
@ -454,7 +462,7 @@ func (c *Controller) XsrfToken() string {
} else { } else {
expire = int64(XSRFExpire) expire = int64(XSRFExpire)
} }
token = getRandomString(15) token = string(utils.RandomCreateBytes(15))
c.SetSecureCookie(XSRFKEY, "_xsrf", token, expire) c.SetSecureCookie(XSRFKEY, "_xsrf", token, expire)
} }
c._xsrf_token = token c._xsrf_token = token
@ -491,14 +499,3 @@ func (c *Controller) XsrfFormHtml() string {
func (c *Controller) GetControllerAndAction() (controllerName, actionName string) { func (c *Controller) GetControllerAndAction() (controllerName, actionName string) {
return c.controllerName, c.actionName return c.controllerName, c.actionName
} }
// getRandomString returns random string.
func getRandomString(n int) string {
const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
var bytes = make([]byte, n)
rand.Read(bytes)
for i, b := range bytes {
bytes[i] = alphanum[b%byte(len(alphanum))]
}
return string(bytes)
}

View File

@ -1,12 +1,13 @@
package controllers package controllers
import ( import (
"github.com/astaxie/beego"
"github.com/garyburd/go-websocket/websocket"
"io/ioutil" "io/ioutil"
"math/rand" "math/rand"
"net/http" "net/http"
"time" "time"
"github.com/astaxie/beego"
"github.com/gorilla/websocket"
) )
const ( const (

View File

@ -28,6 +28,12 @@ func (mr *FilterRouter) ValidRouter(router string) (bool, map[string]string) {
if router == mr.pattern { if router == mr.pattern {
return true, nil return true, nil
} }
//pattern /admin router /admin/ match
//pattern /admin/ router /admin don't match, because url will 301 in router
if n := len(router); n > 1 && router[n-1] == '/' && router[:n-2] == mr.pattern {
return true, nil
}
if mr.hasregex { if mr.hasregex {
if !mr.regex.MatchString(router) { if !mr.regex.MatchString(router) {
return false, nil return false, nil
@ -46,7 +52,7 @@ func (mr *FilterRouter) ValidRouter(router string) (bool, map[string]string) {
return false, nil return false, nil
} }
func buildFilter(pattern string, filter FilterFunc) *FilterRouter { func buildFilter(pattern string, filter FilterFunc) (*FilterRouter, error) {
mr := new(FilterRouter) mr := new(FilterRouter)
mr.params = make(map[int]string) mr.params = make(map[int]string)
mr.filterFunc = filter mr.filterFunc = filter
@ -54,7 +60,7 @@ func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
j := 0 j := 0
for i, part := range parts { for i, part := range parts {
if strings.HasPrefix(part, ":") { if strings.HasPrefix(part, ":") {
expr := "(.+)" expr := "(.*)"
//a user may choose to override the default expression //a user may choose to override the default expression
// similar to expressjs: /user/:id([0-9]+) // similar to expressjs: /user/:id([0-9]+)
if index := strings.Index(part, "("); index != -1 { if index := strings.Index(part, "("); index != -1 {
@ -77,7 +83,7 @@ func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
j++ j++
} }
if strings.HasPrefix(part, "*") { if strings.HasPrefix(part, "*") {
expr := "(.+)" expr := "(.*)"
if part == "*.*" { if part == "*.*" {
mr.params[j] = ":path" mr.params[j] = ":path"
parts[i] = "([^.]+).([^.]+)" parts[i] = "([^.]+).([^.]+)"
@ -137,12 +143,11 @@ func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
pattern = strings.Join(parts, "/") pattern = strings.Join(parts, "/")
regex, regexErr := regexp.Compile(pattern) regex, regexErr := regexp.Compile(pattern)
if regexErr != nil { if regexErr != nil {
//TODO add error handling here to avoid panic return nil, regexErr
panic(regexErr)
} }
mr.regex = regex mr.regex = regex
mr.hasregex = true mr.hasregex = true
} }
mr.pattern = pattern mr.pattern = pattern
return mr return mr, nil
} }

View File

@ -23,3 +23,32 @@ func TestFilter(t *testing.T) {
t.Errorf("user define func can't run") t.Errorf("user define func can't run")
} }
} }
var FilterAdminUser = func(ctx *context.Context) {
ctx.Output.Body([]byte("i am admin"))
}
// Filter pattern /admin/:all
// all url like /admin/ /admin/xie will all get filter
func TestPatternTwo(t *testing.T) {
r, _ := http.NewRequest("GET", "/admin/", nil)
w := httptest.NewRecorder()
handler := NewControllerRegistor()
handler.AddFilter("/admin/:all", "AfterStatic", FilterAdminUser)
handler.ServeHTTP(w, r)
if w.Body.String() != "i am admin" {
t.Errorf("filter /admin/ can't run")
}
}
func TestPatternThree(t *testing.T) {
r, _ := http.NewRequest("GET", "/admin/astaxie", nil)
w := httptest.NewRecorder()
handler := NewControllerRegistor()
handler.AddFilter("/admin/:all", "AfterStatic", FilterAdminUser)
handler.ServeHTTP(w, r)
if w.Body.String() != "i am admin" {
t.Errorf("filter /admin/astaxie can't run")
}
}

View File

@ -7,6 +7,8 @@ import (
"net" "net"
) )
// ConnWriter implements LoggerInterface.
// it writes messages in keep-live tcp connection.
type ConnWriter struct { type ConnWriter struct {
lg *log.Logger lg *log.Logger
innerWriter io.WriteCloser innerWriter io.WriteCloser
@ -17,12 +19,15 @@ type ConnWriter struct {
Level int `json:"level"` Level int `json:"level"`
} }
// create new ConnWrite returning as LoggerInterface.
func NewConn() LoggerInterface { func NewConn() LoggerInterface {
conn := new(ConnWriter) conn := new(ConnWriter)
conn.Level = LevelTrace conn.Level = LevelTrace
return conn return conn
} }
// init connection writer with json config.
// json config only need key "level".
func (c *ConnWriter) Init(jsonconfig string) error { func (c *ConnWriter) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), c) err := json.Unmarshal([]byte(jsonconfig), c)
if err != nil { if err != nil {
@ -31,6 +36,8 @@ func (c *ConnWriter) Init(jsonconfig string) error {
return nil return nil
} }
// write message in connection.
// if connection is down, try to re-connect.
func (c *ConnWriter) WriteMsg(msg string, level int) error { func (c *ConnWriter) WriteMsg(msg string, level int) error {
if level < c.Level { if level < c.Level {
return nil return nil
@ -49,10 +56,12 @@ func (c *ConnWriter) WriteMsg(msg string, level int) error {
return nil return nil
} }
// implementing method. empty.
func (c *ConnWriter) Flush() { func (c *ConnWriter) Flush() {
} }
// destroy connection writer and close tcp listener.
func (c *ConnWriter) Destroy() { func (c *ConnWriter) Destroy() {
if c.innerWriter == nil { if c.innerWriter == nil {
return return

View File

@ -4,13 +4,35 @@ import (
"encoding/json" "encoding/json"
"log" "log"
"os" "os"
"runtime"
) )
type Brush func(string) string
func NewBrush(color string) Brush {
pre := "\033["
reset := "\033[0m"
return func(text string) string {
return pre + color + "m" + text + reset
}
}
var colors = []Brush{
NewBrush("1;36"), // Trace cyan
NewBrush("1;34"), // Debug blue
NewBrush("1;32"), // Info green
NewBrush("1;33"), // Warn yellow
NewBrush("1;31"), // Error red
NewBrush("1;35"), // Critical purple
}
// ConsoleWriter implements LoggerInterface and writes messages to terminal.
type ConsoleWriter struct { type ConsoleWriter struct {
lg *log.Logger lg *log.Logger
Level int `json:"level"` Level int `json:"level"`
} }
// create ConsoleWriter returning as LoggerInterface.
func NewConsole() LoggerInterface { func NewConsole() LoggerInterface {
cw := new(ConsoleWriter) cw := new(ConsoleWriter)
cw.lg = log.New(os.Stdout, "", log.Ldate|log.Ltime) cw.lg = log.New(os.Stdout, "", log.Ldate|log.Ltime)
@ -18,6 +40,8 @@ func NewConsole() LoggerInterface {
return cw return cw
} }
// init console logger.
// jsonconfig like '{"level":LevelTrace}'.
func (c *ConsoleWriter) Init(jsonconfig string) error { func (c *ConsoleWriter) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), c) err := json.Unmarshal([]byte(jsonconfig), c)
if err != nil { if err != nil {
@ -26,18 +50,25 @@ func (c *ConsoleWriter) Init(jsonconfig string) error {
return nil return nil
} }
// write message in console.
func (c *ConsoleWriter) WriteMsg(msg string, level int) error { func (c *ConsoleWriter) WriteMsg(msg string, level int) error {
if level < c.Level { if level < c.Level {
return nil return nil
} }
c.lg.Println(msg) if goos := runtime.GOOS; goos == "windows" {
c.lg.Println(msg)
} else {
c.lg.Println(colors[level](msg))
}
return nil return nil
} }
// implementing method. empty.
func (c *ConsoleWriter) Destroy() { func (c *ConsoleWriter) Destroy() {
} }
// implementing method. empty.
func (c *ConsoleWriter) Flush() { func (c *ConsoleWriter) Flush() {
} }

View File

@ -13,6 +13,8 @@ import (
"time" "time"
) )
// FileLogWriter implements LoggerInterface.
// It writes messages by lines limit, file size limit, or time frequency.
type FileLogWriter struct { type FileLogWriter struct {
*log.Logger *log.Logger
mw *MuxWriter mw *MuxWriter
@ -38,17 +40,20 @@ type FileLogWriter struct {
Level int `json:"level"` Level int `json:"level"`
} }
// an *os.File writer with locker.
type MuxWriter struct { type MuxWriter struct {
sync.Mutex sync.Mutex
fd *os.File fd *os.File
} }
// write to os.File.
func (l *MuxWriter) Write(b []byte) (int, error) { func (l *MuxWriter) Write(b []byte) (int, error) {
l.Lock() l.Lock()
defer l.Unlock() defer l.Unlock()
return l.fd.Write(b) return l.fd.Write(b)
} }
// set os.File in writer.
func (l *MuxWriter) SetFd(fd *os.File) { func (l *MuxWriter) SetFd(fd *os.File) {
if l.fd != nil { if l.fd != nil {
l.fd.Close() l.fd.Close()
@ -56,6 +61,7 @@ func (l *MuxWriter) SetFd(fd *os.File) {
l.fd = fd l.fd = fd
} }
// create a FileLogWriter returning as LoggerInterface.
func NewFileWriter() LoggerInterface { func NewFileWriter() LoggerInterface {
w := &FileLogWriter{ w := &FileLogWriter{
Filename: "", Filename: "",
@ -73,15 +79,16 @@ func NewFileWriter() LoggerInterface {
return w return w
} }
// jsonconfig like this // Init file logger with json config.
//{ // jsonconfig like:
// {
// "filename":"logs/beego.log", // "filename":"logs/beego.log",
// "maxlines":10000, // "maxlines":10000,
// "maxsize":1<<30, // "maxsize":1<<30,
// "daily":true, // "daily":true,
// "maxdays":15, // "maxdays":15,
// "rotate":true // "rotate":true
//} // }
func (w *FileLogWriter) Init(jsonconfig string) error { func (w *FileLogWriter) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), w) err := json.Unmarshal([]byte(jsonconfig), w)
if err != nil { if err != nil {
@ -94,6 +101,7 @@ func (w *FileLogWriter) Init(jsonconfig string) error {
return err return err
} }
// start file logger. create log file and set to locker-inside file writer.
func (w *FileLogWriter) StartLogger() error { func (w *FileLogWriter) StartLogger() error {
fd, err := w.createLogFile() fd, err := w.createLogFile()
if err != nil { if err != nil {
@ -122,6 +130,7 @@ func (w *FileLogWriter) docheck(size int) {
w.maxsize_cursize += size w.maxsize_cursize += size
} }
// write logger message into file.
func (w *FileLogWriter) WriteMsg(msg string, level int) error { func (w *FileLogWriter) WriteMsg(msg string, level int) error {
if level < w.Level { if level < w.Level {
return nil return nil
@ -158,6 +167,8 @@ func (w *FileLogWriter) initFd() error {
return nil return nil
} }
// DoRotate means it need to write file in new file.
// new file name like xx.log.2013-01-01.2
func (w *FileLogWriter) DoRotate() error { func (w *FileLogWriter) DoRotate() error {
_, err := os.Lstat(w.Filename) _, err := os.Lstat(w.Filename)
if err == nil { // file exists if err == nil { // file exists
@ -211,10 +222,14 @@ func (w *FileLogWriter) deleteOldLog() {
}) })
} }
// destroy file logger, close file writer.
func (w *FileLogWriter) Destroy() { func (w *FileLogWriter) Destroy() {
w.mw.fd.Close() w.mw.fd.Close()
} }
// flush file logger.
// there are no buffering messages in file logger in memory.
// flush file means sync file from disk.
func (w *FileLogWriter) Flush() { func (w *FileLogWriter) Flush() {
w.mw.fd.Sync() w.mw.fd.Sync()
} }

View File

@ -6,6 +6,7 @@ import (
) )
const ( const (
// log message levels
LevelTrace = iota LevelTrace = iota
LevelDebug LevelDebug
LevelInfo LevelInfo
@ -16,6 +17,7 @@ const (
type loggerType func() LoggerInterface type loggerType func() LoggerInterface
// LoggerInterface defines the behavior of a log provider.
type LoggerInterface interface { type LoggerInterface interface {
Init(config string) error Init(config string) error
WriteMsg(msg string, level int) error WriteMsg(msg string, level int) error
@ -38,6 +40,8 @@ func Register(name string, log loggerType) {
adapters[name] = log adapters[name] = log
} }
// BeeLogger is default logger in beego application.
// it can contain several providers and log message into all providers.
type BeeLogger struct { type BeeLogger struct {
lock sync.Mutex lock sync.Mutex
level int level int
@ -50,7 +54,9 @@ type logMsg struct {
msg string msg string
} }
// config need to be correct JSON as string: {"interval":360} // NewLogger returns a new BeeLogger.
// channellen means the number of messages in chan.
// if the buffering chan is full, logger adapters write to file or other way.
func NewLogger(channellen int64) *BeeLogger { func NewLogger(channellen int64) *BeeLogger {
bl := new(BeeLogger) bl := new(BeeLogger)
bl.msg = make(chan *logMsg, channellen) bl.msg = make(chan *logMsg, channellen)
@ -60,6 +66,8 @@ func NewLogger(channellen int64) *BeeLogger {
return bl return bl
} }
// SetLogger provides a given logger adapter into BeeLogger with config string.
// config need to be correct JSON as string: {"interval":360}.
func (bl *BeeLogger) SetLogger(adaptername string, config string) error { func (bl *BeeLogger) SetLogger(adaptername string, config string) error {
bl.lock.Lock() bl.lock.Lock()
defer bl.lock.Unlock() defer bl.lock.Unlock()
@ -73,6 +81,7 @@ func (bl *BeeLogger) SetLogger(adaptername string, config string) error {
} }
} }
// remove a logger adapter in BeeLogger.
func (bl *BeeLogger) DelLogger(adaptername string) error { func (bl *BeeLogger) DelLogger(adaptername string) error {
bl.lock.Lock() bl.lock.Lock()
defer bl.lock.Unlock() defer bl.lock.Unlock()
@ -96,10 +105,14 @@ func (bl *BeeLogger) writerMsg(loglevel int, msg string) error {
return nil return nil
} }
// set log message level.
// if message level (such as LevelTrace) is less than logger level (such as LevelWarn), ignore message.
func (bl *BeeLogger) SetLevel(l int) { func (bl *BeeLogger) SetLevel(l int) {
bl.level = l bl.level = l
} }
// start logger chan reading.
// when chan is full, write logs.
func (bl *BeeLogger) StartLogger() { func (bl *BeeLogger) StartLogger() {
for { for {
select { select {
@ -111,43 +124,50 @@ func (bl *BeeLogger) StartLogger() {
} }
} }
// log trace level message.
func (bl *BeeLogger) Trace(format string, v ...interface{}) { func (bl *BeeLogger) Trace(format string, v ...interface{}) {
msg := fmt.Sprintf("[T] "+format, v...) msg := fmt.Sprintf("[T] "+format, v...)
bl.writerMsg(LevelTrace, msg) bl.writerMsg(LevelTrace, msg)
} }
// log debug level message.
func (bl *BeeLogger) Debug(format string, v ...interface{}) { func (bl *BeeLogger) Debug(format string, v ...interface{}) {
msg := fmt.Sprintf("[D] "+format, v...) msg := fmt.Sprintf("[D] "+format, v...)
bl.writerMsg(LevelDebug, msg) bl.writerMsg(LevelDebug, msg)
} }
// log info level message.
func (bl *BeeLogger) Info(format string, v ...interface{}) { func (bl *BeeLogger) Info(format string, v ...interface{}) {
msg := fmt.Sprintf("[I] "+format, v...) msg := fmt.Sprintf("[I] "+format, v...)
bl.writerMsg(LevelInfo, msg) bl.writerMsg(LevelInfo, msg)
} }
// log warn level message.
func (bl *BeeLogger) Warn(format string, v ...interface{}) { func (bl *BeeLogger) Warn(format string, v ...interface{}) {
msg := fmt.Sprintf("[W] "+format, v...) msg := fmt.Sprintf("[W] "+format, v...)
bl.writerMsg(LevelWarn, msg) bl.writerMsg(LevelWarn, msg)
} }
// log error level message.
func (bl *BeeLogger) Error(format string, v ...interface{}) { func (bl *BeeLogger) Error(format string, v ...interface{}) {
msg := fmt.Sprintf("[E] "+format, v...) msg := fmt.Sprintf("[E] "+format, v...)
bl.writerMsg(LevelError, msg) bl.writerMsg(LevelError, msg)
} }
// log critical level message.
func (bl *BeeLogger) Critical(format string, v ...interface{}) { func (bl *BeeLogger) Critical(format string, v ...interface{}) {
msg := fmt.Sprintf("[C] "+format, v...) msg := fmt.Sprintf("[C] "+format, v...)
bl.writerMsg(LevelCritical, msg) bl.writerMsg(LevelCritical, msg)
} }
//flush all chan data // flush all chan data.
func (bl *BeeLogger) Flush() { func (bl *BeeLogger) Flush() {
for _, l := range bl.outputs { for _, l := range bl.outputs {
l.Flush() l.Flush()
} }
} }
// close logger, flush all chan data and destroy all adapters in BeeLogger.
func (bl *BeeLogger) Close() { func (bl *BeeLogger) Close() {
for { for {
if len(bl.msg) > 0 { if len(bl.msg) > 0 {

View File

@ -12,7 +12,7 @@ const (
subjectPhrase = "Diagnostic message from server" subjectPhrase = "Diagnostic message from server"
) )
// smtpWriter is used to send emails via given SMTP-server. // smtpWriter implements LoggerInterface and is used to send emails via given SMTP-server.
type SmtpWriter struct { type SmtpWriter struct {
Username string `json:"Username"` Username string `json:"Username"`
Password string `json:"password"` Password string `json:"password"`
@ -22,10 +22,21 @@ type SmtpWriter struct {
Level int `json:"level"` Level int `json:"level"`
} }
// create smtp writer.
func NewSmtpWriter() LoggerInterface { func NewSmtpWriter() LoggerInterface {
return &SmtpWriter{Level: LevelTrace} return &SmtpWriter{Level: LevelTrace}
} }
// init smtp writer with json config.
// config like:
// {
// "Username":"example@gmail.com",
// "password:"password",
// "host":"smtp.gmail.com:465",
// "subject":"email title",
// "sendTos":["email1","email2"],
// "level":LevelError
// }
func (s *SmtpWriter) Init(jsonconfig string) error { func (s *SmtpWriter) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), s) err := json.Unmarshal([]byte(jsonconfig), s)
if err != nil { if err != nil {
@ -34,6 +45,8 @@ func (s *SmtpWriter) Init(jsonconfig string) error {
return nil return nil
} }
// write message in smtp writer.
// it will send an email with subject and only this message.
func (s *SmtpWriter) WriteMsg(msg string, level int) error { func (s *SmtpWriter) WriteMsg(msg string, level int) error {
if level < s.Level { if level < s.Level {
return nil return nil
@ -65,9 +78,12 @@ func (s *SmtpWriter) WriteMsg(msg string, level int) error {
return err return err
} }
// implementing method. empty.
func (s *SmtpWriter) Flush() { func (s *SmtpWriter) Flush() {
return return
} }
// implementing method. empty.
func (s *SmtpWriter) Destroy() { func (s *SmtpWriter) Destroy() {
return return
} }

View File

@ -5,16 +5,17 @@ import (
"compress/flate" "compress/flate"
"compress/gzip" "compress/gzip"
"errors" "errors"
//"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"os" "os"
"strings" "strings"
"sync"
"time" "time"
) )
var gmfim map[string]*MemFileInfo = make(map[string]*MemFileInfo) var gmfim map[string]*MemFileInfo = make(map[string]*MemFileInfo)
var lock sync.RWMutex
// OpenMemZipFile returns MemFile object with a compressed static file. // OpenMemZipFile returns MemFile object with a compressed static file.
// it's used for serve static file if gzip enable. // it's used for serve static file if gzip enable.
@ -32,12 +33,12 @@ func OpenMemZipFile(path string, zip string) (*MemFile, error) {
modtime := osfileinfo.ModTime() modtime := osfileinfo.ModTime()
fileSize := osfileinfo.Size() fileSize := osfileinfo.Size()
lock.RLock()
cfi, ok := gmfim[zip+":"+path] cfi, ok := gmfim[zip+":"+path]
lock.RUnlock()
if ok && cfi.ModTime() == modtime && cfi.fileSize == fileSize { if ok && cfi.ModTime() == modtime && cfi.fileSize == fileSize {
//fmt.Printf("read %s file %s from cache\n", zip, path)
} else { } else {
//fmt.Printf("NOT read %s file %s from cache\n", zip, path)
var content []byte var content []byte
if zip == "gzip" { if zip == "gzip" {
//将文件内容压缩到zipbuf中 //将文件内容压缩到zipbuf中
@ -81,8 +82,9 @@ func OpenMemZipFile(path string, zip string) (*MemFile, error) {
} }
cfi = &MemFileInfo{osfileinfo, modtime, content, int64(len(content)), fileSize} cfi = &MemFileInfo{osfileinfo, modtime, content, int64(len(content)), fileSize}
lock.Lock()
defer lock.Unlock()
gmfim[zip+":"+path] = cfi gmfim[zip+":"+path] = cfi
//fmt.Printf("%s file %s to %d, cache it\n", zip, path, len(content))
} }
return &MemFile{fi: cfi, offset: 0}, nil return &MemFile{fi: cfi, offset: 0}, nil
} }

View File

@ -61,6 +61,7 @@ var tpl = `
</html> </html>
` `
// render default application error page with error and stack string.
func ShowErr(err interface{}, rw http.ResponseWriter, r *http.Request, Stack string) { func ShowErr(err interface{}, rw http.ResponseWriter, r *http.Request, Stack string) {
t, _ := template.New("beegoerrortemp").Parse(tpl) t, _ := template.New("beegoerrortemp").Parse(tpl)
data := make(map[string]string) data := make(map[string]string)
@ -71,6 +72,7 @@ func ShowErr(err interface{}, rw http.ResponseWriter, r *http.Request, Stack str
data["Stack"] = Stack data["Stack"] = Stack
data["BeegoVersion"] = VERSION data["BeegoVersion"] = VERSION
data["GoVersion"] = runtime.Version() data["GoVersion"] = runtime.Version()
rw.WriteHeader(500)
t.Execute(rw, data) t.Execute(rw, data)
} }
@ -174,18 +176,19 @@ var errtpl = `
</html> </html>
` `
// map of http handlers for each error string.
var ErrorMaps map[string]http.HandlerFunc var ErrorMaps map[string]http.HandlerFunc
func init() { func init() {
ErrorMaps = make(map[string]http.HandlerFunc) ErrorMaps = make(map[string]http.HandlerFunc)
} }
//404 // show 404 notfound error.
func NotFound(rw http.ResponseWriter, r *http.Request) { func NotFound(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{}) data := make(map[string]interface{})
data["Title"] = "Page Not Found" data["Title"] = "Page Not Found"
data["Content"] = template.HTML("<br>The Page You have requested flown the coop." + data["Content"] = template.HTML("<br>The page you have requested has flown the coop." +
"<br>Perhaps you are here because:" + "<br>Perhaps you are here because:" +
"<br><br><ul>" + "<br><br><ul>" +
"<br>The page has moved" + "<br>The page has moved" +
@ -198,28 +201,28 @@ func NotFound(rw http.ResponseWriter, r *http.Request) {
t.Execute(rw, data) t.Execute(rw, data)
} }
//401 // show 401 unauthorized error.
func Unauthorized(rw http.ResponseWriter, r *http.Request) { func Unauthorized(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{}) data := make(map[string]interface{})
data["Title"] = "Unauthorized" data["Title"] = "Unauthorized"
data["Content"] = template.HTML("<br>The Page You have requested can't authorized." + data["Content"] = template.HTML("<br>The page you have requested can't be authorized." +
"<br>Perhaps you are here because:" + "<br>Perhaps you are here because:" +
"<br><br><ul>" + "<br><br><ul>" +
"<br>Check the credentials that you supplied" + "<br>The credentials you supplied are incorrect" +
"<br>Check the address for errors" + "<br>There are errors in the website address" +
"</ul>") "</ul>")
data["BeegoVersion"] = VERSION data["BeegoVersion"] = VERSION
//rw.WriteHeader(http.StatusUnauthorized) //rw.WriteHeader(http.StatusUnauthorized)
t.Execute(rw, data) t.Execute(rw, data)
} }
//403 // show 403 forbidden error.
func Forbidden(rw http.ResponseWriter, r *http.Request) { func Forbidden(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{}) data := make(map[string]interface{})
data["Title"] = "Forbidden" data["Title"] = "Forbidden"
data["Content"] = template.HTML("<br>The Page You have requested forbidden." + data["Content"] = template.HTML("<br>The page you have requested is forbidden." +
"<br>Perhaps you are here because:" + "<br>Perhaps you are here because:" +
"<br><br><ul>" + "<br><br><ul>" +
"<br>Your address may be blocked" + "<br>Your address may be blocked" +
@ -231,12 +234,12 @@ func Forbidden(rw http.ResponseWriter, r *http.Request) {
t.Execute(rw, data) t.Execute(rw, data)
} }
//503 // show 503 service unavailable error.
func ServiceUnavailable(rw http.ResponseWriter, r *http.Request) { func ServiceUnavailable(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{}) data := make(map[string]interface{})
data["Title"] = "Service Unavailable" data["Title"] = "Service Unavailable"
data["Content"] = template.HTML("<br>The Page You have requested unavailable." + data["Content"] = template.HTML("<br>The page you have requested is unavailable." +
"<br>Perhaps you are here because:" + "<br>Perhaps you are here because:" +
"<br><br><ul>" + "<br><br><ul>" +
"<br><br>The page is overloaded" + "<br><br>The page is overloaded" +
@ -247,30 +250,32 @@ func ServiceUnavailable(rw http.ResponseWriter, r *http.Request) {
t.Execute(rw, data) t.Execute(rw, data)
} }
//500 // show 500 internal server error.
func InternalServerError(rw http.ResponseWriter, r *http.Request) { func InternalServerError(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{}) data := make(map[string]interface{})
data["Title"] = "Internal Server Error" data["Title"] = "Internal Server Error"
data["Content"] = template.HTML("<br>The Page You have requested has down now." + data["Content"] = template.HTML("<br>The page you have requested is down right now." +
"<br><br><ul>" + "<br><br><ul>" +
"<br>simply try again later" + "<br>Please try again later and report the error to the website administrator" +
"<br>you should report the fault to the website administrator" + "<br></ul>")
"</ul>")
data["BeegoVersion"] = VERSION data["BeegoVersion"] = VERSION
//rw.WriteHeader(http.StatusInternalServerError) //rw.WriteHeader(http.StatusInternalServerError)
t.Execute(rw, data) t.Execute(rw, data)
} }
// show 500 internal error with simple text string.
func SimpleServerError(rw http.ResponseWriter, r *http.Request) { func SimpleServerError(rw http.ResponseWriter, r *http.Request) {
http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
} }
// add http handler for given error string.
func Errorhandler(err string, h http.HandlerFunc) { func Errorhandler(err string, h http.HandlerFunc) {
ErrorMaps[err] = h ErrorMaps[err] = h
} }
func RegisterErrorHander() { // register default error http handlers, 404,401,403,500 and 503.
func RegisterErrorHandler() {
if _, ok := ErrorMaps["404"]; !ok { if _, ok := ErrorMaps["404"]; !ok {
ErrorMaps["404"] = NotFound ErrorMaps["404"] = NotFound
} }
@ -292,6 +297,8 @@ func RegisterErrorHander() {
} }
} }
// show error string as simple text message.
// if error string is empty, show 500 error as default.
func Exception(errcode string, w http.ResponseWriter, r *http.Request, msg string) { func Exception(errcode string, w http.ResponseWriter, r *http.Request, msg string) {
if h, ok := ErrorMaps[errcode]; ok { if h, ok := ErrorMaps[errcode]; ok {
isint, err := strconv.Atoi(errcode) isint, err := strconv.Atoi(errcode)

View File

@ -2,16 +2,19 @@ package middleware
import "fmt" import "fmt"
// http exceptions
type HTTPException struct { type HTTPException struct {
StatusCode int // http status code 4xx, 5xx StatusCode int // http status code 4xx, 5xx
Description string Description string
} }
// return http exception error string, e.g. "400 Bad Request".
func (e *HTTPException) Error() string { func (e *HTTPException) Error() string {
// return `status description`, e.g. `400 Bad Request`
return fmt.Sprintf("%d %s", e.StatusCode, e.Description) return fmt.Sprintf("%d %s", e.StatusCode, e.Description)
} }
// map of http exceptions for each http status code int.
// defined 400,401,403,404,405,500,502,503 and 504 default.
var HTTPExceptionMaps map[int]HTTPException var HTTPExceptionMaps map[int]HTTPException
func init() { func init() {

View File

@ -544,8 +544,9 @@ var mimemaps map[string]string = map[string]string{
".mustache": "text/html", ".mustache": "text/html",
} }
func initMime() { func initMime() error {
for k, v := range mimemaps { for k, v := range mimemaps {
mime.AddExtensionType(k, v) mime.AddExtensionType(k, v)
} }
return nil
} }

View File

@ -16,6 +16,7 @@ var (
commands = make(map[string]commander) commands = make(map[string]commander)
) )
// print help.
func printHelp(errs ...string) { func printHelp(errs ...string) {
content := `orm command usage: content := `orm command usage:
@ -31,6 +32,7 @@ func printHelp(errs ...string) {
os.Exit(2) os.Exit(2)
} }
// listen for orm command and then run it if command arguments passed.
func RunCommand() { func RunCommand() {
if len(os.Args) < 2 || os.Args[1] != "orm" { if len(os.Args) < 2 || os.Args[1] != "orm" {
return return
@ -58,6 +60,7 @@ func RunCommand() {
} }
} }
// sync database struct command interface.
type commandSyncDb struct { type commandSyncDb struct {
al *alias al *alias
force bool force bool
@ -66,6 +69,7 @@ type commandSyncDb struct {
rtOnError bool rtOnError bool
} }
// parse orm command line arguments.
func (d *commandSyncDb) Parse(args []string) { func (d *commandSyncDb) Parse(args []string) {
var name string var name string
@ -78,6 +82,7 @@ func (d *commandSyncDb) Parse(args []string) {
d.al = getDbAlias(name) d.al = getDbAlias(name)
} }
// run orm line command.
func (d *commandSyncDb) Run() error { func (d *commandSyncDb) Run() error {
var drops []string var drops []string
if d.force { if d.force {
@ -208,10 +213,12 @@ func (d *commandSyncDb) Run() error {
return nil return nil
} }
// database creation commander interface implement.
type commandSqlAll struct { type commandSqlAll struct {
al *alias al *alias
} }
// parse orm command line arguments.
func (d *commandSqlAll) Parse(args []string) { func (d *commandSqlAll) Parse(args []string) {
var name string var name string
@ -222,6 +229,7 @@ func (d *commandSqlAll) Parse(args []string) {
d.al = getDbAlias(name) d.al = getDbAlias(name)
} }
// run orm line command.
func (d *commandSqlAll) Run() error { func (d *commandSqlAll) Run() error {
sqls, indexes := getDbCreateSql(d.al) sqls, indexes := getDbCreateSql(d.al)
var all []string var all []string
@ -243,6 +251,10 @@ func init() {
commands["sqlall"] = new(commandSqlAll) commands["sqlall"] = new(commandSqlAll)
} }
// run syncdb command line.
// name means table's alias name. default is "default".
// force means run next sql if the current is error.
// verbose means show all info when running command or not.
func RunSyncdb(name string, force bool, verbose bool) error { func RunSyncdb(name string, force bool, verbose bool) error {
BootStrap() BootStrap()

View File

@ -12,6 +12,7 @@ type dbIndex struct {
Sql string Sql string
} }
// create database drop sql.
func getDbDropSql(al *alias) (sqls []string) { func getDbDropSql(al *alias) (sqls []string) {
if len(modelCache.cache) == 0 { if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model") fmt.Println("no Model found, need register your model")
@ -26,6 +27,7 @@ func getDbDropSql(al *alias) (sqls []string) {
return sqls return sqls
} }
// get database column type string.
func getColumnTyp(al *alias, fi *fieldInfo) (col string) { func getColumnTyp(al *alias, fi *fieldInfo) (col string) {
T := al.DbBaser.DbTypes() T := al.DbBaser.DbTypes()
fieldType := fi.fieldType fieldType := fi.fieldType
@ -79,6 +81,7 @@ checkColumn:
return return
} }
// create alter sql string.
func getColumnAddQuery(al *alias, fi *fieldInfo) string { func getColumnAddQuery(al *alias, fi *fieldInfo) string {
Q := al.DbBaser.TableQuote() Q := al.DbBaser.TableQuote()
typ := getColumnTyp(al, fi) typ := getColumnTyp(al, fi)
@ -90,6 +93,7 @@ func getColumnAddQuery(al *alias, fi *fieldInfo) string {
return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s", Q, fi.mi.table, Q, Q, fi.column, Q, typ) return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s", Q, fi.mi.table, Q, Q, fi.column, Q, typ)
} }
// create database creation string.
func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) { func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
if len(modelCache.cache) == 0 { if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model") fmt.Println("no Model found, need register your model")

177
orm/db.go
View File

@ -15,7 +15,7 @@ const (
) )
var ( var (
ErrMissPK = errors.New("missed pk value") ErrMissPK = errors.New("missed pk value") // missing pk error
) )
var ( var (
@ -45,13 +45,22 @@ var (
} }
) )
// an instance of dbBaser interface/
type dbBase struct { type dbBase struct {
ins dbBaser ins dbBaser
} }
// check dbBase implements dbBaser interface.
var _ dbBaser = new(dbBase) var _ dbBaser = new(dbBase)
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) { // get struct columns values as interface slice.
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, err error) {
var columns []string
if names != nil {
columns = *names
}
for _, column := range cols { for _, column := range cols {
var fi *fieldInfo var fi *fieldInfo
if fi, _ = mi.fields.GetByAny(column); fi != nil { if fi, _ = mi.fields.GetByAny(column); fi != nil {
@ -64,14 +73,24 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
} }
value, err := d.collectFieldValue(mi, fi, ind, insert, tz) value, err := d.collectFieldValue(mi, fi, ind, insert, tz)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
columns = append(columns, column)
if names != nil {
columns = append(columns, column)
}
values = append(values, value) values = append(values, value)
} }
if names != nil {
*names = columns
}
return return
} }
// get one field value in struct column as interface.
func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) { func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) {
var value interface{} var value interface{}
if fi.pk { if fi.pk {
@ -140,6 +159,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
return value, nil return value, nil
} }
// create insert sql preparation statement object.
func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) {
Q := d.ins.TableQuote() Q := d.ins.TableQuote()
@ -165,8 +185,9 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
return stmt, query, err return stmt, query, err
} }
// insert struct with prepared statement and given struct reflect value.
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
_, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz) values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -185,6 +206,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
} }
} }
// query sql ,read records and persist in dbBaser.
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) error { func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) error {
var whereCols []string var whereCols []string
var args []interface{} var args []interface{}
@ -192,7 +214,8 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
// if specify cols length > 0, then use it for where condition. // if specify cols length > 0, then use it for where condition.
if len(cols) > 0 { if len(cols) > 0 {
var err error var err error
whereCols, args, err = d.collectValues(mi, ind, cols, false, false, tz) whereCols = make([]string, 0, len(cols))
args, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz)
if err != nil { if err != nil {
return err return err
} }
@ -202,7 +225,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
if ok == false { if ok == false {
return ErrMissPK return ErrMissPK
} }
whereCols = append(whereCols, pkColumn) whereCols = []string{pkColumn}
args = append(args, pkValue) args = append(args, pkValue)
} }
@ -243,16 +266,77 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
return nil return nil
} }
// execute insert sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
names, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz) names := make([]string, 0, len(mi.fields.dbcols)-1)
values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
return d.InsertValue(q, mi, names, values) return d.InsertValue(q, mi, false, names, values)
} }
func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values []interface{}) (int64, error) { // multi-insert sql with given slice struct reflect.Value.
func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) {
var (
cnt int64
nums int
values []interface{}
names []string
)
// typ := reflect.Indirect(mi.addrField).Type()
length := sind.Len()
for i := 1; i <= length; i++ {
ind := reflect.Indirect(sind.Index(i - 1))
// Is this needed ?
// if !ind.Type().AssignableTo(typ) {
// return cnt, ErrArgs
// }
if i == 1 {
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
if err != nil {
return cnt, err
}
values = make([]interface{}, bulk*len(vus))
nums += copy(values, vus)
} else {
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
if err != nil {
return cnt, err
}
if len(vus) != len(names) {
return cnt, ErrArgs
}
nums += copy(values[nums:], vus)
}
if i > 1 && i%bulk == 0 || length == i {
num, err := d.InsertValue(q, mi, true, names, values[:nums])
if err != nil {
return cnt, err
}
cnt += num
nums = 0
}
}
return cnt, nil
}
// execute insert sql with given struct and given values.
// insert the given values, not the field values in struct.
func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
Q := d.ins.TableQuote() Q := d.ins.TableQuote()
marks := make([]string, len(names)) marks := make([]string, len(names))
@ -264,36 +348,51 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values
qmarks := strings.Join(marks, ", ") qmarks := strings.Join(marks, ", ")
columns := strings.Join(names, sep) columns := strings.Join(names, sep)
multi := len(values) / len(names)
if isMulti {
qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
}
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
if d.ins.HasReturningID(mi, &query) { if isMulti || !d.ins.HasReturningID(mi, &query) {
row := q.QueryRow(query, values...)
var id int64
err := row.Scan(&id)
return id, err
} else {
if res, err := q.Exec(query, values...); err == nil { if res, err := q.Exec(query, values...); err == nil {
if isMulti {
return res.RowsAffected()
}
return res.LastInsertId() return res.LastInsertId()
} else { } else {
return 0, err return 0, err
} }
} else {
row := q.QueryRow(query, values...)
var id int64
err := row.Scan(&id)
return id, err
} }
} }
// execute update sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
pkName, pkValue, ok := getExistPk(mi, ind) pkName, pkValue, ok := getExistPk(mi, ind)
if ok == false { if ok == false {
return 0, ErrMissPK return 0, ErrMissPK
} }
var setNames []string
// if specify cols length is zero, then commit all columns. // if specify cols length is zero, then commit all columns.
if len(cols) == 0 { if len(cols) == 0 {
cols = mi.fields.dbcols cols = mi.fields.dbcols
setNames = make([]string, 0, len(mi.fields.dbcols)-1)
} else {
setNames = make([]string, 0, len(cols))
} }
setNames, setValues, err := d.collectValues(mi, ind, cols, true, false, tz) setValues, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -317,6 +416,8 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
return 0, nil return 0, nil
} }
// execute delete sql dbQuerier with given struct reflect.Value.
// delete index is pk.
func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
pkName, pkValue, ok := getExistPk(mi, ind) pkName, pkValue, ok := getExistPk(mi, ind)
if ok == false { if ok == false {
@ -358,6 +459,8 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
return 0, nil return 0, nil
} }
// update table-related record by querySet.
// need querySet not struct reflect.Value to update related records.
func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) {
columns := make([]string, 0, len(params)) columns := make([]string, 0, len(params))
values := make([]interface{}, 0, len(params)) values := make([]interface{}, 0, len(params))
@ -433,6 +536,8 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
return 0, nil return 0, nil
} }
// delete related records.
// do UpdateBanch or DeleteBanch by condition of tables' relationship.
func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error {
for _, fi := range mi.fields.fieldsReverse { for _, fi := range mi.fields.fieldsReverse {
fi = fi.reverseFieldInfo fi = fi.reverseFieldInfo
@ -459,8 +564,11 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *
return nil return nil
} }
// delete table-related records.
func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) {
tables := newDbTables(mi, d.ins) tables := newDbTables(mi, d.ins)
tables.skipEnd = true
if qs != nil { if qs != nil {
tables.parseRelated(qs.related, qs.relDepth) tables.parseRelated(qs.related, qs.relDepth)
} }
@ -486,6 +594,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
rs = r rs = r
} }
defer rs.Close()
var ref interface{} var ref interface{}
args = make([]interface{}, 0) args = make([]interface{}, 0)
@ -532,6 +642,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
return 0, nil return 0, nil
} }
// read related records.
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
val := reflect.ValueOf(container) val := reflect.ValueOf(container)
@ -640,6 +751,8 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
refs[i] = &ref refs[i] = &ref
} }
defer rs.Close()
slice := ind slice := ind
var cnt int64 var cnt int64
@ -739,6 +852,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
return cnt, nil return cnt, nil
} }
// excute count sql and return count result int64.
func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
tables := newDbTables(mi, d.ins) tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth) tables.parseRelated(qs.related, qs.relDepth)
@ -759,6 +873,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
return return
} }
// generate sql with replacing operator string placeholders and replaced values.
func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
sql := "" sql := ""
params := getFlatParams(fi, args, tz) params := getFlatParams(fi, args, tz)
@ -812,10 +927,12 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri
return sql, params return sql, params
} }
// gernerate sql string with inner function, such as UPPER(text).
func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) { func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) {
// default not use // default not use
} }
// set values to struct column.
func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) { func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) {
for i, column := range cols { for i, column := range cols {
val := reflect.Indirect(reflect.ValueOf(values[i])).Interface() val := reflect.Indirect(reflect.ValueOf(values[i])).Interface()
@ -837,6 +954,7 @@ func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string,
} }
} }
// convert value from database result to value following in field type.
func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) { func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) {
if val == nil { if val == nil {
return nil, nil return nil, nil
@ -989,6 +1107,7 @@ end:
} }
// set one value to struct column field.
func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) { func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) {
fieldType := fi.fieldType fieldType := fi.fieldType
@ -1063,6 +1182,7 @@ setValue:
return value, nil return value, nil
} }
// query sql, read values , save to *[]ParamList.
func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) {
var ( var (
@ -1150,6 +1270,8 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
refs[i] = &ref refs[i] = &ref
} }
defer rs.Close()
var ( var (
cnt int64 cnt int64
columns []string columns []string
@ -1228,6 +1350,11 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
return cnt, nil return cnt, nil
} }
func (d *dbBase) RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) {
return 0, nil
}
// flag of update joined record.
func (d *dbBase) SupportUpdateJoin() bool { func (d *dbBase) SupportUpdateJoin() bool {
return true return true
} }
@ -1236,30 +1363,37 @@ func (d *dbBase) MaxLimit() uint64 {
return 18446744073709551615 return 18446744073709551615
} }
// return quote.
func (d *dbBase) TableQuote() string { func (d *dbBase) TableQuote() string {
return "`" return "`"
} }
// replace value placeholer in parametered sql string.
func (d *dbBase) ReplaceMarks(query *string) { func (d *dbBase) ReplaceMarks(query *string) {
// default use `?` as mark, do nothing // default use `?` as mark, do nothing
} }
// flag of RETURNING sql.
func (d *dbBase) HasReturningID(*modelInfo, *string) bool { func (d *dbBase) HasReturningID(*modelInfo, *string) bool {
return false return false
} }
// convert time from db.
func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) { func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) {
*t = t.In(tz) *t = t.In(tz)
} }
// convert time to db.
func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) { func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) {
*t = t.In(tz) *t = t.In(tz)
} }
// get database types.
func (d *dbBase) DbTypes() map[string]string { func (d *dbBase) DbTypes() map[string]string {
return nil return nil
} }
// gt all tables.
func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) { func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
tables := make(map[string]bool) tables := make(map[string]bool)
query := d.ins.ShowTablesQuery() query := d.ins.ShowTablesQuery()
@ -1268,6 +1402,8 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
return tables, err return tables, err
} }
defer rows.Close()
for rows.Next() { for rows.Next() {
var table string var table string
err := rows.Scan(&table) err := rows.Scan(&table)
@ -1282,6 +1418,7 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
return tables, nil return tables, nil
} }
// get all cloumns in table.
func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
columns := make(map[string][3]string) columns := make(map[string][3]string)
query := d.ins.ShowColumnsQuery(table) query := d.ins.ShowColumnsQuery(table)
@ -1290,6 +1427,8 @@ func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, e
return columns, err return columns, err
} }
defer rows.Close()
for rows.Next() { for rows.Next() {
var ( var (
name string name string
@ -1306,18 +1445,22 @@ func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, e
return columns, nil return columns, nil
} }
// not implement.
func (d *dbBase) OperatorSql(operator string) string { func (d *dbBase) OperatorSql(operator string) string {
panic(ErrNotImplement) panic(ErrNotImplement)
} }
// not implement.
func (d *dbBase) ShowTablesQuery() string { func (d *dbBase) ShowTablesQuery() string {
panic(ErrNotImplement) panic(ErrNotImplement)
} }
// not implement.
func (d *dbBase) ShowColumnsQuery(table string) string { func (d *dbBase) ShowColumnsQuery(table string) string {
panic(ErrNotImplement) panic(ErrNotImplement)
} }
// not implement.
func (d *dbBase) IndexExists(dbQuerier, string, string) bool { func (d *dbBase) IndexExists(dbQuerier, string, string) bool {
panic(ErrNotImplement) panic(ErrNotImplement)
} }

View File

@ -3,33 +3,37 @@ package orm
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"os"
"reflect" "reflect"
"sync" "sync"
"time" "time"
) )
// database driver constant int.
type DriverType int type DriverType int
const ( const (
_ DriverType = iota _ DriverType = iota // int enum type
DR_MySQL DR_MySQL // mysql
DR_Sqlite DR_Sqlite // sqlite
DR_Oracle DR_Oracle // oracle
DR_Postgres DR_Postgres // pgsql
) )
// database driver string.
type driver string type driver string
// get type constant int of current driver..
func (d driver) Type() DriverType { func (d driver) Type() DriverType {
a, _ := dataBaseCache.get(string(d)) a, _ := dataBaseCache.get(string(d))
return a.Driver return a.Driver
} }
// get name of current driver
func (d driver) Name() string { func (d driver) Name() string {
return string(d) return string(d)
} }
// check driver iis implemented Driver interface or not.
var _ Driver = new(driver) var _ Driver = new(driver)
var ( var (
@ -47,11 +51,13 @@ var (
} }
) )
// database alias cacher.
type _dbCache struct { type _dbCache struct {
mux sync.RWMutex mux sync.RWMutex
cache map[string]*alias cache map[string]*alias
} }
// add database alias with original name.
func (ac *_dbCache) add(name string, al *alias) (added bool) { func (ac *_dbCache) add(name string, al *alias) (added bool) {
ac.mux.Lock() ac.mux.Lock()
defer ac.mux.Unlock() defer ac.mux.Unlock()
@ -62,6 +68,7 @@ func (ac *_dbCache) add(name string, al *alias) (added bool) {
return return
} }
// get database alias if cached.
func (ac *_dbCache) get(name string) (al *alias, ok bool) { func (ac *_dbCache) get(name string) (al *alias, ok bool) {
ac.mux.RLock() ac.mux.RLock()
defer ac.mux.RUnlock() defer ac.mux.RUnlock()
@ -69,6 +76,7 @@ func (ac *_dbCache) get(name string) (al *alias, ok bool) {
return return
} }
// get default alias.
func (ac *_dbCache) getDefault() (al *alias) { func (ac *_dbCache) getDefault() (al *alias) {
al, _ = ac.get("default") al, _ = ac.get("default")
return return
@ -87,57 +95,29 @@ type alias struct {
Engine string Engine string
} }
// Setting the database connect params. Use the database driver self dataSource args. func detectTZ(al *alias) {
func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
al := new(alias)
al.Name = aliasName
al.DriverName = driverName
al.DataSource = dataSource
var (
err error
)
if dr, ok := drivers[driverName]; ok {
al.DbBaser = dbBasers[dr]
al.Driver = dr
} else {
err = fmt.Errorf("driver name `%s` have not registered", driverName)
goto end
}
if dataBaseCache.add(aliasName, al) == false {
err = fmt.Errorf("db name `%s` already registered, cannot reuse", aliasName)
goto end
}
al.DB, err = sql.Open(driverName, dataSource)
if err != nil {
err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
goto end
}
// orm timezone system match database // orm timezone system match database
// default use Local // default use Local
al.TZ = time.Local al.TZ = time.Local
if al.DriverName == "sphinx" {
return
}
switch al.Driver { switch al.Driver {
case DR_MySQL: case DR_MySQL:
row := al.DB.QueryRow("SELECT @@session.time_zone") row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
var tz string var tz string
row.Scan(&tz) row.Scan(&tz)
if tz == "SYSTEM" { if len(tz) >= 8 {
tz = "" if tz[0] != '-' {
row = al.DB.QueryRow("SELECT @@system_time_zone") tz = "+" + tz
row.Scan(&tz)
t, err := time.Parse("MST", tz)
if err == nil {
al.TZ = t.Location()
} }
} else { t, err := time.Parse("-07:00:00", tz)
t, err := time.Parse("-07:00", tz)
if err == nil { if err == nil {
al.TZ = t.Location() al.TZ = t.Location()
} else {
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
} }
} }
@ -163,8 +143,64 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
loc, err := time.LoadLocation(tz) loc, err := time.LoadLocation(tz)
if err == nil { if err == nil {
al.TZ = loc al.TZ = loc
} else {
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
} }
} }
}
func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
al := new(alias)
al.Name = aliasName
al.DriverName = driverName
al.DB = db
if dr, ok := drivers[driverName]; ok {
al.DbBaser = dbBasers[dr]
al.Driver = dr
} else {
return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
}
err := db.Ping()
if err != nil {
return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error())
}
if dataBaseCache.add(aliasName, al) == false {
return nil, fmt.Errorf("db name `%s` already registered, cannot reuse", aliasName)
}
return al, nil
}
func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error {
_, err := addAliasWthDB(aliasName, driverName, db)
return err
}
// Setting the database connect params. Use the database driver self dataSource args.
func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error {
var (
err error
db *sql.DB
al *alias
)
db, err = sql.Open(driverName, dataSource)
if err != nil {
err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
goto end
}
al, err = addAliasWthDB(aliasName, driverName, db)
if err != nil {
goto end
}
al.DataSource = dataSource
detectTZ(al)
for i, v := range params { for i, v := range params {
switch i { switch i {
@ -175,39 +211,37 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
} }
} }
err = al.DB.Ping()
if err != nil {
err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
goto end
}
end: end:
if err != nil { if err != nil {
fmt.Println(err.Error()) if db != nil {
os.Exit(2) db.Close()
}
DebugLog.Println(err.Error())
} }
return err
} }
// Register a database driver use specify driver name, this can be definition the driver is which database type. // Register a database driver use specify driver name, this can be definition the driver is which database type.
func RegisterDriver(driverName string, typ DriverType) { func RegisterDriver(driverName string, typ DriverType) error {
if t, ok := drivers[driverName]; ok == false { if t, ok := drivers[driverName]; ok == false {
drivers[driverName] = typ drivers[driverName] = typ
} else { } else {
if t != typ { if t != typ {
fmt.Sprintf("driverName `%s` db driver already registered and is other type\n", driverName) return fmt.Errorf("driverName `%s` db driver already registered and is other type\n", driverName)
os.Exit(2)
} }
} }
return nil
} }
// Change the database default used timezone // Change the database default used timezone
func SetDataBaseTZ(aliasName string, tz *time.Location) { func SetDataBaseTZ(aliasName string, tz *time.Location) error {
if al, ok := dataBaseCache.get(aliasName); ok { if al, ok := dataBaseCache.get(aliasName); ok {
al.TZ = tz al.TZ = tz
} else { } else {
fmt.Sprintf("DataBase name `%s` not registered\n", aliasName) return fmt.Errorf("DataBase name `%s` not registered\n", aliasName)
os.Exit(2)
} }
return nil
} }
// Change the max idle conns for *sql.DB, use specify database alias name // Change the max idle conns for *sql.DB, use specify database alias name

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
) )
// mysql operators.
var mysqlOperators = map[string]string{ var mysqlOperators = map[string]string{
"exact": "= ?", "exact": "= ?",
"iexact": "LIKE ?", "iexact": "LIKE ?",
@ -21,6 +22,7 @@ var mysqlOperators = map[string]string{
"iendswith": "LIKE ?", "iendswith": "LIKE ?",
} }
// mysql column field types.
var mysqlTypes = map[string]string{ var mysqlTypes = map[string]string{
"auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY", "auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY",
"pk": "NOT NULL PRIMARY KEY", "pk": "NOT NULL PRIMARY KEY",
@ -41,29 +43,35 @@ var mysqlTypes = map[string]string{
"float64-decimal": "numeric(%d, %d)", "float64-decimal": "numeric(%d, %d)",
} }
// mysql dbBaser implementation.
type dbBaseMysql struct { type dbBaseMysql struct {
dbBase dbBase
} }
var _ dbBaser = new(dbBaseMysql) var _ dbBaser = new(dbBaseMysql)
// get mysql operator.
func (d *dbBaseMysql) OperatorSql(operator string) string { func (d *dbBaseMysql) OperatorSql(operator string) string {
return mysqlOperators[operator] return mysqlOperators[operator]
} }
// get mysql table field types.
func (d *dbBaseMysql) DbTypes() map[string]string { func (d *dbBaseMysql) DbTypes() map[string]string {
return mysqlTypes return mysqlTypes
} }
// show table sql for mysql.
func (d *dbBaseMysql) ShowTablesQuery() string { func (d *dbBaseMysql) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()" return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
} }
// show columns sql of table for mysql.
func (d *dbBaseMysql) ShowColumnsQuery(table string) string { func (d *dbBaseMysql) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+ return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
"WHERE table_schema = DATABASE() AND table_name = '%s'", table) "WHERE table_schema = DATABASE() AND table_name = '%s'", table)
} }
// execute sql to check index exist.
func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool { func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool {
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
@ -72,6 +80,7 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool
return cnt > 0 return cnt > 0
} }
// create new mysql dbBaser.
func newdbBaseMysql() dbBaser { func newdbBaseMysql() dbBaser {
b := new(dbBaseMysql) b := new(dbBaseMysql)
b.ins = b b.ins = b

View File

@ -1,11 +1,13 @@
package orm package orm
// oracle dbBaser
type dbBaseOracle struct { type dbBaseOracle struct {
dbBase dbBase
} }
var _ dbBaser = new(dbBaseOracle) var _ dbBaser = new(dbBaseOracle)
// create oracle dbBaser.
func newdbBaseOracle() dbBaser { func newdbBaseOracle() dbBaser {
b := new(dbBaseOracle) b := new(dbBaseOracle)
b.ins = b b.ins = b

View File

@ -5,6 +5,7 @@ import (
"strconv" "strconv"
) )
// postgresql operators.
var postgresOperators = map[string]string{ var postgresOperators = map[string]string{
"exact": "= ?", "exact": "= ?",
"iexact": "= UPPER(?)", "iexact": "= UPPER(?)",
@ -20,6 +21,7 @@ var postgresOperators = map[string]string{
"iendswith": "LIKE UPPER(?)", "iendswith": "LIKE UPPER(?)",
} }
// postgresql column field types.
var postgresTypes = map[string]string{ var postgresTypes = map[string]string{
"auto": "serial NOT NULL PRIMARY KEY", "auto": "serial NOT NULL PRIMARY KEY",
"pk": "NOT NULL PRIMARY KEY", "pk": "NOT NULL PRIMARY KEY",
@ -40,16 +42,19 @@ var postgresTypes = map[string]string{
"float64-decimal": "numeric(%d, %d)", "float64-decimal": "numeric(%d, %d)",
} }
// postgresql dbBaser.
type dbBasePostgres struct { type dbBasePostgres struct {
dbBase dbBase
} }
var _ dbBaser = new(dbBasePostgres) var _ dbBaser = new(dbBasePostgres)
// get postgresql operator.
func (d *dbBasePostgres) OperatorSql(operator string) string { func (d *dbBasePostgres) OperatorSql(operator string) string {
return postgresOperators[operator] return postgresOperators[operator]
} }
// generate functioned sql string, such as contains(text).
func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
switch operator { switch operator {
case "contains", "startswith", "endswith": case "contains", "startswith", "endswith":
@ -59,6 +64,7 @@ func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string,
} }
} }
// postgresql unsupports updating joined record.
func (d *dbBasePostgres) SupportUpdateJoin() bool { func (d *dbBasePostgres) SupportUpdateJoin() bool {
return false return false
} }
@ -67,10 +73,13 @@ func (d *dbBasePostgres) MaxLimit() uint64 {
return 0 return 0
} }
// postgresql quote is ".
func (d *dbBasePostgres) TableQuote() string { func (d *dbBasePostgres) TableQuote() string {
return `"` return `"`
} }
// postgresql value placeholder is $n.
// replace default ? to $n.
func (d *dbBasePostgres) ReplaceMarks(query *string) { func (d *dbBasePostgres) ReplaceMarks(query *string) {
q := *query q := *query
num := 0 num := 0
@ -97,6 +106,7 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
*query = string(data) *query = string(data)
} }
// make returning sql support for postgresql.
func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) { func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) {
if mi.fields.pk.auto { if mi.fields.pk.auto {
if query != nil { if query != nil {
@ -107,18 +117,22 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool)
return return
} }
// show table sql for postgresql.
func (d *dbBasePostgres) ShowTablesQuery() string { func (d *dbBasePostgres) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')" return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')"
} }
// show table columns sql for postgresql.
func (d *dbBasePostgres) ShowColumnsQuery(table string) string { func (d *dbBasePostgres) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table) return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table)
} }
// get column types of postgresql.
func (d *dbBasePostgres) DbTypes() map[string]string { func (d *dbBasePostgres) DbTypes() map[string]string {
return postgresTypes return postgresTypes
} }
// check index exist in postgresql.
func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool { func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool {
query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name) query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name)
row := db.QueryRow(query) row := db.QueryRow(query)
@ -127,6 +141,7 @@ func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bo
return cnt > 0 return cnt > 0
} }
// create new postgresql dbBaser.
func newdbBasePostgres() dbBaser { func newdbBasePostgres() dbBaser {
b := new(dbBasePostgres) b := new(dbBasePostgres)
b.ins = b b.ins = b

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
) )
// sqlite operators.
var sqliteOperators = map[string]string{ var sqliteOperators = map[string]string{
"exact": "= ?", "exact": "= ?",
"iexact": "LIKE ? ESCAPE '\\'", "iexact": "LIKE ? ESCAPE '\\'",
@ -20,6 +21,7 @@ var sqliteOperators = map[string]string{
"iendswith": "LIKE ? ESCAPE '\\'", "iendswith": "LIKE ? ESCAPE '\\'",
} }
// sqlite column types.
var sqliteTypes = map[string]string{ var sqliteTypes = map[string]string{
"auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT", "auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT",
"pk": "NOT NULL PRIMARY KEY", "pk": "NOT NULL PRIMARY KEY",
@ -40,38 +42,47 @@ var sqliteTypes = map[string]string{
"float64-decimal": "decimal", "float64-decimal": "decimal",
} }
// sqlite dbBaser.
type dbBaseSqlite struct { type dbBaseSqlite struct {
dbBase dbBase
} }
var _ dbBaser = new(dbBaseSqlite) var _ dbBaser = new(dbBaseSqlite)
// get sqlite operator.
func (d *dbBaseSqlite) OperatorSql(operator string) string { func (d *dbBaseSqlite) OperatorSql(operator string) string {
return sqliteOperators[operator] return sqliteOperators[operator]
} }
// generate functioned sql for sqlite.
// only support DATE(text).
func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
if fi.fieldType == TypeDateField { if fi.fieldType == TypeDateField {
*leftCol = fmt.Sprintf("DATE(%s)", *leftCol) *leftCol = fmt.Sprintf("DATE(%s)", *leftCol)
} }
} }
// unable updating joined record in sqlite.
func (d *dbBaseSqlite) SupportUpdateJoin() bool { func (d *dbBaseSqlite) SupportUpdateJoin() bool {
return false return false
} }
// max int in sqlite.
func (d *dbBaseSqlite) MaxLimit() uint64 { func (d *dbBaseSqlite) MaxLimit() uint64 {
return 9223372036854775807 return 9223372036854775807
} }
// get column types in sqlite.
func (d *dbBaseSqlite) DbTypes() map[string]string { func (d *dbBaseSqlite) DbTypes() map[string]string {
return sqliteTypes return sqliteTypes
} }
// get show tables sql in sqlite.
func (d *dbBaseSqlite) ShowTablesQuery() string { func (d *dbBaseSqlite) ShowTablesQuery() string {
return "SELECT name FROM sqlite_master WHERE type = 'table'" return "SELECT name FROM sqlite_master WHERE type = 'table'"
} }
// get columns in sqlite.
func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
query := d.ins.ShowColumnsQuery(table) query := d.ins.ShowColumnsQuery(table)
rows, err := db.Query(query) rows, err := db.Query(query)
@ -92,10 +103,12 @@ func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]str
return columns, nil return columns, nil
} }
// get show columns sql in sqlite.
func (d *dbBaseSqlite) ShowColumnsQuery(table string) string { func (d *dbBaseSqlite) ShowColumnsQuery(table string) string {
return fmt.Sprintf("pragma table_info('%s')", table) return fmt.Sprintf("pragma table_info('%s')", table)
} }
// check index exist in sqlite.
func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool { func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool {
query := fmt.Sprintf("PRAGMA index_list('%s')", table) query := fmt.Sprintf("PRAGMA index_list('%s')", table)
rows, err := db.Query(query) rows, err := db.Query(query)
@ -113,6 +126,7 @@ func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool
return false return false
} }
// create new sqlite dbBaser.
func newdbBaseSqlite() dbBaser { func newdbBaseSqlite() dbBaser {
b := new(dbBaseSqlite) b := new(dbBaseSqlite)
b.ins = b b.ins = b

View File

@ -6,6 +6,7 @@ import (
"time" "time"
) )
// table info struct.
type dbTable struct { type dbTable struct {
id int id int
index string index string
@ -18,13 +19,17 @@ type dbTable struct {
jtl *dbTable jtl *dbTable
} }
// tables collection struct, contains some tables.
type dbTables struct { type dbTables struct {
tablesM map[string]*dbTable tablesM map[string]*dbTable
tables []*dbTable tables []*dbTable
mi *modelInfo mi *modelInfo
base dbBaser base dbBaser
skipEnd bool
} }
// set table info to collection.
// if not exist, create new.
func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable { func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
name := strings.Join(names, ExprSep) name := strings.Join(names, ExprSep)
if j, ok := t.tablesM[name]; ok { if j, ok := t.tablesM[name]; ok {
@ -41,6 +46,7 @@ func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool)
return t.tablesM[name] return t.tablesM[name]
} }
// add table info to collection.
func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) { func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
name := strings.Join(names, ExprSep) name := strings.Join(names, ExprSep)
if _, ok := t.tablesM[name]; ok == false { if _, ok := t.tablesM[name]; ok == false {
@ -53,11 +59,14 @@ func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool)
return t.tablesM[name], false return t.tablesM[name], false
} }
// get table info in collection.
func (t *dbTables) get(name string) (*dbTable, bool) { func (t *dbTables) get(name string) (*dbTable, bool) {
j, ok := t.tablesM[name] j, ok := t.tablesM[name]
return j, ok return j, ok
} }
// get related fields info in recursive depth loop.
// loop once, depth decreases one.
func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string { func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
if depth < 0 || fi.fieldType == RelManyToMany { if depth < 0 || fi.fieldType == RelManyToMany {
return related return related
@ -78,6 +87,7 @@ func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []
return related return related
} }
// parse related fields.
func (t *dbTables) parseRelated(rels []string, depth int) { func (t *dbTables) parseRelated(rels []string, depth int) {
relsNum := len(rels) relsNum := len(rels)
@ -111,7 +121,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
names = append(names, fi.name) names = append(names, fi.name)
mmi = fi.relModelInfo mmi = fi.relModelInfo
if fi.null { if fi.null || t.skipEnd {
inner = false inner = false
} }
@ -139,6 +149,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
} }
} }
// generate join string.
func (t *dbTables) getJoinSql() (join string) { func (t *dbTables) getJoinSql() (join string) {
Q := t.base.TableQuote() Q := t.base.TableQuote()
@ -185,9 +196,12 @@ func (t *dbTables) getJoinSql() (join string) {
return return
} }
// parse orm model struct field tag expression.
func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) { func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) {
var ( var (
jtl *dbTable jtl *dbTable
fi *fieldInfo
fiN *fieldInfo
mmi = mi mmi = mi
) )
@ -196,9 +210,22 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
inner := true inner := true
loopFor:
for i, ex := range exprs { for i, ex := range exprs {
fi, ok := mmi.fields.GetByAny(ex) var ok, okN bool
if fiN != nil {
fi = fiN
ok = true
fiN = nil
}
if i == 0 {
fi, ok = mmi.fields.GetByAny(ex)
}
_ = okN
if ok { if ok {
@ -216,44 +243,61 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
mmi = fi.reverseFieldInfo.mi mmi = fi.reverseFieldInfo.mi
} }
if i < num {
fiN, okN = mmi.fields.GetByAny(exprs[i+1])
}
if isRel && (fi.mi.isThrough == false || num != i) { if isRel && (fi.mi.isThrough == false || num != i) {
if fi.null { if fi.null || t.skipEnd {
inner = false inner = false
} }
jt, _ := t.add(names, mmi, fi, inner) if t.skipEnd && okN || !t.skipEnd {
jt.jtl = jtl if t.skipEnd && okN && fiN.pk {
jtl = jt goto loopEnd
}
if num == i {
if i == 0 || jtl == nil {
index = "T0"
} else {
index = jtl.index
}
info = fi
if jtl == nil {
name = fi.name
} else {
name = jtl.name + ExprSep + fi.name
}
switch {
case fi.rel:
case fi.reverse:
switch fi.reverseFieldInfo.fieldType {
case RelOneToOne, RelForeignKey:
index = jtl.index
info = fi.reverseFieldInfo.mi.fields.pk
name = info.name
} }
jt, _ := t.add(names, mmi, fi, inner)
jt.jtl = jtl
jtl = jt
}
}
if num != i {
continue
}
loopEnd:
if i == 0 || jtl == nil {
index = "T0"
} else {
index = jtl.index
}
info = fi
if jtl == nil {
name = fi.name
} else {
name = jtl.name + ExprSep + fi.name
}
switch {
case fi.rel:
case fi.reverse:
switch fi.reverseFieldInfo.fieldType {
case RelOneToOne, RelForeignKey:
index = jtl.index
info = fi.reverseFieldInfo.mi.fields.pk
name = info.name
} }
} }
break loopFor
} else { } else {
index = "" index = ""
name = "" name = ""
@ -267,6 +311,7 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
return return
} }
// generate condition sql.
func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) { func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
if cond == nil || cond.IsEmpty() { if cond == nil || cond.IsEmpty() {
return return
@ -331,6 +376,7 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
return return
} }
// generate order sql.
func (t *dbTables) getOrderSql(orders []string) (orderSql string) { func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
if len(orders) == 0 { if len(orders) == 0 {
return return
@ -359,6 +405,7 @@ func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
return return
} }
// generate limit sql.
func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits string) { func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits string) {
if limit == 0 { if limit == 0 {
limit = int64(DefaultRowsLimit) limit = int64(DefaultRowsLimit)
@ -381,6 +428,7 @@ func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits
return return
} }
// crete new tables collection.
func newDbTables(mi *modelInfo, base dbBaser) *dbTables { func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
tables := &dbTables{} tables := &dbTables{}
tables.tablesM = make(map[string]*dbTable) tables.tablesM = make(map[string]*dbTable)

View File

@ -6,6 +6,7 @@ import (
"time" "time"
) )
// get table alias.
func getDbAlias(name string) *alias { func getDbAlias(name string) *alias {
if al, ok := dataBaseCache.get(name); ok { if al, ok := dataBaseCache.get(name); ok {
return al return al
@ -15,6 +16,7 @@ func getDbAlias(name string) *alias {
return nil return nil
} }
// get pk column info.
func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) { func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
fi := mi.fields.pk fi := mi.fields.pk
@ -37,6 +39,7 @@ func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interfac
return return
} }
// get fields description as flatted string.
func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) { func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) {
outFor: outFor:

View File

@ -41,6 +41,7 @@ var (
} }
) )
// model info collection
type _modelCache struct { type _modelCache struct {
sync.RWMutex sync.RWMutex
orders []string orders []string
@ -49,6 +50,7 @@ type _modelCache struct {
done bool done bool
} }
// get all model info
func (mc *_modelCache) all() map[string]*modelInfo { func (mc *_modelCache) all() map[string]*modelInfo {
m := make(map[string]*modelInfo, len(mc.cache)) m := make(map[string]*modelInfo, len(mc.cache))
for k, v := range mc.cache { for k, v := range mc.cache {
@ -57,6 +59,7 @@ func (mc *_modelCache) all() map[string]*modelInfo {
return m return m
} }
// get orderd model info
func (mc *_modelCache) allOrdered() []*modelInfo { func (mc *_modelCache) allOrdered() []*modelInfo {
m := make([]*modelInfo, 0, len(mc.orders)) m := make([]*modelInfo, 0, len(mc.orders))
for _, table := range mc.orders { for _, table := range mc.orders {
@ -65,16 +68,19 @@ func (mc *_modelCache) allOrdered() []*modelInfo {
return m return m
} }
// get model info by table name
func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) { func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) {
mi, ok = mc.cache[table] mi, ok = mc.cache[table]
return return
} }
// get model info by field name
func (mc *_modelCache) getByFN(name string) (mi *modelInfo, ok bool) { func (mc *_modelCache) getByFN(name string) (mi *modelInfo, ok bool) {
mi, ok = mc.cacheByFN[name] mi, ok = mc.cacheByFN[name]
return return
} }
// set model info to collection
func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo { func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
mii := mc.cache[table] mii := mc.cache[table]
mc.cache[table] = mi mc.cache[table] = mi
@ -85,6 +91,7 @@ func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
return mii return mii
} }
// clean all model info.
func (mc *_modelCache) clean() { func (mc *_modelCache) clean() {
mc.orders = make([]string, 0) mc.orders = make([]string, 0)
mc.cache = make(map[string]*modelInfo) mc.cache = make(map[string]*modelInfo)

View File

@ -8,6 +8,8 @@ import (
"strings" "strings"
) )
// register models.
// prefix means table name prefix.
func registerModel(model interface{}, prefix string) { func registerModel(model interface{}, prefix string) {
val := reflect.ValueOf(model) val := reflect.ValueOf(model)
ind := reflect.Indirect(val) ind := reflect.Indirect(val)
@ -67,6 +69,7 @@ func registerModel(model interface{}, prefix string) {
modelCache.set(table, info) modelCache.set(table, info)
} }
// boostrap models
func bootStrap() { func bootStrap() {
if modelCache.done { if modelCache.done {
return return
@ -281,6 +284,7 @@ end:
} }
} }
// register models
func RegisterModel(models ...interface{}) { func RegisterModel(models ...interface{}) {
if modelCache.done { if modelCache.done {
panic(fmt.Errorf("RegisterModel must be run before BootStrap")) panic(fmt.Errorf("RegisterModel must be run before BootStrap"))
@ -302,6 +306,8 @@ func RegisterModelWithPrefix(prefix string, models ...interface{}) {
} }
} }
// bootrap models.
// make all model parsed and can not add more models
func BootStrap() { func BootStrap() {
if modelCache.done { if modelCache.done {
return return

View File

@ -9,6 +9,7 @@ import (
var errSkipField = errors.New("skip field") var errSkipField = errors.New("skip field")
// field info collection
type fields struct { type fields struct {
pk *fieldInfo pk *fieldInfo
columns map[string]*fieldInfo columns map[string]*fieldInfo
@ -23,6 +24,7 @@ type fields struct {
dbcols []string dbcols []string
} }
// add field info
func (f *fields) Add(fi *fieldInfo) (added bool) { func (f *fields) Add(fi *fieldInfo) (added bool) {
if f.fields[fi.name] == nil && f.columns[fi.column] == nil { if f.fields[fi.name] == nil && f.columns[fi.column] == nil {
f.columns[fi.column] = fi f.columns[fi.column] = fi
@ -49,14 +51,17 @@ func (f *fields) Add(fi *fieldInfo) (added bool) {
return true return true
} }
// get field info by name
func (f *fields) GetByName(name string) *fieldInfo { func (f *fields) GetByName(name string) *fieldInfo {
return f.fields[name] return f.fields[name]
} }
// get field info by column name
func (f *fields) GetByColumn(column string) *fieldInfo { func (f *fields) GetByColumn(column string) *fieldInfo {
return f.columns[column] return f.columns[column]
} }
// get field info by string, name is prior
func (f *fields) GetByAny(name string) (*fieldInfo, bool) { func (f *fields) GetByAny(name string) (*fieldInfo, bool) {
if fi, ok := f.fields[name]; ok { if fi, ok := f.fields[name]; ok {
return fi, ok return fi, ok
@ -70,6 +75,7 @@ func (f *fields) GetByAny(name string) (*fieldInfo, bool) {
return nil, false return nil, false
} }
// create new field info collection
func newFields() *fields { func newFields() *fields {
f := new(fields) f := new(fields)
f.fields = make(map[string]*fieldInfo) f.fields = make(map[string]*fieldInfo)
@ -79,6 +85,7 @@ func newFields() *fields {
return f return f
} }
// single field info
type fieldInfo struct { type fieldInfo struct {
mi *modelInfo mi *modelInfo
fieldIndex int fieldIndex int
@ -115,6 +122,7 @@ type fieldInfo struct {
onDelete string onDelete string
} }
// new field info
func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField) (fi *fieldInfo, err error) { func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField) (fi *fieldInfo, err error) {
var ( var (
tag string tag string

View File

@ -7,6 +7,7 @@ import (
"reflect" "reflect"
) )
// single model info
type modelInfo struct { type modelInfo struct {
pkg string pkg string
name string name string
@ -20,6 +21,7 @@ type modelInfo struct {
isThrough bool isThrough bool
} }
// new model info
func newModelInfo(val reflect.Value) (info *modelInfo) { func newModelInfo(val reflect.Value) (info *modelInfo) {
var ( var (
err error err error
@ -79,6 +81,8 @@ func newModelInfo(val reflect.Value) (info *modelInfo) {
return return
} }
// combine related model info to new model info.
// prepare for relation models query.
func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) { func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
info = new(modelInfo) info = new(modelInfo)
info.fields = newFields() info.fields = newFields()

View File

@ -7,10 +7,12 @@ import (
"time" "time"
) )
// get reflect.Type name with package path.
func getFullName(typ reflect.Type) string { func getFullName(typ reflect.Type) string {
return typ.PkgPath() + "." + typ.Name() return typ.PkgPath() + "." + typ.Name()
} }
// get table name. method, or field name. auto snaked.
func getTableName(val reflect.Value) string { func getTableName(val reflect.Value) string {
ind := reflect.Indirect(val) ind := reflect.Indirect(val)
fun := val.MethodByName("TableName") fun := val.MethodByName("TableName")
@ -26,6 +28,7 @@ func getTableName(val reflect.Value) string {
return snakeString(ind.Type().Name()) return snakeString(ind.Type().Name())
} }
// get table engine, mysiam or innodb.
func getTableEngine(val reflect.Value) string { func getTableEngine(val reflect.Value) string {
fun := val.MethodByName("TableEngine") fun := val.MethodByName("TableEngine")
if fun.IsValid() { if fun.IsValid() {
@ -40,6 +43,7 @@ func getTableEngine(val reflect.Value) string {
return "" return ""
} }
// get table index from method.
func getTableIndex(val reflect.Value) [][]string { func getTableIndex(val reflect.Value) [][]string {
fun := val.MethodByName("TableIndex") fun := val.MethodByName("TableIndex")
if fun.IsValid() { if fun.IsValid() {
@ -56,6 +60,7 @@ func getTableIndex(val reflect.Value) [][]string {
return nil return nil
} }
// get table unique from method
func getTableUnique(val reflect.Value) [][]string { func getTableUnique(val reflect.Value) [][]string {
fun := val.MethodByName("TableUnique") fun := val.MethodByName("TableUnique")
if fun.IsValid() { if fun.IsValid() {
@ -72,6 +77,7 @@ func getTableUnique(val reflect.Value) [][]string {
return nil return nil
} }
// get snaked column name
func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string { func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string {
col = strings.ToLower(col) col = strings.ToLower(col)
column := col column := col
@ -89,6 +95,7 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col
return column return column
} }
// return field type as type constant from reflect.Value
func getFieldType(val reflect.Value) (ft int, err error) { func getFieldType(val reflect.Value) (ft int, err error) {
elm := reflect.Indirect(val) elm := reflect.Indirect(val)
switch elm.Kind() { switch elm.Kind() {
@ -128,6 +135,7 @@ func getFieldType(val reflect.Value) (ft int, err error) {
return return
} }
// parse struct tag string
func parseStructTag(data string, attrs *map[string]bool, tags *map[string]string) { func parseStructTag(data string, attrs *map[string]bool, tags *map[string]string) {
attr := make(map[string]bool) attr := make(map[string]bool)
tag := make(map[string]string) tag := make(map[string]string)

View File

@ -25,6 +25,7 @@ var (
ErrMultiRows = errors.New("<QuerySeter> return multi rows") ErrMultiRows = errors.New("<QuerySeter> return multi rows")
ErrNoRows = errors.New("<QuerySeter> no row found") ErrNoRows = errors.New("<QuerySeter> no row found")
ErrStmtClosed = errors.New("<QuerySeter> stmt already closed") ErrStmtClosed = errors.New("<QuerySeter> stmt already closed")
ErrArgs = errors.New("<Ormer> args error may be empty")
ErrNotImplement = errors.New("have not implement") ErrNotImplement = errors.New("have not implement")
) )
@ -39,11 +40,12 @@ type orm struct {
var _ Ormer = new(orm) var _ Ormer = new(orm)
func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) { // get model info and model reflect value
func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) {
val := reflect.ValueOf(md) val := reflect.ValueOf(md)
ind = reflect.Indirect(val) ind = reflect.Indirect(val)
typ := ind.Type() typ := ind.Type()
if val.Kind() != reflect.Ptr { if needPtr && val.Kind() != reflect.Ptr {
panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ))) panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ)))
} }
name := getFullName(typ) name := getFullName(typ)
@ -53,6 +55,7 @@ func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) {
panic(fmt.Errorf("<Ormer> table: `%s` not found, maybe not RegisterModel", name)) panic(fmt.Errorf("<Ormer> table: `%s` not found, maybe not RegisterModel", name))
} }
// get field info from model info by given field name
func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
fi, ok := mi.fields.GetByAny(name) fi, ok := mi.fields.GetByAny(name)
if !ok { if !ok {
@ -61,8 +64,9 @@ func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
return fi return fi
} }
// read data to model
func (o *orm) Read(md interface{}, cols ...string) error { func (o *orm) Read(md interface{}, cols ...string) error {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md, true)
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols) err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols)
if err != nil { if err != nil {
return err return err
@ -70,26 +74,83 @@ func (o *orm) Read(md interface{}, cols ...string) error {
return nil return nil
} }
// Try to read a row from the database, or insert one if it doesn't exist
func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) {
cols = append([]string{col1}, cols...)
mi, ind := o.getMiInd(md, true)
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols)
if err == ErrNoRows {
// Create
id, err := o.Insert(md)
return (err == nil), id, err
}
return false, ind.Field(mi.fields.pk.fieldIndex).Int(), err
}
// insert model data to database
func (o *orm) Insert(md interface{}) (int64, error) { func (o *orm) Insert(md interface{}) (int64, error) {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md, true)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
if err != nil { if err != nil {
return id, err return id, err
} }
if id > 0 {
if mi.fields.pk.auto { o.setPk(mi, ind, id)
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
ind.Field(mi.fields.pk.fieldIndex).SetUint(uint64(id))
} else {
ind.Field(mi.fields.pk.fieldIndex).SetInt(id)
}
}
}
return id, nil return id, nil
} }
// set auto pk field
func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) {
if mi.fields.pk.auto {
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
ind.Field(mi.fields.pk.fieldIndex).SetUint(uint64(id))
} else {
ind.Field(mi.fields.pk.fieldIndex).SetInt(id)
}
}
}
// insert some models to database
func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) {
var cnt int64
sind := reflect.Indirect(reflect.ValueOf(mds))
switch sind.Kind() {
case reflect.Array, reflect.Slice:
if sind.Len() == 0 {
return cnt, ErrArgs
}
default:
return cnt, ErrArgs
}
if bulk <= 1 {
for i := 0; i < sind.Len(); i++ {
ind := sind.Index(i)
mi, _ := o.getMiInd(ind.Interface(), false)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
if err != nil {
return cnt, err
}
o.setPk(mi, ind, id)
cnt += 1
}
} else {
mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ)
}
return cnt, nil
}
// update model to database.
// cols set the columns those want to update.
func (o *orm) Update(md interface{}, cols ...string) (int64, error) { func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md, true)
num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
if err != nil { if err != nil {
return num, err return num, err
@ -97,26 +158,22 @@ func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
return num, nil return num, nil
} }
// delete model in database
func (o *orm) Delete(md interface{}) (int64, error) { func (o *orm) Delete(md interface{}) (int64, error) {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md, true)
num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ) num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ)
if err != nil { if err != nil {
return num, err return num, err
} }
if num > 0 { if num > 0 {
if mi.fields.pk.auto { o.setPk(mi, ind, 0)
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
ind.Field(mi.fields.pk.fieldIndex).SetUint(0)
} else {
ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
}
}
} }
return num, nil return num, nil
} }
// create a models to models queryer
func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md, true)
fi := o.getFieldInfo(mi, name) fi := o.getFieldInfo(mi, name)
switch { switch {
@ -129,6 +186,14 @@ func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer {
return newQueryM2M(md, o, mi, fi, ind) return newQueryM2M(md, o, mi, fi, ind)
} }
// load related models to md model.
// args are limit, offset int and order string.
//
// example:
// orm.LoadRelated(post,"Tags")
// for _,tag := range post.Tags{...}
//
// make sure the relation is defined in model struct tags.
func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) {
_, fi, ind, qseter := o.queryRelated(md, name) _, fi, ind, qseter := o.queryRelated(md, name)
@ -190,14 +255,21 @@ func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int
return nums, err return nums, err
} }
// return a QuerySeter for related models to md model.
// it can do all, update, delete in QuerySeter.
// example:
// qs := orm.QueryRelated(post,"Tag")
// qs.All(&[]*Tag{})
//
func (o *orm) QueryRelated(md interface{}, name string) QuerySeter { func (o *orm) QueryRelated(md interface{}, name string) QuerySeter {
// is this api needed ? // is this api needed ?
_, _, _, qs := o.queryRelated(md, name) _, _, _, qs := o.queryRelated(md, name)
return qs return qs
} }
// get QuerySeter for related models to md model
func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md, true)
fi := o.getFieldInfo(mi, name) fi := o.getFieldInfo(mi, name)
_, _, exist := getExistPk(mi, ind) _, _, exist := getExistPk(mi, ind)
@ -227,6 +299,7 @@ func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo,
return mi, fi, ind, qs return mi, fi, ind, qs
} }
// get reverse relation QuerySeter
func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
switch fi.fieldType { switch fi.fieldType {
case RelReverseOne, RelReverseMany: case RelReverseOne, RelReverseMany:
@ -247,6 +320,7 @@ func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS
return q return q
} }
// get relation QuerySeter
func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
switch fi.fieldType { switch fi.fieldType {
case RelOneToOne, RelForeignKey, RelManyToMany: case RelOneToOne, RelForeignKey, RelManyToMany:
@ -266,6 +340,9 @@ func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
return q return q
} }
// return a QuerySeter for table operations.
// table name can be string or struct.
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
name := "" name := ""
if table, ok := ptrStructOrTableName.(string); ok { if table, ok := ptrStructOrTableName.(string); ok {
@ -285,6 +362,7 @@ func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
return return
} }
// switch to another registered database driver by given name.
func (o *orm) Using(name string) error { func (o *orm) Using(name string) error {
if o.isTx { if o.isTx {
panic(fmt.Errorf("<Ormer.Using> transaction has been start, cannot change db")) panic(fmt.Errorf("<Ormer.Using> transaction has been start, cannot change db"))
@ -302,6 +380,7 @@ func (o *orm) Using(name string) error {
return nil return nil
} }
// begin transaction
func (o *orm) Begin() error { func (o *orm) Begin() error {
if o.isTx { if o.isTx {
return ErrTxHasBegan return ErrTxHasBegan
@ -320,6 +399,7 @@ func (o *orm) Begin() error {
return nil return nil
} }
// commit transaction
func (o *orm) Commit() error { func (o *orm) Commit() error {
if o.isTx == false { if o.isTx == false {
return ErrTxDone return ErrTxDone
@ -334,6 +414,7 @@ func (o *orm) Commit() error {
return err return err
} }
// rollback transaction
func (o *orm) Rollback() error { func (o *orm) Rollback() error {
if o.isTx == false { if o.isTx == false {
return ErrTxDone return ErrTxDone
@ -348,14 +429,23 @@ func (o *orm) Rollback() error {
return err return err
} }
// return a raw query seter for raw sql string.
func (o *orm) Raw(query string, args ...interface{}) RawSeter { func (o *orm) Raw(query string, args ...interface{}) RawSeter {
return newRawSet(o, query, args) return newRawSet(o, query, args)
} }
// return current using database Driver
func (o *orm) Driver() Driver { func (o *orm) Driver() Driver {
return driver(o.alias.Name) return driver(o.alias.Name)
} }
func (o *orm) GetDB() dbQuerier {
panic(ErrNotImplement)
// not enough
return o.db
}
// create new orm
func NewOrm() Ormer { func NewOrm() Ormer {
BootStrap() // execute only once BootStrap() // execute only once
@ -366,3 +456,30 @@ func NewOrm() Ormer {
} }
return o return o
} }
// create a new ormer object with specify *sql.DB for query
func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) {
var al *alias
if dr, ok := drivers[driverName]; ok {
al = new(alias)
al.DbBaser = dbBasers[dr]
al.Driver = dr
} else {
return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
}
al.Name = aliasName
al.DriverName = driverName
o := new(orm)
o.alias = al
if Debug {
o.db = newDbQueryLog(o.alias, db)
} else {
o.db = db
}
return o, nil
}

View File

@ -18,15 +18,19 @@ type condValue struct {
isCond bool isCond bool
} }
// condition struct.
// work for WHERE conditions.
type Condition struct { type Condition struct {
params []condValue params []condValue
} }
// return new condition struct
func NewCondition() *Condition { func NewCondition() *Condition {
c := &Condition{} c := &Condition{}
return c return c
} }
// add expression to condition
func (c Condition) And(expr string, args ...interface{}) *Condition { func (c Condition) And(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.And> args cannot empty")) panic(fmt.Errorf("<Condition.And> args cannot empty"))
@ -35,6 +39,7 @@ func (c Condition) And(expr string, args ...interface{}) *Condition {
return &c return &c
} }
// add NOT expression to condition
func (c Condition) AndNot(expr string, args ...interface{}) *Condition { func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.AndNot> args cannot empty")) panic(fmt.Errorf("<Condition.AndNot> args cannot empty"))
@ -43,6 +48,7 @@ func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
return &c return &c
} }
// combine a condition to current condition
func (c *Condition) AndCond(cond *Condition) *Condition { func (c *Condition) AndCond(cond *Condition) *Condition {
c = c.clone() c = c.clone()
if c == cond { if c == cond {
@ -54,6 +60,7 @@ func (c *Condition) AndCond(cond *Condition) *Condition {
return c return c
} }
// add OR expression to condition
func (c Condition) Or(expr string, args ...interface{}) *Condition { func (c Condition) Or(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.Or> args cannot empty")) panic(fmt.Errorf("<Condition.Or> args cannot empty"))
@ -62,6 +69,7 @@ func (c Condition) Or(expr string, args ...interface{}) *Condition {
return &c return &c
} }
// add OR NOT expression to condition
func (c Condition) OrNot(expr string, args ...interface{}) *Condition { func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.OrNot> args cannot empty")) panic(fmt.Errorf("<Condition.OrNot> args cannot empty"))
@ -70,6 +78,7 @@ func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
return &c return &c
} }
// combine a OR condition to current condition
func (c *Condition) OrCond(cond *Condition) *Condition { func (c *Condition) OrCond(cond *Condition) *Condition {
c = c.clone() c = c.clone()
if c == cond { if c == cond {
@ -81,10 +90,12 @@ func (c *Condition) OrCond(cond *Condition) *Condition {
return c return c
} }
// check the condition arguments are empty or not.
func (c *Condition) IsEmpty() bool { func (c *Condition) IsEmpty() bool {
return len(c.params) == 0 return len(c.params) == 0
} }
// clone a condition
func (c Condition) clone() *Condition { func (c Condition) clone() *Condition {
return &c return &c
} }

View File

@ -13,6 +13,7 @@ type Log struct {
*log.Logger *log.Logger
} }
// set io.Writer to create a Logger.
func NewLog(out io.Writer) *Log { func NewLog(out io.Writer) *Log {
d := new(Log) d := new(Log)
d.Logger = log.New(out, "[ORM]", 1e9) d.Logger = log.New(out, "[ORM]", 1e9)
@ -40,6 +41,8 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error
DebugLog.Println(con) DebugLog.Println(con)
} }
// statement query logger struct.
// if dev mode, use stmtQueryLog, or use stmtQuerier.
type stmtQueryLog struct { type stmtQueryLog struct {
alias *alias alias *alias
query string query string
@ -84,6 +87,8 @@ func newStmtQueryLog(alias *alias, stmt stmtQuerier, query string) stmtQuerier {
return d return d
} }
// database query logger struct.
// if dev mode, use dbQueryLog, or use dbQuerier.
type dbQueryLog struct { type dbQueryLog struct {
alias *alias alias *alias
db dbQuerier db dbQuerier

View File

@ -5,6 +5,7 @@ import (
"reflect" "reflect"
) )
// an insert queryer struct
type insertSet struct { type insertSet struct {
mi *modelInfo mi *modelInfo
orm *orm orm *orm
@ -14,6 +15,7 @@ type insertSet struct {
var _ Inserter = new(insertSet) var _ Inserter = new(insertSet)
// insert model ignore it's registered or not.
func (o *insertSet) Insert(md interface{}) (int64, error) { func (o *insertSet) Insert(md interface{}) (int64, error) {
if o.closed { if o.closed {
return 0, ErrStmtClosed return 0, ErrStmtClosed
@ -44,6 +46,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) {
return id, nil return id, nil
} }
// close insert queryer statement
func (o *insertSet) Close() error { func (o *insertSet) Close() error {
if o.closed { if o.closed {
return ErrStmtClosed return ErrStmtClosed
@ -52,6 +55,7 @@ func (o *insertSet) Close() error {
return o.stmt.Close() return o.stmt.Close()
} }
// create new insert queryer.
func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) { func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) {
bi := new(insertSet) bi := new(insertSet)
bi.orm = orm bi.orm = orm

View File

@ -4,6 +4,7 @@ import (
"reflect" "reflect"
) )
// model to model struct
type queryM2M struct { type queryM2M struct {
md interface{} md interface{}
mi *modelInfo mi *modelInfo
@ -12,6 +13,13 @@ type queryM2M struct {
ind reflect.Value ind reflect.Value
} }
// add models to origin models when creating queryM2M.
// example:
// m2m := orm.QueryM2M(post,"Tag")
// m2m.Add(&Tag1{},&Tag2{})
// for _,tag := range post.Tags{}
//
// make sure the relation is defined in post model struct tag.
func (o *queryM2M) Add(mds ...interface{}) (int64, error) { func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
fi := o.fi fi := o.fi
mi := fi.relThroughModelInfo mi := fi.relThroughModelInfo
@ -44,7 +52,8 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
names := []string{mfi.column, rfi.column} names := []string{mfi.column, rfi.column}
var nums int64 values := make([]interface{}, 0, len(models)*2)
for _, md := range models { for _, md := range models {
ind := reflect.Indirect(reflect.ValueOf(md)) ind := reflect.Indirect(reflect.ValueOf(md))
@ -59,18 +68,14 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
} }
} }
values := []interface{}{v1, v2} values = append(values, v1, v2)
_, err := dbase.InsertValue(orm.db, mi, names, values)
if err != nil {
return nums, err
}
nums += 1
} }
return nums, nil return dbase.InsertValue(orm.db, mi, true, names, values)
} }
// remove models following the origin model relationship
func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
fi := o.fi fi := o.fi
qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md) qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md)
@ -82,17 +87,20 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
return nums, nil return nums, nil
} }
// check model is existed in relationship of origin model
func (o *queryM2M) Exist(md interface{}) bool { func (o *queryM2M) Exist(md interface{}) bool {
fi := o.fi fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md). return o.qs.Filter(fi.reverseFieldInfo.name, o.md).
Filter(fi.reverseFieldInfoTwo.name, md).Exist() Filter(fi.reverseFieldInfoTwo.name, md).Exist()
} }
// clean all models in related of origin model
func (o *queryM2M) Clear() (int64, error) { func (o *queryM2M) Clear() (int64, error) {
fi := o.fi fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete() return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete()
} }
// count all related models of origin model
func (o *queryM2M) Count() (int64, error) { func (o *queryM2M) Count() (int64, error) {
fi := o.fi fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count() return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count()
@ -100,6 +108,7 @@ func (o *queryM2M) Count() (int64, error) {
var _ QueryM2Mer = new(queryM2M) var _ QueryM2Mer = new(queryM2M)
// create new M2M queryer.
func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer { func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer {
qm2m := new(queryM2M) qm2m := new(queryM2M)
qm2m.md = md qm2m.md = md

View File

@ -18,6 +18,10 @@ const (
Col_Except Col_Except
) )
// ColValue do the field raw changes. e.g Nums = Nums + 10. usage:
// Params{
// "Nums": ColValue(Col_Add, 10),
// }
func ColValue(opt operator, value interface{}) interface{} { func ColValue(opt operator, value interface{}) interface{} {
switch opt { switch opt {
case Col_Add, Col_Minus, Col_Multiply, Col_Except: case Col_Add, Col_Minus, Col_Multiply, Col_Except:
@ -34,6 +38,7 @@ func ColValue(opt operator, value interface{}) interface{} {
return val return val
} }
// real query struct
type querySet struct { type querySet struct {
mi *modelInfo mi *modelInfo
cond *Condition cond *Condition
@ -47,6 +52,7 @@ type querySet struct {
var _ QuerySeter = new(querySet) var _ QuerySeter = new(querySet)
// add condition expression to QuerySeter.
func (o querySet) Filter(expr string, args ...interface{}) QuerySeter { func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
if o.cond == nil { if o.cond == nil {
o.cond = NewCondition() o.cond = NewCondition()
@ -55,6 +61,7 @@ func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
return &o return &o
} }
// add NOT condition to querySeter.
func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter { func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
if o.cond == nil { if o.cond == nil {
o.cond = NewCondition() o.cond = NewCondition()
@ -63,10 +70,13 @@ func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
return &o return &o
} }
// set offset number
func (o *querySet) setOffset(num interface{}) { func (o *querySet) setOffset(num interface{}) {
o.offset = ToInt64(num) o.offset = ToInt64(num)
} }
// add LIMIT value.
// args[0] means offset, e.g. LIMIT num,offset.
func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter { func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter {
o.limit = ToInt64(limit) o.limit = ToInt64(limit)
if len(args) > 0 { if len(args) > 0 {
@ -75,16 +85,21 @@ func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter {
return &o return &o
} }
// add OFFSET value
func (o querySet) Offset(offset interface{}) QuerySeter { func (o querySet) Offset(offset interface{}) QuerySeter {
o.setOffset(offset) o.setOffset(offset)
return &o return &o
} }
// add ORDER expression.
// "column" means ASC, "-column" means DESC.
func (o querySet) OrderBy(exprs ...string) QuerySeter { func (o querySet) OrderBy(exprs ...string) QuerySeter {
o.orders = exprs o.orders = exprs
return &o return &o
} }
// set relation model to query together.
// it will query relation models and assign to parent model.
func (o querySet) RelatedSel(params ...interface{}) QuerySeter { func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
var related []string var related []string
if len(params) == 0 { if len(params) == 0 {
@ -105,36 +120,50 @@ func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
return &o return &o
} }
// set condition to QuerySeter.
func (o querySet) SetCond(cond *Condition) QuerySeter { func (o querySet) SetCond(cond *Condition) QuerySeter {
o.cond = cond o.cond = cond
return &o return &o
} }
// return QuerySeter execution result number
func (o *querySet) Count() (int64, error) { func (o *querySet) Count() (int64, error) {
return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
} }
// check result empty or not after QuerySeter executed
func (o *querySet) Exist() bool { func (o *querySet) Exist() bool {
cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
return cnt > 0 return cnt > 0
} }
// execute update with parameters
func (o *querySet) Update(values Params) (int64, error) { func (o *querySet) Update(values Params) (int64, error) {
return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ)
} }
// execute delete
func (o *querySet) Delete() (int64, error) { func (o *querySet) Delete() (int64, error) {
return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
} }
// return a insert queryer.
// it can be used in times.
// example:
// i,err := sq.PrepareInsert()
// i.Add(&user1{},&user2{})
func (o *querySet) PrepareInsert() (Inserter, error) { func (o *querySet) PrepareInsert() (Inserter, error) {
return newInsertSet(o.orm, o.mi) return newInsertSet(o.orm, o.mi)
} }
// query all data and map to containers.
// cols means the columns when querying.
func (o *querySet) All(container interface{}, cols ...string) (int64, error) { func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
} }
// query one row data and map to containers.
// cols means the columns when querying.
func (o *querySet) One(container interface{}, cols ...string) error { func (o *querySet) One(container interface{}, cols ...string) error {
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
if err != nil { if err != nil {
@ -149,18 +178,56 @@ func (o *querySet) One(container interface{}, cols ...string) error {
return nil return nil
} }
// query all data and map to []map[string]interface.
// expres means condition expression.
// it converts data to []map[column]value.
func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) { func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
} }
// query all data and map to [][]interface
// it converts data to [][column_index]value
func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) { func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
} }
// query all data and map to []interface.
// it's designed for one row record set, auto change to []value, not [][column]value.
func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) { func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ)
} }
// query all rows into map[string]interface with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to map[string]interface{}{
// "total": 100,
// "found": 200,
// }
func (o *querySet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) {
panic(ErrNotImplement)
return o.orm.alias.DbBaser.RowsTo(o.orm.db, o, o.mi, o.cond, result, keyCol, valueCol, o.orm.alias.TZ)
}
// query all rows into struct with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to struct {
// Total int
// Found int
// }
func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) {
panic(ErrNotImplement)
return o.orm.alias.DbBaser.RowsTo(o.orm.db, o, o.mi, o.cond, ptrStruct, keyCol, valueCol, o.orm.alias.TZ)
}
// create new QuerySeter.
func newQuerySet(orm *orm, mi *modelInfo) QuerySeter { func newQuerySet(orm *orm, mi *modelInfo) QuerySeter {
o := new(querySet) o := new(querySet)
o.mi = mi o.mi = mi

View File

@ -4,10 +4,10 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"reflect" "reflect"
"strings"
"time" "time"
) )
// raw sql string prepared statement
type rawPrepare struct { type rawPrepare struct {
rs *rawSet rs *rawSet
stmt stmtQuerier stmt stmtQuerier
@ -45,6 +45,7 @@ func newRawPreparer(rs *rawSet) (RawPreparer, error) {
return o, nil return o, nil
} }
// raw query seter
type rawSet struct { type rawSet struct {
query string query string
args []interface{} args []interface{}
@ -53,11 +54,13 @@ type rawSet struct {
var _ RawSeter = new(rawSet) var _ RawSeter = new(rawSet)
// set args for every query
func (o rawSet) SetArgs(args ...interface{}) RawSeter { func (o rawSet) SetArgs(args ...interface{}) RawSeter {
o.args = args o.args = args
return &o return &o
} }
// execute raw sql and return sql.Result
func (o *rawSet) Exec() (sql.Result, error) { func (o *rawSet) Exec() (sql.Result, error) {
query := o.query query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query) o.orm.alias.DbBaser.ReplaceMarks(&query)
@ -66,6 +69,7 @@ func (o *rawSet) Exec() (sql.Result, error) {
return o.orm.db.Exec(query, args...) return o.orm.db.Exec(query, args...)
} }
// set field value to row container
func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
switch ind.Kind() { switch ind.Kind() {
case reflect.Bool: case reflect.Bool:
@ -164,65 +168,12 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
} }
} }
func (o *rawSet) loopInitRefs(typ reflect.Type, refsPtr *[]interface{}, sIdxesPtr *[][]int) { // set field value in loop for slice container
sIdxes := *sIdxesPtr func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) {
refs := *refsPtr
if typ.Kind() == reflect.Struct {
if typ.String() == "time.Time" {
var ref interface{}
refs = append(refs, &ref)
sIdxes = append(sIdxes, []int{0})
} else {
idxs := []int{}
outFor:
for idx := 0; idx < typ.NumField(); idx++ {
ctyp := typ.Field(idx)
tag := ctyp.Tag.Get(defaultStructTagName)
for _, v := range strings.Split(tag, defaultStructTagDelim) {
if v == "-" {
continue outFor
}
}
tp := ctyp.Type
if tp.Kind() == reflect.Ptr {
tp = tp.Elem()
}
if tp.String() == "time.Time" {
var ref interface{}
refs = append(refs, &ref)
} else if tp.Kind() != reflect.Struct {
var ref interface{}
refs = append(refs, &ref)
} else {
// skip other type
continue
}
idxs = append(idxs, idx)
}
sIdxes = append(sIdxes, idxs)
}
} else {
var ref interface{}
refs = append(refs, &ref)
sIdxes = append(sIdxes, []int{0})
}
*sIdxesPtr = sIdxes
*refsPtr = refs
}
func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) {
nInds := *nIndsPtr nInds := *nIndsPtr
cur := 0 cur := 0
for i, idxs := range sIdxes { for i := 0; i < len(sInds); i++ {
sInd := sInds[i] sInd := sInds[i]
eTyp := eTyps[i] eTyp := eTyps[i]
@ -258,32 +209,8 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect
o.setFieldValue(ind, value) o.setFieldValue(ind, value)
} }
cur++ cur++
} else {
hasValue := false
for _, idx := range idxs {
tind := ind.Field(idx)
value := reflect.ValueOf(refs[cur]).Elem().Interface()
if value != nil {
hasValue = true
}
if tind.Kind() == reflect.Ptr {
if value == nil {
tindV := reflect.New(tind.Type()).Elem()
tind.Set(tindV)
} else {
tindV := reflect.New(tind.Type().Elem())
o.setFieldValue(tindV.Elem(), value)
tind.Set(tindV)
}
} else {
o.setFieldValue(tind, value)
}
cur++
}
if hasValue == false && isPtr {
val = reflect.New(val.Type()).Elem()
}
} }
} else { } else {
value := reflect.ValueOf(refs[cur]).Elem().Interface() value := reflect.ValueOf(refs[cur]).Elem().Interface()
if isPtr && value == nil { if isPtr && value == nil {
@ -312,16 +239,14 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect
} }
} }
// query data and map to container
func (o *rawSet) QueryRow(containers ...interface{}) error { func (o *rawSet) QueryRow(containers ...interface{}) error {
if len(containers) == 0 {
panic(fmt.Errorf("<RawSeter.QueryRow> need at least one arg"))
}
refs := make([]interface{}, 0, len(containers)) refs := make([]interface{}, 0, len(containers))
sIdxes := make([][]int, 0)
sInds := make([]reflect.Value, 0) sInds := make([]reflect.Value, 0)
eTyps := make([]reflect.Type, 0) eTyps := make([]reflect.Type, 0)
structMode := false
var sMi *modelInfo
for _, container := range containers { for _, container := range containers {
val := reflect.ValueOf(container) val := reflect.ValueOf(container)
ind := reflect.Indirect(val) ind := reflect.Indirect(val)
@ -335,44 +260,123 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
if typ.Kind() == reflect.Ptr { if typ.Kind() == reflect.Ptr {
typ = typ.Elem() typ = typ.Elem()
} }
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
sInds = append(sInds, ind) sInds = append(sInds, ind)
eTyps = append(eTyps, etyp) eTyps = append(eTyps, etyp)
o.loopInitRefs(typ, &refs, &sIdxes) if typ.Kind() == reflect.Struct && typ.String() != "time.Time" {
if len(containers) > 1 {
panic(fmt.Errorf("<RawSeter.QueryRow> now support one struct only. see #384"))
}
structMode = true
fn := getFullName(typ)
if mi, ok := modelCache.getByFN(fn); ok {
sMi = mi
}
} else {
var ref interface{}
refs = append(refs, &ref)
}
} }
query := o.query query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query) o.orm.alias.DbBaser.ReplaceMarks(&query)
args := getFlatParams(nil, o.args, o.orm.alias.TZ) args := getFlatParams(nil, o.args, o.orm.alias.TZ)
row := o.orm.db.QueryRow(query, args...) rows, err := o.orm.db.Query(query, args...)
if err != nil {
if err := row.Scan(refs...); err == sql.ErrNoRows { if err == sql.ErrNoRows {
return ErrNoRows return ErrNoRows
} else if err != nil { }
return err return err
} }
nInds := make([]reflect.Value, len(sInds)) defer rows.Close()
o.loopSetRefs(refs, sIdxes, sInds, &nInds, eTyps, true)
for i, sInd := range sInds { if rows.Next() {
nInd := nInds[i] if structMode {
sInd.Set(nInd) columns, err := rows.Columns()
if err != nil {
return err
}
columnsMp := make(map[string]interface{}, len(columns))
refs = make([]interface{}, 0, len(columns))
for _, col := range columns {
var ref interface{}
columnsMp[col] = &ref
refs = append(refs, &ref)
}
if err := rows.Scan(refs...); err != nil {
return err
}
ind := sInds[0]
if ind.Kind() == reflect.Ptr {
if ind.IsNil() || !ind.IsValid() {
ind.Set(reflect.New(eTyps[0].Elem()))
}
ind = ind.Elem()
}
if sMi != nil {
for _, col := range columns {
if fi := sMi.fields.GetByColumn(col); fi != nil {
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
o.setFieldValue(ind.FieldByIndex([]int{fi.fieldIndex}), value)
}
}
} else {
for i := 0; i < ind.NumField(); i++ {
f := ind.Field(i)
fe := ind.Type().Field(i)
var attrs map[string]bool
var tags map[string]string
parseStructTag(fe.Tag.Get("orm"), &attrs, &tags)
var col string
if col = tags["column"]; len(col) == 0 {
col = snakeString(fe.Name)
}
if v, ok := columnsMp[col]; ok {
value := reflect.ValueOf(v).Elem().Interface()
o.setFieldValue(f, value)
}
}
}
} else {
if err := rows.Scan(refs...); err != nil {
return err
}
nInds := make([]reflect.Value, len(sInds))
o.loopSetRefs(refs, sInds, &nInds, eTyps, true)
for i, sInd := range sInds {
nInd := nInds[i]
sInd.Set(nInd)
}
}
} else {
return ErrNoRows
} }
return nil return nil
} }
// query data rows and map to container
func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
refs := make([]interface{}, 0) refs := make([]interface{}, 0, len(containers))
sIdxes := make([][]int, 0)
sInds := make([]reflect.Value, 0) sInds := make([]reflect.Value, 0)
eTyps := make([]reflect.Type, 0) eTyps := make([]reflect.Type, 0)
structMode := false
var sMi *modelInfo
for _, container := range containers { for _, container := range containers {
val := reflect.ValueOf(container) val := reflect.ValueOf(container)
sInd := reflect.Indirect(val) sInd := reflect.Indirect(val)
@ -389,7 +393,20 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
sInds = append(sInds, sInd) sInds = append(sInds, sInd)
eTyps = append(eTyps, etyp) eTyps = append(eTyps, etyp)
o.loopInitRefs(typ, &refs, &sIdxes) if typ.Kind() == reflect.Struct && typ.String() != "time.Time" {
if len(containers) > 1 {
panic(fmt.Errorf("<RawSeter.QueryRow> now support one struct only. see #384"))
}
structMode = true
fn := getFullName(typ)
if mi, ok := modelCache.getByFN(fn); ok {
sMi = mi
}
} else {
var ref interface{}
refs = append(refs, &ref)
}
} }
query := o.query query := o.query
@ -401,30 +418,107 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
return 0, err return 0, err
} }
nInds := make([]reflect.Value, len(sInds)) defer rows.Close()
var cnt int64 var cnt int64
for rows.Next() { nInds := make([]reflect.Value, len(sInds))
if err := rows.Scan(refs...); err != nil { sInd := sInds[0]
return 0, err
}
o.loopSetRefs(refs, sIdxes, sInds, &nInds, eTyps, cnt == 0) for rows.Next() {
if structMode {
columns, err := rows.Columns()
if err != nil {
return 0, err
}
columnsMp := make(map[string]interface{}, len(columns))
refs = make([]interface{}, 0, len(columns))
for _, col := range columns {
var ref interface{}
columnsMp[col] = &ref
refs = append(refs, &ref)
}
if err := rows.Scan(refs...); err != nil {
return 0, err
}
if cnt == 0 && !sInd.IsNil() {
sInd.Set(reflect.New(sInd.Type()).Elem())
}
var ind reflect.Value
if eTyps[0].Kind() == reflect.Ptr {
ind = reflect.New(eTyps[0].Elem())
} else {
ind = reflect.New(eTyps[0])
}
if ind.Kind() == reflect.Ptr {
ind = ind.Elem()
}
if sMi != nil {
for _, col := range columns {
if fi := sMi.fields.GetByColumn(col); fi != nil {
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
o.setFieldValue(ind.FieldByIndex([]int{fi.fieldIndex}), value)
}
}
} else {
for i := 0; i < ind.NumField(); i++ {
f := ind.Field(i)
fe := ind.Type().Field(i)
var attrs map[string]bool
var tags map[string]string
parseStructTag(fe.Tag.Get("orm"), &attrs, &tags)
var col string
if col = tags["column"]; len(col) == 0 {
col = snakeString(fe.Name)
}
if v, ok := columnsMp[col]; ok {
value := reflect.ValueOf(v).Elem().Interface()
o.setFieldValue(f, value)
}
}
}
if eTyps[0].Kind() == reflect.Ptr {
ind = ind.Addr()
}
sInd = reflect.Append(sInd, ind)
} else {
if err := rows.Scan(refs...); err != nil {
return 0, err
}
o.loopSetRefs(refs, sInds, &nInds, eTyps, cnt == 0)
}
cnt++ cnt++
} }
if cnt > 0 { if cnt > 0 {
for i, sInd := range sInds {
nInd := nInds[i] if structMode {
sInd.Set(nInd) sInds[0].Set(sInd)
} else {
for i, sInd := range sInds {
nInd := nInds[i]
sInd.Set(nInd)
}
} }
} }
return cnt, nil return cnt, nil
} }
func (o *rawSet) readValues(container interface{}) (int64, error) { func (o *rawSet) readValues(container interface{}, needCols []string) (int64, error) {
var ( var (
maps []Params maps []Params
lists []ParamsList lists []ParamsList
@ -455,21 +549,41 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
rs = r rs = r
} }
defer rs.Close()
var ( var (
refs []interface{} refs []interface{}
cnt int64 cnt int64
cols []string cols []string
indexs []int
) )
for rs.Next() { for rs.Next() {
if cnt == 0 { if cnt == 0 {
if columns, err := rs.Columns(); err != nil { if columns, err := rs.Columns(); err != nil {
return 0, err return 0, err
} else { } else {
if len(needCols) > 0 {
indexs = make([]int, 0, len(needCols))
} else {
indexs = make([]int, 0, len(columns))
}
cols = columns cols = columns
refs = make([]interface{}, len(cols)) refs = make([]interface{}, len(cols))
for i, _ := range refs { for i, _ := range refs {
var ref sql.NullString var ref sql.NullString
refs[i] = &ref refs[i] = &ref
if len(needCols) > 0 {
for _, c := range needCols {
if c == cols[i] {
indexs = append(indexs, i)
}
}
} else {
indexs = append(indexs, i)
}
} }
} }
} }
@ -481,7 +595,8 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
switch typ { switch typ {
case 1: case 1:
params := make(Params, len(cols)) params := make(Params, len(cols))
for i, ref := range refs { for _, i := range indexs {
ref := refs[i]
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
if value.Valid { if value.Valid {
params[cols[i]] = value.String params[cols[i]] = value.String
@ -492,7 +607,8 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
maps = append(maps, params) maps = append(maps, params)
case 2: case 2:
params := make(ParamsList, 0, len(cols)) params := make(ParamsList, 0, len(cols))
for _, ref := range refs { for _, i := range indexs {
ref := refs[i]
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
if value.Valid { if value.Valid {
params = append(params, value.String) params = append(params, value.String)
@ -502,7 +618,8 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
} }
lists = append(lists, params) lists = append(lists, params)
case 3: case 3:
for _, ref := range refs { for _, i := range indexs {
ref := refs[i]
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
if value.Valid { if value.Valid {
list = append(list, value.String) list = append(list, value.String)
@ -527,18 +644,166 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
return cnt, nil return cnt, nil
} }
func (o *rawSet) Values(container *[]Params) (int64, error) { func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (int64, error) {
return o.readValues(container) var (
maps Params
ind *reflect.Value
)
typ := 0
switch container.(type) {
case *Params:
typ = 1
default:
typ = 2
vl := reflect.ValueOf(container)
id := reflect.Indirect(vl)
if vl.Kind() != reflect.Ptr || id.Kind() != reflect.Struct {
panic(fmt.Errorf("<RawSeter> RowsTo unsupport type `%T` need ptr struct", container))
}
ind = &id
}
query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query)
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
var rs *sql.Rows
if r, err := o.orm.db.Query(query, args...); err != nil {
return 0, err
} else {
rs = r
}
defer rs.Close()
var (
refs []interface{}
cnt int64
cols []string
)
var (
keyIndex = -1
valueIndex = -1
)
for rs.Next() {
if cnt == 0 {
if columns, err := rs.Columns(); err != nil {
return 0, err
} else {
cols = columns
refs = make([]interface{}, len(cols))
for i, _ := range refs {
if keyCol == cols[i] {
keyIndex = i
}
if typ == 1 || keyIndex == i {
var ref sql.NullString
refs[i] = &ref
} else {
var ref interface{}
refs[i] = &ref
}
if valueCol == cols[i] {
valueIndex = i
}
}
if keyIndex == -1 || valueIndex == -1 {
panic(fmt.Errorf("<RawSeter> RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol))
}
}
}
if err := rs.Scan(refs...); err != nil {
return 0, err
}
if cnt == 0 {
switch typ {
case 1:
maps = make(Params)
}
}
key := reflect.Indirect(reflect.ValueOf(refs[keyIndex])).Interface().(sql.NullString).String
switch typ {
case 1:
value := reflect.Indirect(reflect.ValueOf(refs[valueIndex])).Interface().(sql.NullString)
if value.Valid {
maps[key] = value.String
} else {
maps[key] = nil
}
default:
if id := ind.FieldByName(camelString(key)); id.IsValid() {
o.setFieldValue(id, reflect.ValueOf(refs[valueIndex]).Elem().Interface())
}
}
cnt++
}
if typ == 1 {
v, _ := container.(*Params)
*v = maps
}
return cnt, nil
} }
func (o *rawSet) ValuesList(container *[]ParamsList) (int64, error) { // query data to []map[string]interface
return o.readValues(container) func (o *rawSet) Values(container *[]Params, cols ...string) (int64, error) {
return o.readValues(container, cols)
} }
func (o *rawSet) ValuesFlat(container *ParamsList) (int64, error) { // query data to [][]interface
return o.readValues(container) func (o *rawSet) ValuesList(container *[]ParamsList, cols ...string) (int64, error) {
return o.readValues(container, cols)
} }
// query data to []interface
func (o *rawSet) ValuesFlat(container *ParamsList, cols ...string) (int64, error) {
return o.readValues(container, cols)
}
// query all rows into map[string]interface with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to map[string]interface{}{
// "total": 100,
// "found": 200,
// }
func (o *rawSet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) {
return o.queryRowsTo(result, keyCol, valueCol)
}
// query all rows into struct with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to struct {
// Total int
// Found int
// }
func (o *rawSet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) {
return o.queryRowsTo(ptrStruct, keyCol, valueCol)
}
// return prepared raw statement for used in times.
func (o *rawSet) Prepare() (RawPreparer, error) { func (o *rawSet) Prepare() (RawPreparer, error) {
return newRawPreparer(o) return newRawPreparer(o)
} }

View File

@ -1322,58 +1322,6 @@ func TestRawQueryRow(t *testing.T) {
} }
} }
type Tmp struct {
Skip0 string
Id int
Char *string
Skip1 int `orm:"-"`
Date time.Time
DateTime time.Time
}
Boolean = false
Text = ""
Int64 = 0
Uint = 0
tmp := new(Tmp)
cols = []string{
"int", "char", "date", "datetime", "boolean", "text", "int64", "uint",
}
query = fmt.Sprintf("SELECT NULL, %s%s%s FROM data WHERE id = ?", Q, strings.Join(cols, sep), Q)
values = []interface{}{
tmp, &Boolean, &Text, &Int64, &Uint,
}
err = dORM.Raw(query, 1).QueryRow(values...)
throwFailNow(t, err)
for _, col := range cols {
switch col {
case "id":
throwFail(t, AssertIs(tmp.Id, data_values[col]))
case "char":
c := tmp.Char
throwFail(t, AssertIs(*c, data_values[col]))
case "date":
v := tmp.Date.In(DefaultTimeLoc)
value := data_values[col].(time.Time).In(DefaultTimeLoc)
throwFail(t, AssertIs(v, value, test_Date))
case "datetime":
v := tmp.DateTime.In(DefaultTimeLoc)
value := data_values[col].(time.Time).In(DefaultTimeLoc)
throwFail(t, AssertIs(v, value, test_DateTime))
case "boolean":
throwFail(t, AssertIs(Boolean, data_values[col]))
case "text":
throwFail(t, AssertIs(Text, data_values[col]))
case "int64":
throwFail(t, AssertIs(Int64, data_values[col]))
case "uint":
throwFail(t, AssertIs(Uint, data_values[col]))
}
}
var ( var (
uid int uid int
status *int status *int
@ -1394,22 +1342,13 @@ func TestRawQueryRow(t *testing.T) {
func TestQueryRows(t *testing.T) { func TestQueryRows(t *testing.T) {
Q := dDbBaser.TableQuote() Q := dDbBaser.TableQuote()
cols := []string{
"id", "boolean", "char", "text", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32",
"int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal",
}
var datas []*Data var datas []*Data
var dids []int
sep := fmt.Sprintf("%s, %s", Q, Q) query := fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q)
query := fmt.Sprintf("SELECT %s%s%s, id FROM %sdata%s", Q, strings.Join(cols, sep), Q, Q, Q) num, err := dORM.Raw(query).QueryRows(&datas)
num, err := dORM.Raw(query).QueryRows(&datas, &dids)
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(num, 1))
throwFailNow(t, AssertIs(len(datas), 1)) throwFailNow(t, AssertIs(len(datas), 1))
throwFailNow(t, AssertIs(len(dids), 1))
throwFailNow(t, AssertIs(dids[0], 1))
ind := reflect.Indirect(reflect.ValueOf(datas[0])) ind := reflect.Indirect(reflect.ValueOf(datas[0]))
@ -1427,90 +1366,43 @@ func TestQueryRows(t *testing.T) {
throwFail(t, AssertIs(vu == value, true), value, vu) throwFail(t, AssertIs(vu == value, true), value, vu)
} }
type Tmp struct { var datas2 []Data
Id int
Name string query = fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q)
Skiped0 string `orm:"-"` num, err = dORM.Raw(query).QueryRows(&datas2)
Pid *int throwFailNow(t, err)
Skiped1 Data throwFailNow(t, AssertIs(num, 1))
Skiped2 *Data throwFailNow(t, AssertIs(len(datas2), 1))
ind = reflect.Indirect(reflect.ValueOf(datas2[0]))
for name, value := range Data_Values {
e := ind.FieldByName(name)
vu := e.Interface()
switch name {
case "Date":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date)
case "DateTime":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
}
throwFail(t, AssertIs(vu == value, true), value, vu)
} }
var ( var ids []int
ids []int var usernames []string
userNames []string query = fmt.Sprintf("SELECT %sid%s, %suser_name%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q, Q, Q)
profileIds1 []int num, err = dORM.Raw(query).QueryRows(&ids, &usernames)
profileIds2 []*int
createds []time.Time
updateds []time.Time
tmps1 []*Tmp
tmps2 []Tmp
)
cols = []string{
"id", "user_name", "profile_id", "profile_id", "id", "user_name", "profile_id", "id", "user_name", "profile_id", "created", "updated",
}
query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s ORDER BY id", Q, strings.Join(cols, sep), Q, Q, Q)
num, err = dORM.Raw(query).QueryRows(&ids, &userNames, &profileIds1, &profileIds2, &tmps1, &tmps2, &createds, &updateds)
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 3)) throwFailNow(t, AssertIs(num, 3))
throwFailNow(t, AssertIs(len(ids), 3))
var users []User throwFailNow(t, AssertIs(ids[0], 2))
dORM.QueryTable("user").OrderBy("Id").All(&users) throwFailNow(t, AssertIs(usernames[0], "slene"))
throwFailNow(t, AssertIs(ids[1], 3))
for i := 0; i < 3; i++ { throwFailNow(t, AssertIs(usernames[1], "astaxie"))
id := ids[i] throwFailNow(t, AssertIs(ids[2], 4))
name := userNames[i] throwFailNow(t, AssertIs(usernames[2], "nobody"))
pid1 := profileIds1[i]
pid2 := profileIds2[i]
created := createds[i]
updated := updateds[i]
user := users[i]
throwFailNow(t, AssertIs(id, user.Id))
throwFailNow(t, AssertIs(name, user.UserName))
if user.Profile != nil {
throwFailNow(t, AssertIs(pid1, user.Profile.Id))
throwFailNow(t, AssertIs(*pid2, user.Profile.Id))
} else {
throwFailNow(t, AssertIs(pid1, 0))
throwFailNow(t, AssertIs(pid2, nil))
}
throwFailNow(t, AssertIs(created, user.Created, test_Date))
throwFailNow(t, AssertIs(updated, user.Updated, test_DateTime))
tmp := tmps1[i]
tmp1 := *tmp
throwFailNow(t, AssertIs(tmp1.Id, user.Id))
throwFailNow(t, AssertIs(tmp1.Name, user.UserName))
if user.Profile != nil {
pid := tmp1.Pid
throwFailNow(t, AssertIs(*pid, user.Profile.Id))
} else {
throwFailNow(t, AssertIs(tmp1.Pid, nil))
}
tmp2 := tmps2[i]
throwFailNow(t, AssertIs(tmp2.Id, user.Id))
throwFailNow(t, AssertIs(tmp2.Name, user.UserName))
if user.Profile != nil {
pid := tmp2.Pid
throwFailNow(t, AssertIs(*pid, user.Profile.Id))
} else {
throwFailNow(t, AssertIs(tmp2.Pid, nil))
}
}
type Sec struct {
Id int
Name string
}
var tmp []*Sec
query = fmt.Sprintf("SELECT NULL, NULL FROM %suser%s LIMIT 1", Q, Q)
num, err = dORM.Raw(query).QueryRows(&tmp)
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
throwFail(t, AssertIs(tmp[0], nil))
} }
func TestRawValues(t *testing.T) { func TestRawValues(t *testing.T) {
@ -1669,6 +1561,32 @@ func TestDelete(t *testing.T) {
num, err = qs.Filter("user_name", "slene").Filter("profile__isnull", true).Count() num, err = qs.Filter("user_name", "slene").Filter("profile__isnull", true).Count()
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
qs = dORM.QueryTable("comment")
num, err = qs.Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 6))
qs = dORM.QueryTable("post")
num, err = qs.Filter("Id", 3).Delete()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
qs = dORM.QueryTable("comment")
num, err = qs.Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 4))
fmt.Println("...")
qs = dORM.QueryTable("comment")
num, err = qs.Filter("Post__User", 3).Delete()
throwFail(t, err)
throwFail(t, AssertIs(num, 3))
qs = dORM.QueryTable("comment")
num, err = qs.Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
} }
func TestTransaction(t *testing.T) { func TestTransaction(t *testing.T) {
@ -1724,3 +1642,41 @@ func TestTransaction(t *testing.T) {
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
} }
func TestReadOrCreate(t *testing.T) {
u := &User{
UserName: "Kyle",
Email: "kylemcc@gmail.com",
Password: "other_pass",
Status: 7,
IsStaff: false,
IsActive: true,
}
created, pk, err := dORM.ReadOrCreate(u, "UserName")
throwFail(t, err)
throwFail(t, AssertIs(created, true))
throwFail(t, AssertIs(u.UserName, "Kyle"))
throwFail(t, AssertIs(u.Email, "kylemcc@gmail.com"))
throwFail(t, AssertIs(u.Password, "other_pass"))
throwFail(t, AssertIs(u.Status, 7))
throwFail(t, AssertIs(u.IsStaff, false))
throwFail(t, AssertIs(u.IsActive, true))
throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), test_Date))
throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), test_DateTime))
nu := &User{UserName: u.UserName, Email: "someotheremail@gmail.com"}
created, pk, err = dORM.ReadOrCreate(nu, "UserName")
throwFail(t, err)
throwFail(t, AssertIs(created, false))
throwFail(t, AssertIs(nu.Id, u.Id))
throwFail(t, AssertIs(pk, u.Id))
throwFail(t, AssertIs(nu.UserName, u.UserName))
throwFail(t, AssertIs(nu.Email, u.Email)) // should contain the value in the table, not the one specified above
throwFail(t, AssertIs(nu.Password, u.Password))
throwFail(t, AssertIs(nu.Status, u.Status))
throwFail(t, AssertIs(nu.IsStaff, u.IsStaff))
throwFail(t, AssertIs(nu.IsActive, u.IsActive))
dORM.Delete(u)
}

View File

@ -6,11 +6,13 @@ import (
"time" "time"
) )
// database driver
type Driver interface { type Driver interface {
Name() string Name() string
Type() DriverType Type() DriverType
} }
// field info
type Fielder interface { type Fielder interface {
String() string String() string
FieldType() int FieldType() int
@ -18,9 +20,12 @@ type Fielder interface {
RawValue() interface{} RawValue() interface{}
} }
// orm struct
type Ormer interface { type Ormer interface {
Read(interface{}, ...string) error Read(interface{}, ...string) error
ReadOrCreate(interface{}, string, ...string) (bool, int64, error)
Insert(interface{}) (int64, error) Insert(interface{}) (int64, error)
InsertMulti(int, interface{}) (int64, error)
Update(interface{}, ...string) (int64, error) Update(interface{}, ...string) (int64, error)
Delete(interface{}) (int64, error) Delete(interface{}) (int64, error)
LoadRelated(interface{}, string, ...interface{}) (int64, error) LoadRelated(interface{}, string, ...interface{}) (int64, error)
@ -32,13 +37,16 @@ type Ormer interface {
Rollback() error Rollback() error
Raw(string, ...interface{}) RawSeter Raw(string, ...interface{}) RawSeter
Driver() Driver Driver() Driver
GetDB() dbQuerier
} }
// insert prepared statement
type Inserter interface { type Inserter interface {
Insert(interface{}) (int64, error) Insert(interface{}) (int64, error)
Close() error Close() error
} }
// query seter
type QuerySeter interface { type QuerySeter interface {
Filter(string, ...interface{}) QuerySeter Filter(string, ...interface{}) QuerySeter
Exclude(string, ...interface{}) QuerySeter Exclude(string, ...interface{}) QuerySeter
@ -57,8 +65,11 @@ type QuerySeter interface {
Values(*[]Params, ...string) (int64, error) Values(*[]Params, ...string) (int64, error)
ValuesList(*[]ParamsList, ...string) (int64, error) ValuesList(*[]ParamsList, ...string) (int64, error)
ValuesFlat(*ParamsList, string) (int64, error) ValuesFlat(*ParamsList, string) (int64, error)
RowsToMap(*Params, string, string) (int64, error)
RowsToStruct(interface{}, string, string) (int64, error)
} }
// model to model query struct
type QueryM2Mer interface { type QueryM2Mer interface {
Add(...interface{}) (int64, error) Add(...interface{}) (int64, error)
Remove(...interface{}) (int64, error) Remove(...interface{}) (int64, error)
@ -67,22 +78,27 @@ type QueryM2Mer interface {
Count() (int64, error) Count() (int64, error)
} }
// raw query statement
type RawPreparer interface { type RawPreparer interface {
Exec(...interface{}) (sql.Result, error) Exec(...interface{}) (sql.Result, error)
Close() error Close() error
} }
// raw query seter
type RawSeter interface { type RawSeter interface {
Exec() (sql.Result, error) Exec() (sql.Result, error)
QueryRow(...interface{}) error QueryRow(...interface{}) error
QueryRows(...interface{}) (int64, error) QueryRows(...interface{}) (int64, error)
SetArgs(...interface{}) RawSeter SetArgs(...interface{}) RawSeter
Values(*[]Params) (int64, error) Values(*[]Params, ...string) (int64, error)
ValuesList(*[]ParamsList) (int64, error) ValuesList(*[]ParamsList, ...string) (int64, error)
ValuesFlat(*ParamsList) (int64, error) ValuesFlat(*ParamsList, ...string) (int64, error)
RowsToMap(*Params, string, string) (int64, error)
RowsToStruct(interface{}, string, string) (int64, error)
Prepare() (RawPreparer, error) Prepare() (RawPreparer, error)
} }
// statement querier
type stmtQuerier interface { type stmtQuerier interface {
Close() error Close() error
Exec(args ...interface{}) (sql.Result, error) Exec(args ...interface{}) (sql.Result, error)
@ -90,6 +106,7 @@ type stmtQuerier interface {
QueryRow(args ...interface{}) *sql.Row QueryRow(args ...interface{}) *sql.Row
} }
// db querier
type dbQuerier interface { type dbQuerier interface {
Prepare(query string) (*sql.Stmt, error) Prepare(query string) (*sql.Stmt, error)
Exec(query string, args ...interface{}) (sql.Result, error) Exec(query string, args ...interface{}) (sql.Result, error)
@ -97,19 +114,31 @@ type dbQuerier interface {
QueryRow(query string, args ...interface{}) *sql.Row QueryRow(query string, args ...interface{}) *sql.Row
} }
// type DB interface {
// Begin() (*sql.Tx, error)
// Prepare(query string) (stmtQuerier, error)
// Exec(query string, args ...interface{}) (sql.Result, error)
// Query(query string, args ...interface{}) (*sql.Rows, error)
// QueryRow(query string, args ...interface{}) *sql.Row
// }
// transaction beginner
type txer interface { type txer interface {
Begin() (*sql.Tx, error) Begin() (*sql.Tx, error)
} }
// transaction ending
type txEnder interface { type txEnder interface {
Commit() error Commit() error
Rollback() error Rollback() error
} }
// base database struct
type dbBaser interface { type dbBaser interface {
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
InsertValue(dbQuerier, *modelInfo, []string, []interface{}) (int64, error) InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
@ -123,6 +152,7 @@ type dbBaser interface {
GenerateOperatorLeftCol(*fieldInfo, string, *string) GenerateOperatorLeftCol(*fieldInfo, string, *string)
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error)
MaxLimit() uint64 MaxLimit() uint64
TableQuote() string TableQuote() string
ReplaceMarks(*string) ReplaceMarks(*string)

View File

@ -10,6 +10,7 @@ import (
type StrTo string type StrTo string
// set string
func (f *StrTo) Set(v string) { func (f *StrTo) Set(v string) {
if v != "" { if v != "" {
*f = StrTo(v) *f = StrTo(v)
@ -18,77 +19,93 @@ func (f *StrTo) Set(v string) {
} }
} }
// clean string
func (f *StrTo) Clear() { func (f *StrTo) Clear() {
*f = StrTo(0x1E) *f = StrTo(0x1E)
} }
// check string exist
func (f StrTo) Exist() bool { func (f StrTo) Exist() bool {
return string(f) != string(0x1E) return string(f) != string(0x1E)
} }
// string to bool
func (f StrTo) Bool() (bool, error) { func (f StrTo) Bool() (bool, error) {
return strconv.ParseBool(f.String()) return strconv.ParseBool(f.String())
} }
// string to float32
func (f StrTo) Float32() (float32, error) { func (f StrTo) Float32() (float32, error) {
v, err := strconv.ParseFloat(f.String(), 32) v, err := strconv.ParseFloat(f.String(), 32)
return float32(v), err return float32(v), err
} }
// string to float64
func (f StrTo) Float64() (float64, error) { func (f StrTo) Float64() (float64, error) {
return strconv.ParseFloat(f.String(), 64) return strconv.ParseFloat(f.String(), 64)
} }
// string to int
func (f StrTo) Int() (int, error) { func (f StrTo) Int() (int, error) {
v, err := strconv.ParseInt(f.String(), 10, 32) v, err := strconv.ParseInt(f.String(), 10, 32)
return int(v), err return int(v), err
} }
// string to int8
func (f StrTo) Int8() (int8, error) { func (f StrTo) Int8() (int8, error) {
v, err := strconv.ParseInt(f.String(), 10, 8) v, err := strconv.ParseInt(f.String(), 10, 8)
return int8(v), err return int8(v), err
} }
// string to int16
func (f StrTo) Int16() (int16, error) { func (f StrTo) Int16() (int16, error) {
v, err := strconv.ParseInt(f.String(), 10, 16) v, err := strconv.ParseInt(f.String(), 10, 16)
return int16(v), err return int16(v), err
} }
// string to int32
func (f StrTo) Int32() (int32, error) { func (f StrTo) Int32() (int32, error) {
v, err := strconv.ParseInt(f.String(), 10, 32) v, err := strconv.ParseInt(f.String(), 10, 32)
return int32(v), err return int32(v), err
} }
// string to int64
func (f StrTo) Int64() (int64, error) { func (f StrTo) Int64() (int64, error) {
v, err := strconv.ParseInt(f.String(), 10, 64) v, err := strconv.ParseInt(f.String(), 10, 64)
return int64(v), err return int64(v), err
} }
// string to uint
func (f StrTo) Uint() (uint, error) { func (f StrTo) Uint() (uint, error) {
v, err := strconv.ParseUint(f.String(), 10, 32) v, err := strconv.ParseUint(f.String(), 10, 32)
return uint(v), err return uint(v), err
} }
// string to uint8
func (f StrTo) Uint8() (uint8, error) { func (f StrTo) Uint8() (uint8, error) {
v, err := strconv.ParseUint(f.String(), 10, 8) v, err := strconv.ParseUint(f.String(), 10, 8)
return uint8(v), err return uint8(v), err
} }
// string to uint16
func (f StrTo) Uint16() (uint16, error) { func (f StrTo) Uint16() (uint16, error) {
v, err := strconv.ParseUint(f.String(), 10, 16) v, err := strconv.ParseUint(f.String(), 10, 16)
return uint16(v), err return uint16(v), err
} }
// string to uint31
func (f StrTo) Uint32() (uint32, error) { func (f StrTo) Uint32() (uint32, error) {
v, err := strconv.ParseUint(f.String(), 10, 32) v, err := strconv.ParseUint(f.String(), 10, 32)
return uint32(v), err return uint32(v), err
} }
// string to uint64
func (f StrTo) Uint64() (uint64, error) { func (f StrTo) Uint64() (uint64, error) {
v, err := strconv.ParseUint(f.String(), 10, 64) v, err := strconv.ParseUint(f.String(), 10, 64)
return uint64(v), err return uint64(v), err
} }
// string to string
func (f StrTo) String() string { func (f StrTo) String() string {
if f.Exist() { if f.Exist() {
return string(f) return string(f)
@ -96,6 +113,7 @@ func (f StrTo) String() string {
return "" return ""
} }
// interface to string
func ToStr(value interface{}, args ...int) (s string) { func ToStr(value interface{}, args ...int) (s string) {
switch v := value.(type) { switch v := value.(type) {
case bool: case bool:
@ -134,6 +152,7 @@ func ToStr(value interface{}, args ...int) (s string) {
return s return s
} }
// interface to int64
func ToInt64(value interface{}) (d int64) { func ToInt64(value interface{}) (d int64) {
val := reflect.ValueOf(value) val := reflect.ValueOf(value)
switch value.(type) { switch value.(type) {
@ -147,6 +166,7 @@ func ToInt64(value interface{}) (d int64) {
return return
} }
// snake string, XxYy to xx_yy
func snakeString(s string) string { func snakeString(s string) string {
data := make([]byte, 0, len(s)*2) data := make([]byte, 0, len(s)*2)
j := false j := false
@ -164,6 +184,7 @@ func snakeString(s string) string {
return strings.ToLower(string(data[:len(data)])) return strings.ToLower(string(data[:len(data)]))
} }
// camel string, xx_yy to XxYy
func camelString(s string) string { func camelString(s string) string {
data := make([]byte, 0, len(s)) data := make([]byte, 0, len(s))
j := false j := false
@ -190,6 +211,7 @@ func camelString(s string) string {
type argString []string type argString []string
// get string by index from string slice
func (a argString) Get(i int, args ...string) (r string) { func (a argString) Get(i int, args ...string) (r string) {
if i >= 0 && i < len(a) { if i >= 0 && i < len(a) {
r = a[i] r = a[i]
@ -201,6 +223,7 @@ func (a argString) Get(i int, args ...string) (r string) {
type argInt []int type argInt []int
// get int by index from int slice
func (a argInt) Get(i int, args ...int) (r int) { func (a argInt) Get(i int, args ...int) (r int) {
if i >= 0 && i < len(a) { if i >= 0 && i < len(a) {
r = a[i] r = a[i]
@ -213,6 +236,7 @@ func (a argInt) Get(i int, args ...int) (r int) {
type argAny []interface{} type argAny []interface{}
// get interface by index from interface slice
func (a argAny) Get(i int, args ...interface{}) (r interface{}) { func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
if i >= 0 && i < len(a) { if i >= 0 && i < len(a) {
r = a[i] r = a[i]
@ -223,15 +247,18 @@ func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
return return
} }
// parse time to string with location
func timeParse(dateString, format string) (time.Time, error) { func timeParse(dateString, format string) (time.Time, error) {
tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc) tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
return tp, err return tp, err
} }
// format time string
func timeFormat(t time.Time, format string) string { func timeFormat(t time.Time, format string) string {
return t.Format(format) return t.Format(format)
} }
// get pointer indirect type
func indirectType(v reflect.Type) reflect.Type { func indirectType(v reflect.Type) reflect.Type {
switch v.Kind() { switch v.Kind() {
case reflect.Ptr: case reflect.Ptr:

75
plugins/auth/basic.go Normal file
View File

@ -0,0 +1,75 @@
// basic auth for plugin
package auth
// Example:
// func SecretAuth(username, password string) bool {
// if username == "astaxie" && password == "helloBeego" {
// return true
// }
// return false
// }
// authPlugin := auth.NewBasicAuthenticator(SecretAuth)
// beego.AddFilter("*","AfterStatic",authPlugin)
import (
"encoding/base64"
"net/http"
"strings"
"github.com/astaxie/beego"
"github.com/astaxie/beego/context"
)
func NewBasicAuthenticator(secrets SecretProvider, Realm string) beego.FilterFunc {
return func(ctx *context.Context) {
a := &BasicAuth{Secrets: secrets, Realm: Realm}
if username := a.CheckAuth(ctx.Request); username == "" {
a.RequireAuth(ctx.ResponseWriter, ctx.Request)
}
}
}
type SecretProvider func(user, pass string) bool
type BasicAuth struct {
Secrets SecretProvider
Realm string
}
/*
Checks the username/password combination from the request. Returns
either an empty string (authentication failed) or the name of the
authenticated user.
Supports MD5 and SHA1 password entries
*/
func (a *BasicAuth) CheckAuth(r *http.Request) string {
s := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
if len(s) != 2 || s[0] != "Basic" {
return ""
}
b, err := base64.StdEncoding.DecodeString(s[1])
if err != nil {
return ""
}
pair := strings.SplitN(string(b), ":", 2)
if len(pair) != 2 {
return ""
}
if a.Secrets(pair[0], pair[1]) {
return pair[0]
}
return ""
}
/*
http.Handler for BasicAuth which initiates the authentication process
(or requires reauthentication).
*/
func (a *BasicAuth) RequireAuth(w http.ResponseWriter, r *http.Request) {
w.Header().Set("WWW-Authenticate", `Basic realm="`+a.Realm+`"`)
w.WriteHeader(401)
w.Write([]byte("401 Unauthorized\n"))
}

View File

@ -1,7 +1,10 @@
package beego package beego
import ( import (
"bufio"
"errors"
"fmt" "fmt"
"net"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
@ -30,6 +33,14 @@ const (
var ( var (
// supported http methods. // supported http methods.
HTTPMETHOD = []string{"get", "post", "put", "delete", "patch", "options", "head"} HTTPMETHOD = []string{"get", "post", "put", "delete", "patch", "options", "head"}
// these beego.Controller's methods shouldn't reflect to AutoRouter
exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString",
"RenderBytes", "Redirect", "Abort", "StopRun", "UrlFor", "ServeJson", "ServeJsonp",
"ServeXml", "Input", "ParseForm", "GetString", "GetStrings", "GetInt", "GetBool",
"GetFloat", "GetFile", "SaveToFile", "StartSession", "SetSession", "GetSession",
"DelSession", "SessionRegenerateID", "DestroySession", "IsAjax", "GetSecureCookie",
"SetSecureCookie", "XsrfToken", "CheckXsrfCookie", "XsrfFormHtml",
"GetControllerAndAction"}
) )
type controllerInfo struct { type controllerInfo struct {
@ -77,7 +88,7 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM
params := make(map[int]string) params := make(map[int]string)
for i, part := range parts { for i, part := range parts {
if strings.HasPrefix(part, ":") { if strings.HasPrefix(part, ":") {
expr := "(.+)" expr := "(.*)"
//a user may choose to override the defult expression //a user may choose to override the defult expression
// similar to expressjs: /user/:id([0-9]+) // similar to expressjs: /user/:id([0-9]+)
if index := strings.Index(part, "("); index != -1 { if index := strings.Index(part, "("); index != -1 {
@ -100,7 +111,7 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM
j++ j++
} }
if strings.HasPrefix(part, "*") { if strings.HasPrefix(part, "*") {
expr := "(.+)" expr := "(.*)"
if part == "*.*" { if part == "*.*" {
params[j] = ":path" params[j] = ":path"
parts[i] = "([^.]+).([^.]+)" parts[i] = "([^.]+).([^.]+)"
@ -218,8 +229,8 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM
// Add auto router to ControllerRegistor. // Add auto router to ControllerRegistor.
// example beego.AddAuto(&MainContorlller{}), // example beego.AddAuto(&MainContorlller{}),
// MainController has method List and Page. // MainController has method List and Page.
// visit the url /main/list to exec List function // visit the url /main/list to execute List function
// /main/page to exec Page function. // /main/page to execute Page function.
func (p *ControllerRegistor) AddAuto(c ControllerInterface) { func (p *ControllerRegistor) AddAuto(c ControllerInterface) {
p.enableAuto = true p.enableAuto = true
reflectVal := reflect.ValueOf(c) reflectVal := reflect.ValueOf(c)
@ -232,14 +243,42 @@ func (p *ControllerRegistor) AddAuto(c ControllerInterface) {
p.autoRouter[firstParam] = make(map[string]reflect.Type) p.autoRouter[firstParam] = make(map[string]reflect.Type)
} }
for i := 0; i < rt.NumMethod(); i++ { for i := 0; i < rt.NumMethod(); i++ {
p.autoRouter[firstParam][rt.Method(i).Name] = ct if !utils.InSlice(rt.Method(i).Name, exceptMethod) {
p.autoRouter[firstParam][rt.Method(i).Name] = ct
}
}
}
// Add auto router to ControllerRegistor with prefix.
// example beego.AddAutoPrefix("/admin",&MainContorlller{}),
// MainController has method List and Page.
// visit the url /admin/main/list to execute List function
// /admin/main/page to execute Page function.
func (p *ControllerRegistor) AddAutoPrefix(prefix string, c ControllerInterface) {
p.enableAuto = true
reflectVal := reflect.ValueOf(c)
rt := reflectVal.Type()
ct := reflect.Indirect(reflectVal).Type()
firstParam := strings.Trim(prefix, "/") + "/" + strings.ToLower(strings.TrimSuffix(ct.Name(), "Controller"))
if _, ok := p.autoRouter[firstParam]; ok {
return
} else {
p.autoRouter[firstParam] = make(map[string]reflect.Type)
}
for i := 0; i < rt.NumMethod(); i++ {
if !utils.InSlice(rt.Method(i).Name, exceptMethod) {
p.autoRouter[firstParam][rt.Method(i).Name] = ct
}
} }
} }
// [Deprecated] use InsertFilter. // [Deprecated] use InsertFilter.
// Add FilterFunc with pattern for action. // Add FilterFunc with pattern for action.
func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc) { func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc) error {
mr := buildFilter(pattern, filter) mr, err := buildFilter(pattern, filter)
if err != nil {
return err
}
switch action { switch action {
case "BeforeRouter": case "BeforeRouter":
p.filters[BeforeRouter] = append(p.filters[BeforeRouter], mr) p.filters[BeforeRouter] = append(p.filters[BeforeRouter], mr)
@ -253,13 +292,18 @@ func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc
p.filters[FinishRouter] = append(p.filters[FinishRouter], mr) p.filters[FinishRouter] = append(p.filters[FinishRouter], mr)
} }
p.enableFilter = true p.enableFilter = true
return nil
} }
// Add a FilterFunc with pattern rule and action constant. // Add a FilterFunc with pattern rule and action constant.
func (p *ControllerRegistor) InsertFilter(pattern string, pos int, filter FilterFunc) { func (p *ControllerRegistor) InsertFilter(pattern string, pos int, filter FilterFunc) error {
mr := buildFilter(pattern, filter) mr, err := buildFilter(pattern, filter)
if err != nil {
return err
}
p.filters[pos] = append(p.filters[pos], mr) p.filters[pos] = append(p.filters[pos], mr)
p.enableFilter = true p.enableFilter = true
return nil
} }
// UrlFor does another controller handler in this request function. // UrlFor does another controller handler in this request function.
@ -485,7 +529,9 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
// session init // session init
if SessionOn { if SessionOn {
context.Input.CruSession = GlobalSessions.SessionStart(w, r) context.Input.CruSession = GlobalSessions.SessionStart(w, r)
defer context.Input.CruSession.SessionRelease() defer func() {
context.Input.CruSession.SessionRelease(w)
}()
} }
if !utils.InSlice(strings.ToLower(r.Method), HTTPMETHOD) { if !utils.InSlice(strings.ToLower(r.Method), HTTPMETHOD) {
@ -575,12 +621,11 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
} }
// pattern /admin url /admin 200 /admin/ 200 // pattern /admin url /admin 200 /admin/ 200
// pattern /admin/ url /admin 301 /admin/ 200 // pattern /admin/ url /admin 301 /admin/ 200
if requestPath[n-1] != '/' && len(route.pattern) == n+1 && if requestPath[n-1] != '/' && requestPath+"/" == route.pattern {
route.pattern[n] == '/' && route.pattern[:n] == requestPath {
http.Redirect(w, r, requestPath+"/", 301) http.Redirect(w, r, requestPath+"/", 301)
goto Admin goto Admin
} }
if requestPath[n-1] == '/' && n >= 2 && requestPath[:n-2] == route.pattern { if requestPath[n-1] == '/' && route.pattern+"/" == requestPath {
runMethod = p.getRunMethod(r.Method, context, route) runMethod = p.getRunMethod(r.Method, context, route)
if runMethod != "" { if runMethod != "" {
runrouter = route.controllerType runrouter = route.controllerType
@ -857,3 +902,13 @@ func (w *responseWriter) WriteHeader(code int) {
w.started = true w.started = true
w.writer.WriteHeader(code) w.writer.WriteHeader(code)
} }
// hijacker for http
func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hj, ok := w.writer.(http.Hijacker)
if !ok {
println("supported?")
return nil, nil, errors.New("webserver doesn't support hijacking")
}
return hj.Hijack()
}

View File

@ -198,3 +198,15 @@ func TestPrepare(t *testing.T) {
t.Errorf(w.Body.String() + "user define func can't run") t.Errorf(w.Body.String() + "user define func can't run")
} }
} }
func TestAutoPrefix(t *testing.T) {
r, _ := http.NewRequest("GET", "/admin/test/list", nil)
w := httptest.NewRecorder()
handler := NewControllerRegistor()
handler.AddAutoPrefix("/admin", &TestController{})
handler.ServeHTTP(w, r)
if w.Body.String() != "i am list" {
t.Errorf("TestAutoPrefix can't run")
}
}

View File

@ -28,21 +28,21 @@ Then in you web app init the global session manager
* Use **memory** as provider: * Use **memory** as provider:
func init() { func init() {
globalSessions, _ = session.NewManager("memory", "gosessionid", 3600,"") globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid","gclifetime":3600}`)
go globalSessions.GC() go globalSessions.GC()
} }
* Use **file** as provider, the last param is the path where you want file to be stored: * Use **file** as provider, the last param is the path where you want file to be stored:
func init() { func init() {
globalSessions, _ = session.NewManager("file", "gosessionid", 3600, "./tmp") globalSessions, _ = session.NewManager("file",`{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","./tmp"}`)
go globalSessions.GC() go globalSessions.GC()
} }
* Use **Redis** as provider, the last param is the Redis conn address,poolsize,password: * Use **Redis** as provider, the last param is the Redis conn address,poolsize,password:
func init() { func init() {
globalSessions, _ = session.NewManager("redis", "gosessionid", 3600, "127.0.0.1:6379,100,astaxie") globalSessions, _ = session.NewManager("redis", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","127.0.0.1:6379,100,astaxie"}`)
go globalSessions.GC() go globalSessions.GC()
} }
@ -50,15 +50,24 @@ Then in you web app init the global session manager
func init() { func init() {
globalSessions, _ = session.NewManager( globalSessions, _ = session.NewManager(
"mysql", "gosessionid", 3600, "username:password@protocol(address)/dbname?param=value") "mysql", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","username:password@protocol(address)/dbname?param=value"}`)
go globalSessions.GC() go globalSessions.GC()
} }
* Use **Cookie** as provider:
func init() {
globalSessions, _ = session.NewManager(
"cookie", `{"cookieName":"gosessionid","enableSetCookie":false,gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`)
go globalSessions.GC()
}
Finally in the handlerfunc you can use it like this Finally in the handlerfunc you can use it like this
func login(w http.ResponseWriter, r *http.Request) { func login(w http.ResponseWriter, r *http.Request) {
sess := globalSessions.SessionStart(w, r) sess := globalSessions.SessionStart(w, r)
defer sess.SessionRelease() defer sess.SessionRelease(w)
username := sess.Get("username") username := sess.Get("username")
fmt.Println(username) fmt.Println(username)
if r.Method == "GET" { if r.Method == "GET" {
@ -78,19 +87,19 @@ When you develop a web app, maybe you want to write own provider because you mus
Writing a provider is easy. You only need to define two struct types Writing a provider is easy. You only need to define two struct types
(Session and Provider), which satisfy the interface definition. (Session and Provider), which satisfy the interface definition.
Maybe you will find the **memory** provider as good example. Maybe you will find the **memory** provider is a good example.
type SessionStore interface { type SessionStore interface {
Set(key, value interface{}) error //set session value Set(key, value interface{}) error //set session value
Get(key interface{}) interface{} //get session value Get(key interface{}) interface{} //get session value
Delete(key interface{}) error //delete session value Delete(key interface{}) error //delete session value
SessionID() string //back current sessionID SessionID() string //back current sessionID
SessionRelease() // release the resource & save data to provider SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data
Flush() error //delete all data Flush() error //delete all data
} }
type Provider interface { type Provider interface {
SessionInit(maxlifetime int64, savePath string) error SessionInit(gclifetime int64, config string) error
SessionRead(sid string) (SessionStore, error) SessionRead(sid string) (SessionStore, error)
SessionExist(sid string) bool SessionExist(sid string) bool
SessionRegenerate(oldsid, sid string) (SessionStore, error) SessionRegenerate(oldsid, sid string) (SessionStore, error)

170
session/sess_cookie.go Normal file
View File

@ -0,0 +1,170 @@
package session
import (
"crypto/aes"
"crypto/cipher"
"encoding/json"
"net/http"
"net/url"
"sync"
)
var cookiepder = &CookieProvider{}
// Cookie SessionStore
type CookieSessionStore struct {
sid string
values map[interface{}]interface{} // session data
lock sync.RWMutex
}
// Set value to cookie session.
// the value are encoded as gob with hash block string.
func (st *CookieSessionStore) Set(key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values[key] = value
return nil
}
// Get value from cookie session
func (st *CookieSessionStore) Get(key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.values[key]; ok {
return v
} else {
return nil
}
return nil
}
// Delete value in cookie session
func (st *CookieSessionStore) Delete(key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.values, key)
return nil
}
// Clean all values in cookie session
func (st *CookieSessionStore) Flush() error {
st.lock.Lock()
defer st.lock.Unlock()
st.values = make(map[interface{}]interface{})
return nil
}
// Return id of this cookie session
func (st *CookieSessionStore) SessionID() string {
return st.sid
}
// Write cookie session to http response cookie
func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) {
str, err := encodeCookie(cookiepder.block,
cookiepder.config.SecurityKey,
cookiepder.config.SecurityName,
st.values)
if err != nil {
return
}
cookie := &http.Cookie{Name: cookiepder.config.CookieName,
Value: url.QueryEscape(str),
Path: "/",
HttpOnly: true,
Secure: cookiepder.config.Secure}
http.SetCookie(w, cookie)
return
}
type cookieConfig struct {
SecurityKey string `json:"securityKey"`
BlockKey string `json:"blockKey"`
SecurityName string `json:"securityName"`
CookieName string `json:"cookieName"`
Secure bool `json:"secure"`
Maxage int `json:"maxage"`
}
// Cookie session provider
type CookieProvider struct {
maxlifetime int64
config *cookieConfig
block cipher.Block
}
// Init cookie session provider with max lifetime and config json.
// maxlifetime is ignored.
// json config:
// securityKey - hash string
// blockKey - gob encode hash string. it's saved as aes crypto.
// securityName - recognized name in encoded cookie string
// cookieName - cookie name
// maxage - cookie max life time.
func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error {
pder.config = &cookieConfig{}
err := json.Unmarshal([]byte(config), pder.config)
if err != nil {
return err
}
if pder.config.BlockKey == "" {
pder.config.BlockKey = string(generateRandomKey(16))
}
if pder.config.SecurityName == "" {
pder.config.SecurityName = string(generateRandomKey(20))
}
pder.block, err = aes.NewCipher([]byte(pder.config.BlockKey))
if err != nil {
return err
}
return nil
}
// Get SessionStore in cooke.
// decode cooke string to map and put into SessionStore with sid.
func (pder *CookieProvider) SessionRead(sid string) (SessionStore, error) {
maps, _ := decodeCookie(pder.block,
pder.config.SecurityKey,
pder.config.SecurityName,
sid, pder.maxlifetime)
if maps == nil {
maps = make(map[interface{}]interface{})
}
rs := &CookieSessionStore{sid: sid, values: maps}
return rs, nil
}
// Cookie session is always existed
func (pder *CookieProvider) SessionExist(sid string) bool {
return true
}
// Implement method, no used.
func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
return nil, nil
}
// Implement method, no used.
func (pder *CookieProvider) SessionDestroy(sid string) error {
return nil
}
// Implement method, no used.
func (pder *CookieProvider) SessionGC() {
return
}
// Implement method, return 0.
func (pder *CookieProvider) SessionAll() int {
return 0
}
// Implement method, no used.
func (pder *CookieProvider) SessionUpdate(sid string) error {
return nil
}
func init() {
Register("cookie", cookiepder)
}

View File

@ -0,0 +1,38 @@
package session
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestCookie(t *testing.T) {
config := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
globalSessions, err := NewManager("cookie", config)
if err != nil {
t.Fatal("init cookie session err", err)
}
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
sess := globalSessions.SessionStart(w, r)
err = sess.Set("username", "astaxie")
if err != nil {
t.Fatal("set error,", err)
}
if username := sess.Get("username"); username != "astaxie" {
t.Fatal("get username error")
}
sess.SessionRelease(w)
if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" {
t.Fatal("setcookie error")
} else {
parts := strings.Split(strings.TrimSpace(cookiestr), ";")
for k, v := range parts {
nameval := strings.Split(v, "=")
if k == 0 && nameval[0] != "gosessionid" {
t.Fatal("error")
}
}
}
}

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http"
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
@ -17,6 +18,7 @@ var (
gcmaxlifetime int64 gcmaxlifetime int64
) )
// File session store
type FileSessionStore struct { type FileSessionStore struct {
f *os.File f *os.File
sid string sid string
@ -24,6 +26,7 @@ type FileSessionStore struct {
values map[interface{}]interface{} values map[interface{}]interface{}
} }
// Set value to file session
func (fs *FileSessionStore) Set(key, value interface{}) error { func (fs *FileSessionStore) Set(key, value interface{}) error {
fs.lock.Lock() fs.lock.Lock()
defer fs.lock.Unlock() defer fs.lock.Unlock()
@ -31,6 +34,7 @@ func (fs *FileSessionStore) Set(key, value interface{}) error {
return nil return nil
} }
// Get value from file session
func (fs *FileSessionStore) Get(key interface{}) interface{} { func (fs *FileSessionStore) Get(key interface{}) interface{} {
fs.lock.RLock() fs.lock.RLock()
defer fs.lock.RUnlock() defer fs.lock.RUnlock()
@ -42,6 +46,7 @@ func (fs *FileSessionStore) Get(key interface{}) interface{} {
return nil return nil
} }
// Delete value in file session by given key
func (fs *FileSessionStore) Delete(key interface{}) error { func (fs *FileSessionStore) Delete(key interface{}) error {
fs.lock.Lock() fs.lock.Lock()
defer fs.lock.Unlock() defer fs.lock.Unlock()
@ -49,6 +54,7 @@ func (fs *FileSessionStore) Delete(key interface{}) error {
return nil return nil
} }
// Clean all values in file session
func (fs *FileSessionStore) Flush() error { func (fs *FileSessionStore) Flush() error {
fs.lock.Lock() fs.lock.Lock()
defer fs.lock.Unlock() defer fs.lock.Unlock()
@ -56,11 +62,13 @@ func (fs *FileSessionStore) Flush() error {
return nil return nil
} }
// Get file session store id
func (fs *FileSessionStore) SessionID() string { func (fs *FileSessionStore) SessionID() string {
return fs.sid return fs.sid
} }
func (fs *FileSessionStore) SessionRelease() { // Write file session to local file with Gob string
func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) {
defer fs.f.Close() defer fs.f.Close()
b, err := encodeGob(fs.values) b, err := encodeGob(fs.values)
if err != nil { if err != nil {
@ -71,17 +79,23 @@ func (fs *FileSessionStore) SessionRelease() {
fs.f.Write(b) fs.f.Write(b)
} }
// File session provider
type FileProvider struct { type FileProvider struct {
maxlifetime int64 maxlifetime int64
savePath string savePath string
} }
// Init file session provider.
// savePath sets the session files path.
func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error { func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error {
fp.maxlifetime = maxlifetime fp.maxlifetime = maxlifetime
fp.savePath = savePath fp.savePath = savePath
return nil return nil
} }
// Read file session by sid.
// if file is not exist, create it.
// the file path is generated from sid string.
func (fp *FileProvider) SessionRead(sid string) (SessionStore, error) { func (fp *FileProvider) SessionRead(sid string) (SessionStore, error) {
err := os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777) err := os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777)
if err != nil { if err != nil {
@ -116,6 +130,8 @@ func (fp *FileProvider) SessionRead(sid string) (SessionStore, error) {
return ss, nil return ss, nil
} }
// Check file session exist.
// it checkes the file named from sid exist or not.
func (fp *FileProvider) SessionExist(sid string) bool { func (fp *FileProvider) SessionExist(sid string) bool {
_, err := os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) _, err := os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
if err == nil { if err == nil {
@ -125,16 +141,20 @@ func (fp *FileProvider) SessionExist(sid string) bool {
} }
} }
// Remove all files in this save path
func (fp *FileProvider) SessionDestroy(sid string) error { func (fp *FileProvider) SessionDestroy(sid string) error {
os.Remove(path.Join(fp.savePath)) os.Remove(path.Join(fp.savePath))
return nil return nil
} }
// Recycle files in save path
func (fp *FileProvider) SessionGC() { func (fp *FileProvider) SessionGC() {
gcmaxlifetime = fp.maxlifetime gcmaxlifetime = fp.maxlifetime
filepath.Walk(fp.savePath, gcpath) filepath.Walk(fp.savePath, gcpath)
} }
// Get active file session number.
// it walks save path to count files.
func (fp *FileProvider) SessionAll() int { func (fp *FileProvider) SessionAll() int {
a := &activeSession{} a := &activeSession{}
err := filepath.Walk(fp.savePath, func(path string, f os.FileInfo, err error) error { err := filepath.Walk(fp.savePath, func(path string, f os.FileInfo, err error) error {
@ -147,6 +167,8 @@ func (fp *FileProvider) SessionAll() int {
return a.total return a.total
} }
// Generate new sid for file session.
// it delete old file and create new file named from new sid.
func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) { func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
err := os.MkdirAll(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])), 0777) err := os.MkdirAll(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])), 0777)
if err != nil { if err != nil {
@ -196,6 +218,7 @@ func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (SessionStore, err
return ss, nil return ss, nil
} }
// remove file in save path if expired
func gcpath(path string, info os.FileInfo, err error) error { func gcpath(path string, info os.FileInfo, err error) error {
if err != nil { if err != nil {
return err return err

View File

@ -1,38 +0,0 @@
package session
import (
"bytes"
"encoding/gob"
)
func init() {
gob.Register([]interface{}{})
gob.Register(map[int]interface{}{})
gob.Register(map[string]interface{}{})
gob.Register(map[interface{}]interface{}{})
gob.Register(map[string]string{})
gob.Register(map[int]string{})
gob.Register(map[int]int{})
gob.Register(map[int]int64{})
}
func encodeGob(obj map[interface{}]interface{}) ([]byte, error) {
buf := bytes.NewBuffer(nil)
enc := gob.NewEncoder(buf)
err := enc.Encode(obj)
if err != nil {
return []byte(""), err
}
return buf.Bytes(), nil
}
func decodeGob(encoded []byte) (map[interface{}]interface{}, error) {
buf := bytes.NewBuffer(encoded)
dec := gob.NewDecoder(buf)
var out map[interface{}]interface{}
err := dec.Decode(&out)
if err != nil {
return nil, err
}
return out, nil
}

View File

@ -2,19 +2,23 @@ package session
import ( import (
"container/list" "container/list"
"net/http"
"sync" "sync"
"time" "time"
) )
var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)} var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)}
// memory session store.
// it saved sessions in a map in memory.
type MemSessionStore struct { type MemSessionStore struct {
sid string //session id唯一标示 sid string //session id
timeAccessed time.Time //最后访问时间 timeAccessed time.Time //last access time
value map[interface{}]interface{} //session里面存储的值 value map[interface{}]interface{} //session store
lock sync.RWMutex lock sync.RWMutex
} }
// set value to memory session
func (st *MemSessionStore) Set(key, value interface{}) error { func (st *MemSessionStore) Set(key, value interface{}) error {
st.lock.Lock() st.lock.Lock()
defer st.lock.Unlock() defer st.lock.Unlock()
@ -22,6 +26,7 @@ func (st *MemSessionStore) Set(key, value interface{}) error {
return nil return nil
} }
// get value from memory session by key
func (st *MemSessionStore) Get(key interface{}) interface{} { func (st *MemSessionStore) Get(key interface{}) interface{} {
st.lock.RLock() st.lock.RLock()
defer st.lock.RUnlock() defer st.lock.RUnlock()
@ -33,6 +38,7 @@ func (st *MemSessionStore) Get(key interface{}) interface{} {
return nil return nil
} }
// delete in memory session by key
func (st *MemSessionStore) Delete(key interface{}) error { func (st *MemSessionStore) Delete(key interface{}) error {
st.lock.Lock() st.lock.Lock()
defer st.lock.Unlock() defer st.lock.Unlock()
@ -40,6 +46,7 @@ func (st *MemSessionStore) Delete(key interface{}) error {
return nil return nil
} }
// clear all values in memory session
func (st *MemSessionStore) Flush() error { func (st *MemSessionStore) Flush() error {
st.lock.Lock() st.lock.Lock()
defer st.lock.Unlock() defer st.lock.Unlock()
@ -47,28 +54,31 @@ func (st *MemSessionStore) Flush() error {
return nil return nil
} }
// get this id of memory session store
func (st *MemSessionStore) SessionID() string { func (st *MemSessionStore) SessionID() string {
return st.sid return st.sid
} }
func (st *MemSessionStore) SessionRelease() { // Implement method, no used.
func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) {
} }
type MemProvider struct { type MemProvider struct {
lock sync.RWMutex //用来锁 lock sync.RWMutex // locker
sessions map[string]*list.Element //用来存储在内存 sessions map[string]*list.Element // map in memory
list *list.List //用来做gc list *list.List // for gc
maxlifetime int64 maxlifetime int64
savePath string savePath string
} }
// init memory session
func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error { func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error {
pder.maxlifetime = maxlifetime pder.maxlifetime = maxlifetime
pder.savePath = savePath pder.savePath = savePath
return nil return nil
} }
// get memory session store by sid
func (pder *MemProvider) SessionRead(sid string) (SessionStore, error) { func (pder *MemProvider) SessionRead(sid string) (SessionStore, error) {
pder.lock.RLock() pder.lock.RLock()
if element, ok := pder.sessions[sid]; ok { if element, ok := pder.sessions[sid]; ok {
@ -87,6 +97,7 @@ func (pder *MemProvider) SessionRead(sid string) (SessionStore, error) {
return nil, nil return nil, nil
} }
// check session store exist in memory session by sid
func (pder *MemProvider) SessionExist(sid string) bool { func (pder *MemProvider) SessionExist(sid string) bool {
pder.lock.RLock() pder.lock.RLock()
defer pder.lock.RUnlock() defer pder.lock.RUnlock()
@ -97,6 +108,7 @@ func (pder *MemProvider) SessionExist(sid string) bool {
} }
} }
// generate new sid for session store in memory session
func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) { func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
pder.lock.RLock() pder.lock.RLock()
if element, ok := pder.sessions[oldsid]; ok { if element, ok := pder.sessions[oldsid]; ok {
@ -120,6 +132,7 @@ func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (SessionStore, er
return nil, nil return nil, nil
} }
// delete session store in memory session by id
func (pder *MemProvider) SessionDestroy(sid string) error { func (pder *MemProvider) SessionDestroy(sid string) error {
pder.lock.Lock() pder.lock.Lock()
defer pder.lock.Unlock() defer pder.lock.Unlock()
@ -131,6 +144,7 @@ func (pder *MemProvider) SessionDestroy(sid string) error {
return nil return nil
} }
// clean expired session stores in memory session
func (pder *MemProvider) SessionGC() { func (pder *MemProvider) SessionGC() {
pder.lock.RLock() pder.lock.RLock()
for { for {
@ -152,10 +166,12 @@ func (pder *MemProvider) SessionGC() {
pder.lock.RUnlock() pder.lock.RUnlock()
} }
// get count number of memory session
func (pder *MemProvider) SessionAll() int { func (pder *MemProvider) SessionAll() int {
return pder.list.Len() return pder.list.Len()
} }
// expand time of session store by id in memory session
func (pder *MemProvider) SessionUpdate(sid string) error { func (pder *MemProvider) SessionUpdate(sid string) error {
pder.lock.Lock() pder.lock.Lock()
defer pder.lock.Unlock() defer pder.lock.Unlock()

35
session/sess_mem_test.go Normal file
View File

@ -0,0 +1,35 @@
package session
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestMem(t *testing.T) {
globalSessions, _ := NewManager("memory", `{"cookieName":"gosessionid","gclifetime":10}`)
go globalSessions.GC()
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
sess := globalSessions.SessionStart(w, r)
defer sess.SessionRelease(w)
err := sess.Set("username", "astaxie")
if err != nil {
t.Fatal("set error,", err)
}
if username := sess.Get("username"); username != "astaxie" {
t.Fatal("get username error")
}
if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" {
t.Fatal("setcookie error")
} else {
parts := strings.Split(strings.TrimSpace(cookiestr), ";")
for k, v := range parts {
nameval := strings.Split(v, "=")
if k == 0 && nameval[0] != "gosessionid" {
t.Fatal("error")
}
}
}
}

View File

@ -1,14 +1,16 @@
package session package session
//CREATE TABLE `session` ( // mysql session support need create table as sql:
// `session_key` char(64) NOT NULL, // CREATE TABLE `session` (
// `session_data` blob, // `session_key` char(64) NOT NULL,
// `session_expiry` int(11) unsigned NOT NULL, // session_data` blob,
// PRIMARY KEY (`session_key`) // `session_expiry` int(11) unsigned NOT NULL,
//) ENGINE=MyISAM DEFAULT CHARSET=utf8; // PRIMARY KEY (`session_key`)
// ) ENGINE=MyISAM DEFAULT CHARSET=utf8;
import ( import (
"database/sql" "database/sql"
"net/http"
"sync" "sync"
"time" "time"
@ -17,6 +19,7 @@ import (
var mysqlpder = &MysqlProvider{} var mysqlpder = &MysqlProvider{}
// mysql session store
type MysqlSessionStore struct { type MysqlSessionStore struct {
c *sql.DB c *sql.DB
sid string sid string
@ -24,6 +27,8 @@ type MysqlSessionStore struct {
values map[interface{}]interface{} values map[interface{}]interface{}
} }
// set value in mysql session.
// it is temp value in map.
func (st *MysqlSessionStore) Set(key, value interface{}) error { func (st *MysqlSessionStore) Set(key, value interface{}) error {
st.lock.Lock() st.lock.Lock()
defer st.lock.Unlock() defer st.lock.Unlock()
@ -31,6 +36,7 @@ func (st *MysqlSessionStore) Set(key, value interface{}) error {
return nil return nil
} }
// get value from mysql session
func (st *MysqlSessionStore) Get(key interface{}) interface{} { func (st *MysqlSessionStore) Get(key interface{}) interface{} {
st.lock.RLock() st.lock.RLock()
defer st.lock.RUnlock() defer st.lock.RUnlock()
@ -42,6 +48,7 @@ func (st *MysqlSessionStore) Get(key interface{}) interface{} {
return nil return nil
} }
// delete value in mysql session
func (st *MysqlSessionStore) Delete(key interface{}) error { func (st *MysqlSessionStore) Delete(key interface{}) error {
st.lock.Lock() st.lock.Lock()
defer st.lock.Unlock() defer st.lock.Unlock()
@ -49,6 +56,7 @@ func (st *MysqlSessionStore) Delete(key interface{}) error {
return nil return nil
} }
// clear all values in mysql session
func (st *MysqlSessionStore) Flush() error { func (st *MysqlSessionStore) Flush() error {
st.lock.Lock() st.lock.Lock()
defer st.lock.Unlock() defer st.lock.Unlock()
@ -56,26 +64,31 @@ func (st *MysqlSessionStore) Flush() error {
return nil return nil
} }
// get session id of this mysql session store
func (st *MysqlSessionStore) SessionID() string { func (st *MysqlSessionStore) SessionID() string {
return st.sid return st.sid
} }
func (st *MysqlSessionStore) SessionRelease() { // save mysql session values to database.
// must call this method to save values to database.
func (st *MysqlSessionStore) SessionRelease(w http.ResponseWriter) {
defer st.c.Close() defer st.c.Close()
if len(st.values) > 0 { b, err := encodeGob(st.values)
b, err := encodeGob(st.values) if err != nil {
if err != nil { return
return
}
st.c.Exec("UPDATE session set `session_data`= ? where session_key=?", b, st.sid)
} }
st.c.Exec("UPDATE session set `session_data`=?, `session_expiry`=? where session_key=?",
b, time.Now().Unix(), st.sid)
} }
// mysql session provider
type MysqlProvider struct { type MysqlProvider struct {
maxlifetime int64 maxlifetime int64
savePath string savePath string
} }
// connect to mysql
func (mp *MysqlProvider) connectInit() *sql.DB { func (mp *MysqlProvider) connectInit() *sql.DB {
db, e := sql.Open("mysql", mp.savePath) db, e := sql.Open("mysql", mp.savePath)
if e != nil { if e != nil {
@ -84,19 +97,23 @@ func (mp *MysqlProvider) connectInit() *sql.DB {
return db return db
} }
// init mysql session.
// savepath is the connection string of mysql.
func (mp *MysqlProvider) SessionInit(maxlifetime int64, savePath string) error { func (mp *MysqlProvider) SessionInit(maxlifetime int64, savePath string) error {
mp.maxlifetime = maxlifetime mp.maxlifetime = maxlifetime
mp.savePath = savePath mp.savePath = savePath
return nil return nil
} }
// get mysql session by sid
func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) { func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) {
c := mp.connectInit() c := mp.connectInit()
row := c.QueryRow("select session_data from session where session_key=?", sid) row := c.QueryRow("select session_data from session where session_key=?", sid)
var sessiondata []byte var sessiondata []byte
err := row.Scan(&sessiondata) err := row.Scan(&sessiondata)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", sid, "", time.Now().Unix()) c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)",
sid, "", time.Now().Unix())
} }
var kv map[interface{}]interface{} var kv map[interface{}]interface{}
if len(sessiondata) == 0 { if len(sessiondata) == 0 {
@ -111,8 +128,10 @@ func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) {
return rs, nil return rs, nil
} }
// check mysql session exist
func (mp *MysqlProvider) SessionExist(sid string) bool { func (mp *MysqlProvider) SessionExist(sid string) bool {
c := mp.connectInit() c := mp.connectInit()
defer c.Close()
row := c.QueryRow("select session_data from session where session_key=?", sid) row := c.QueryRow("select session_data from session where session_key=?", sid)
var sessiondata []byte var sessiondata []byte
err := row.Scan(&sessiondata) err := row.Scan(&sessiondata)
@ -123,6 +142,7 @@ func (mp *MysqlProvider) SessionExist(sid string) bool {
} }
} }
// generate new sid for mysql session
func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) { func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
c := mp.connectInit() c := mp.connectInit()
row := c.QueryRow("select session_data from session where session_key=?", oldsid) row := c.QueryRow("select session_data from session where session_key=?", oldsid)
@ -145,6 +165,7 @@ func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (SessionStore, er
return rs, nil return rs, nil
} }
// delete mysql session by sid
func (mp *MysqlProvider) SessionDestroy(sid string) error { func (mp *MysqlProvider) SessionDestroy(sid string) error {
c := mp.connectInit() c := mp.connectInit()
c.Exec("DELETE FROM session where session_key=?", sid) c.Exec("DELETE FROM session where session_key=?", sid)
@ -152,6 +173,7 @@ func (mp *MysqlProvider) SessionDestroy(sid string) error {
return nil return nil
} }
// delete expired values in mysql session
func (mp *MysqlProvider) SessionGC() { func (mp *MysqlProvider) SessionGC() {
c := mp.connectInit() c := mp.connectInit()
c.Exec("DELETE from session where session_expiry < ?", time.Now().Unix()-mp.maxlifetime) c.Exec("DELETE from session where session_expiry < ?", time.Now().Unix()-mp.maxlifetime)
@ -159,6 +181,7 @@ func (mp *MysqlProvider) SessionGC() {
return return
} }
// count values in mysql session
func (mp *MysqlProvider) SessionAll() int { func (mp *MysqlProvider) SessionAll() int {
c := mp.connectInit() c := mp.connectInit()
defer c.Close() defer c.Close()

View File

@ -1,6 +1,7 @@
package session package session
import ( import (
"net/http"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -10,18 +11,21 @@ import (
var redispder = &RedisProvider{} var redispder = &RedisProvider{}
// redis max pool size
var MAX_POOL_SIZE = 100 var MAX_POOL_SIZE = 100
var redisPool chan redis.Conn var redisPool chan redis.Conn
// redis session store
type RedisSessionStore struct { type RedisSessionStore struct {
c redis.Conn p *redis.Pool
sid string sid string
lock sync.RWMutex lock sync.RWMutex
values map[interface{}]interface{} values map[interface{}]interface{}
maxlifetime int64 maxlifetime int64
} }
// set value in redis session
func (rs *RedisSessionStore) Set(key, value interface{}) error { func (rs *RedisSessionStore) Set(key, value interface{}) error {
rs.lock.Lock() rs.lock.Lock()
defer rs.lock.Unlock() defer rs.lock.Unlock()
@ -29,6 +33,7 @@ func (rs *RedisSessionStore) Set(key, value interface{}) error {
return nil return nil
} }
// get value in redis session
func (rs *RedisSessionStore) Get(key interface{}) interface{} { func (rs *RedisSessionStore) Get(key interface{}) interface{} {
rs.lock.RLock() rs.lock.RLock()
defer rs.lock.RUnlock() defer rs.lock.RUnlock()
@ -40,6 +45,7 @@ func (rs *RedisSessionStore) Get(key interface{}) interface{} {
return nil return nil
} }
// delete value in redis session
func (rs *RedisSessionStore) Delete(key interface{}) error { func (rs *RedisSessionStore) Delete(key interface{}) error {
rs.lock.Lock() rs.lock.Lock()
defer rs.lock.Unlock() defer rs.lock.Unlock()
@ -47,6 +53,7 @@ func (rs *RedisSessionStore) Delete(key interface{}) error {
return nil return nil
} }
// clear all values in redis session
func (rs *RedisSessionStore) Flush() error { func (rs *RedisSessionStore) Flush() error {
rs.lock.Lock() rs.lock.Lock()
defer rs.lock.Unlock() defer rs.lock.Unlock()
@ -54,22 +61,31 @@ func (rs *RedisSessionStore) Flush() error {
return nil return nil
} }
// get redis session id
func (rs *RedisSessionStore) SessionID() string { func (rs *RedisSessionStore) SessionID() string {
return rs.sid return rs.sid
} }
func (rs *RedisSessionStore) SessionRelease() { // save session values to redis
defer rs.c.Close() func (rs *RedisSessionStore) SessionRelease(w http.ResponseWriter) {
if len(rs.values) > 0 { c := rs.p.Get()
b, err := encodeGob(rs.values) defer c.Close()
if err != nil {
return // if rs.values is empty, return directly
} if len(rs.values) < 1 {
rs.c.Do("SET", rs.sid, string(b)) c.Do("DEL", rs.sid)
rs.c.Do("EXPIRE", rs.sid, rs.maxlifetime) return
} }
b, err := encodeGob(rs.values)
if err != nil {
return
}
c.Do("SET", rs.sid, string(b), "EX", rs.maxlifetime)
} }
// redis session provider
type RedisProvider struct { type RedisProvider struct {
maxlifetime int64 maxlifetime int64
savePath string savePath string
@ -78,8 +94,9 @@ type RedisProvider struct {
poollist *redis.Pool poollist *redis.Pool
} }
//savepath like redisserveraddr,poolsize,password // init redis session
//127.0.0.1:6379,100,astaxie // savepath like redis server addr,pool size,password
// e.g. 127.0.0.1:6379,100,astaxie
func (rp *RedisProvider) SessionInit(maxlifetime int64, savePath string) error { func (rp *RedisProvider) SessionInit(maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime rp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",") configs := strings.Split(savePath, ",")
@ -115,12 +132,11 @@ func (rp *RedisProvider) SessionInit(maxlifetime int64, savePath string) error {
return nil return nil
} }
// read redis session by sid
func (rp *RedisProvider) SessionRead(sid string) (SessionStore, error) { func (rp *RedisProvider) SessionRead(sid string) (SessionStore, error) {
c := rp.poollist.Get() c := rp.poollist.Get()
if existed, err := redis.Int(c.Do("EXISTS", sid)); err != nil || existed == 0 { defer c.Close()
c.Do("SET", sid)
}
c.Do("EXPIRE", sid, rp.maxlifetime)
kvs, err := redis.String(c.Do("GET", sid)) kvs, err := redis.String(c.Do("GET", sid))
var kv map[interface{}]interface{} var kv map[interface{}]interface{}
if len(kvs) == 0 { if len(kvs) == 0 {
@ -131,13 +147,16 @@ func (rp *RedisProvider) SessionRead(sid string) (SessionStore, error) {
return nil, err return nil, err
} }
} }
rs := &RedisSessionStore{c: c, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
rs := &RedisSessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil return rs, nil
} }
// check redis session exist by sid
func (rp *RedisProvider) SessionExist(sid string) bool { func (rp *RedisProvider) SessionExist(sid string) bool {
c := rp.poollist.Get() c := rp.poollist.Get()
defer c.Close() defer c.Close()
if existed, err := redis.Int(c.Do("EXISTS", sid)); err != nil || existed == 0 { if existed, err := redis.Int(c.Do("EXISTS", sid)); err != nil || existed == 0 {
return false return false
} else { } else {
@ -145,13 +164,21 @@ func (rp *RedisProvider) SessionExist(sid string) bool {
} }
} }
// generate new sid for redis session
func (rp *RedisProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) { func (rp *RedisProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
c := rp.poollist.Get() c := rp.poollist.Get()
if existed, err := redis.Int(c.Do("EXISTS", oldsid)); err != nil || existed == 0 { defer c.Close()
c.Do("SET", oldsid)
if existed, _ := redis.Int(c.Do("EXISTS", oldsid)); existed == 0 {
// oldsid doesn't exists, set the new sid directly
// ignore error here, since if it return error
// the existed value will be 0
c.Do("SET", sid, "", "EX", rp.maxlifetime)
} else {
c.Do("RENAME", oldsid, sid)
c.Do("EXPIRE", sid, rp.maxlifetime)
} }
c.Do("RENAME", oldsid, sid)
c.Do("EXPIRE", sid, rp.maxlifetime)
kvs, err := redis.String(c.Do("GET", sid)) kvs, err := redis.String(c.Do("GET", sid))
var kv map[interface{}]interface{} var kv map[interface{}]interface{}
if len(kvs) == 0 { if len(kvs) == 0 {
@ -162,24 +189,27 @@ func (rp *RedisProvider) SessionRegenerate(oldsid, sid string) (SessionStore, er
return nil, err return nil, err
} }
} }
rs := &RedisSessionStore{c: c, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
rs := &RedisSessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil return rs, nil
} }
// delete redis session by id
func (rp *RedisProvider) SessionDestroy(sid string) error { func (rp *RedisProvider) SessionDestroy(sid string) error {
c := rp.poollist.Get() c := rp.poollist.Get()
defer c.Close() defer c.Close()
c.Do("DEL", sid) c.Do("DEL", sid)
return nil return nil
} }
// Impelment method, no used.
func (rp *RedisProvider) SessionGC() { func (rp *RedisProvider) SessionGC() {
return return
} }
//@todo // @todo
func (rp *RedisProvider) SessionAll() int { func (rp *RedisProvider) SessionAll() int {
return 0 return 0
} }

View File

@ -1,6 +1,8 @@
package session package session
import ( import (
"crypto/aes"
"encoding/json"
"testing" "testing"
) )
@ -26,3 +28,82 @@ func Test_gob(t *testing.T) {
t.Error("decode int error") t.Error("decode int error")
} }
} }
func TestGenerate(t *testing.T) {
str := generateRandomKey(20)
if len(str) != 20 {
t.Fatal("generate length is not equal to 20")
}
}
func TestCookieEncodeDecode(t *testing.T) {
hashKey := "testhashKey"
blockkey := generateRandomKey(16)
block, err := aes.NewCipher(blockkey)
if err != nil {
t.Fatal("NewCipher:", err)
}
securityName := string(generateRandomKey(20))
val := make(map[interface{}]interface{})
val["name"] = "astaxie"
val["gender"] = "male"
str, err := encodeCookie(block, hashKey, securityName, val)
if err != nil {
t.Fatal("encodeCookie:", err)
}
dst := make(map[interface{}]interface{})
dst, err = decodeCookie(block, hashKey, securityName, str, 3600)
if err != nil {
t.Fatal("decodeCookie", err)
}
if dst["name"] != "astaxie" {
t.Fatal("dst get map error")
}
if dst["gender"] != "male" {
t.Fatal("dst get map error")
}
}
func TestParseConfig(t *testing.T) {
s := `{"cookieName":"gosessionid","gclifetime":3600}`
cf := new(managerConfig)
cf.EnableSetCookie = true
err := json.Unmarshal([]byte(s), cf)
if err != nil {
t.Fatal("parse json error,", err)
}
if cf.CookieName != "gosessionid" {
t.Fatal("parseconfig get cookiename error")
}
if cf.Gclifetime != 3600 {
t.Fatal("parseconfig get gclifetime error")
}
cc := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
cf2 := new(managerConfig)
cf2.EnableSetCookie = true
err = json.Unmarshal([]byte(cc), cf2)
if err != nil {
t.Fatal("parse json error,", err)
}
if cf2.CookieName != "gosessionid" {
t.Fatal("parseconfig get cookiename error")
}
if cf2.Gclifetime != 3600 {
t.Fatal("parseconfig get gclifetime error")
}
if cf2.EnableSetCookie != false {
t.Fatal("parseconfig get enableSetCookie error")
}
cconfig := new(cookieConfig)
err = json.Unmarshal([]byte(cf2.ProviderConfig), cconfig)
if err != nil {
t.Fatal("parse ProviderConfig err,", err)
}
if cconfig.CookieName != "gosessionid" {
t.Fatal("ProviderConfig get cookieName error")
}
if cconfig.SecurityKey != "beegocookiehashkey" {
t.Fatal("ProviderConfig get securityKey error")
}
}

188
session/sess_utils.go Normal file
View File

@ -0,0 +1,188 @@
package session
import (
"bytes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha1"
"crypto/subtle"
"encoding/base64"
"encoding/gob"
"errors"
"fmt"
"io"
"strconv"
"time"
)
func init() {
gob.Register([]interface{}{})
gob.Register(map[int]interface{}{})
gob.Register(map[string]interface{}{})
gob.Register(map[interface{}]interface{}{})
gob.Register(map[string]string{})
gob.Register(map[int]string{})
gob.Register(map[int]int{})
gob.Register(map[int]int64{})
}
func encodeGob(obj map[interface{}]interface{}) ([]byte, error) {
buf := bytes.NewBuffer(nil)
enc := gob.NewEncoder(buf)
err := enc.Encode(obj)
if err != nil {
return []byte(""), err
}
return buf.Bytes(), nil
}
func decodeGob(encoded []byte) (map[interface{}]interface{}, error) {
buf := bytes.NewBuffer(encoded)
dec := gob.NewDecoder(buf)
var out map[interface{}]interface{}
err := dec.Decode(&out)
if err != nil {
return nil, err
}
return out, nil
}
// generateRandomKey creates a random key with the given strength.
func generateRandomKey(strength int) []byte {
k := make([]byte, strength)
if _, err := io.ReadFull(rand.Reader, k); err != nil {
return nil
}
return k
}
// Encryption -----------------------------------------------------------------
// encrypt encrypts a value using the given block in counter mode.
//
// A random initialization vector (http://goo.gl/zF67k) with the length of the
// block size is prepended to the resulting ciphertext.
func encrypt(block cipher.Block, value []byte) ([]byte, error) {
iv := generateRandomKey(block.BlockSize())
if iv == nil {
return nil, errors.New("encrypt: failed to generate random iv")
}
// Encrypt it.
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(value, value)
// Return iv + ciphertext.
return append(iv, value...), nil
}
// decrypt decrypts a value using the given block in counter mode.
//
// The value to be decrypted must be prepended by a initialization vector
// (http://goo.gl/zF67k) with the length of the block size.
func decrypt(block cipher.Block, value []byte) ([]byte, error) {
size := block.BlockSize()
if len(value) > size {
// Extract iv.
iv := value[:size]
// Extract ciphertext.
value = value[size:]
// Decrypt it.
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(value, value)
return value, nil
}
return nil, errors.New("decrypt: the value could not be decrypted")
}
func encodeCookie(block cipher.Block, hashKey, name string, value map[interface{}]interface{}) (string, error) {
var err error
var b []byte
// 1. encodeGob.
if b, err = encodeGob(value); err != nil {
return "", err
}
// 2. Encrypt (optional).
if b, err = encrypt(block, b); err != nil {
return "", err
}
b = encode(b)
// 3. Create MAC for "name|date|value". Extra pipe to be used later.
b = []byte(fmt.Sprintf("%s|%d|%s|", name, time.Now().UTC().Unix(), b))
h := hmac.New(sha1.New, []byte(hashKey))
h.Write(b)
sig := h.Sum(nil)
// Append mac, remove name.
b = append(b, sig...)[len(name)+1:]
// 4. Encode to base64.
b = encode(b)
// Done.
return string(b), nil
}
func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime int64) (map[interface{}]interface{}, error) {
// 1. Decode from base64.
b, err := decode([]byte(value))
if err != nil {
return nil, err
}
// 2. Verify MAC. Value is "date|value|mac".
parts := bytes.SplitN(b, []byte("|"), 3)
if len(parts) != 3 {
return nil, errors.New("Decode: invalid value %v")
}
b = append([]byte(name+"|"), b[:len(b)-len(parts[2])]...)
h := hmac.New(sha1.New, []byte(hashKey))
h.Write(b)
sig := h.Sum(nil)
if len(sig) != len(parts[2]) || subtle.ConstantTimeCompare(sig, parts[2]) != 1 {
return nil, errors.New("Decode: the value is not valid")
}
// 3. Verify date ranges.
var t1 int64
if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil {
return nil, errors.New("Decode: invalid timestamp")
}
t2 := time.Now().UTC().Unix()
if t1 > t2 {
return nil, errors.New("Decode: timestamp is too new")
}
if t1 < t2-gcmaxlifetime {
return nil, errors.New("Decode: expired timestamp")
}
// 4. Decrypt (optional).
b, err = decode(parts[1])
if err != nil {
return nil, err
}
if b, err = decrypt(block, b); err != nil {
return nil, err
}
// 5. decodeGob.
if dst, err := decodeGob(b); err != nil {
return nil, err
} else {
return dst, nil
}
// Done.
return nil, nil
}
// Encoding -------------------------------------------------------------------
// encode encodes a value using base64.
func encode(value []byte) []byte {
encoded := make([]byte, base64.URLEncoding.EncodedLen(len(value)))
base64.URLEncoding.Encode(encoded, value)
return encoded
}
// decode decodes a cookie using base64.
func decode(value []byte) ([]byte, error) {
decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value)))
b, err := base64.URLEncoding.Decode(decoded, value)
if err != nil {
return nil, err
}
return decoded[:b], nil
}

View File

@ -6,6 +6,7 @@ import (
"crypto/rand" "crypto/rand"
"crypto/sha1" "crypto/sha1"
"encoding/hex" "encoding/hex"
"encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -13,17 +14,20 @@ import (
"time" "time"
) )
// SessionStore contains all data for one session process with specific id.
type SessionStore interface { type SessionStore interface {
Set(key, value interface{}) error //set session value Set(key, value interface{}) error //set session value
Get(key interface{}) interface{} //get session value Get(key interface{}) interface{} //get session value
Delete(key interface{}) error //delete session value Delete(key interface{}) error //delete session value
SessionID() string //back current sessionID SessionID() string //back current sessionID
SessionRelease() // release the resource & save data to provider SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data
Flush() error //delete all data Flush() error //delete all data
} }
// Provider contains global session methods and saved SessionStores.
// it can operate a SessionStore by its id.
type Provider interface { type Provider interface {
SessionInit(maxlifetime int64, savePath string) error SessionInit(gclifetime int64, config string) error
SessionRead(sid string) (SessionStore, error) SessionRead(sid string) (SessionStore, error)
SessionExist(sid string) bool SessionExist(sid string) bool
SessionRegenerate(oldsid, sid string) (SessionStore, error) SessionRegenerate(oldsid, sid string) (SessionStore, error)
@ -47,90 +51,86 @@ func Register(name string, provide Provider) {
provides[name] = provide provides[name] = provide
} }
type Manager struct { type managerConfig struct {
cookieName string //private cookiename CookieName string `json:"cookieName"`
provider Provider EnableSetCookie bool `json:"enableSetCookie,omitempty"`
maxlifetime int64 Gclifetime int64 `json:"gclifetime"`
hashfunc string //support md5 & sha1 Maxlifetime int64 `json:"maxLifetime"`
hashkey string Maxage int `json:"maxage"`
maxage int //cookielifetime Secure bool `json:"secure"`
secure bool SessionIDHashFunc string `json:"sessionIDHashFunc"`
options []interface{} SessionIDHashKey string `json:"sessionIDHashKey"`
CookieLifeTime int64 `json:"cookieLifeTime"`
ProviderConfig string `json:"providerConfig"`
} }
//options // Manager contains Provider and its configuration.
//1. is https default false type Manager struct {
//2. hashfunc default sha1 provider Provider
//3. hashkey default beegosessionkey config *managerConfig
//4. maxage default is none }
func NewManager(provideName, cookieName string, maxlifetime int64, savePath string, options ...interface{}) (*Manager, error) {
// Create new Manager with provider name and json config string.
// provider name:
// 1. cookie
// 2. file
// 3. memory
// 4. redis
// 5. mysql
// json config:
// 1. is https default false
// 2. hashfunc default sha1
// 3. hashkey default beegosessionkey
// 4. maxage default is none
func NewManager(provideName, config string) (*Manager, error) {
provider, ok := provides[provideName] provider, ok := provides[provideName]
if !ok { if !ok {
return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName) return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName)
} }
provider.SessionInit(maxlifetime, savePath) cf := new(managerConfig)
secure := false cf.EnableSetCookie = true
if len(options) > 0 { err := json.Unmarshal([]byte(config), cf)
secure = options[0].(bool) if err != nil {
return nil, err
} }
hashfunc := "sha1" if cf.Maxlifetime == 0 {
if len(options) > 1 { cf.Maxlifetime = cf.Gclifetime
hashfunc = options[1].(string)
} }
hashkey := "beegosessionkey" err = provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig)
if len(options) > 2 { if err != nil {
hashkey = options[2].(string) return nil, err
} }
maxage := -1 if cf.SessionIDHashFunc == "" {
if len(options) > 3 { cf.SessionIDHashFunc = "sha1"
switch options[3].(type) {
case int:
if options[3].(int) > 0 {
maxage = options[3].(int)
} else if options[3].(int) < 0 {
maxage = 0
}
case int64:
if options[3].(int64) > 0 {
maxage = int(options[3].(int64))
} else if options[3].(int64) < 0 {
maxage = 0
}
case int32:
if options[3].(int32) > 0 {
maxage = int(options[3].(int32))
} else if options[3].(int32) < 0 {
maxage = 0
}
}
} }
if cf.SessionIDHashKey == "" {
cf.SessionIDHashKey = string(generateRandomKey(16))
}
return &Manager{ return &Manager{
provider: provider, provider,
cookieName: cookieName, cf,
maxlifetime: maxlifetime,
hashfunc: hashfunc,
hashkey: hashkey,
maxage: maxage,
secure: secure,
options: options,
}, nil }, nil
} }
//get Session // Start session. generate or read the session id from http request.
// if session id exists, return SessionStore with this id.
func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore) { func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore) {
cookie, err := r.Cookie(manager.cookieName) cookie, err := r.Cookie(manager.config.CookieName)
if err != nil || cookie.Value == "" { if err != nil || cookie.Value == "" {
sid := manager.sessionId(r) sid := manager.sessionId(r)
session, _ = manager.provider.SessionRead(sid) session, _ = manager.provider.SessionRead(sid)
cookie = &http.Cookie{Name: manager.cookieName, cookie = &http.Cookie{Name: manager.config.CookieName,
Value: url.QueryEscape(sid), Value: url.QueryEscape(sid),
Path: "/", Path: "/",
HttpOnly: true, HttpOnly: true,
Secure: manager.secure} Secure: manager.config.Secure}
if manager.maxage >= 0 { if manager.config.Maxage >= 0 {
cookie.MaxAge = manager.maxage cookie.MaxAge = manager.config.Maxage
}
if manager.config.EnableSetCookie {
http.SetCookie(w, cookie)
} }
http.SetCookie(w, cookie)
r.AddCookie(cookie) r.AddCookie(cookie)
} else { } else {
sid, _ := url.QueryUnescape(cookie.Value) sid, _ := url.QueryUnescape(cookie.Value)
@ -139,55 +139,65 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
} else { } else {
sid = manager.sessionId(r) sid = manager.sessionId(r)
session, _ = manager.provider.SessionRead(sid) session, _ = manager.provider.SessionRead(sid)
cookie = &http.Cookie{Name: manager.cookieName, cookie = &http.Cookie{Name: manager.config.CookieName,
Value: url.QueryEscape(sid), Value: url.QueryEscape(sid),
Path: "/", Path: "/",
HttpOnly: true, HttpOnly: true,
Secure: manager.secure} Secure: manager.config.Secure}
if manager.maxage >= 0 { if manager.config.Maxage >= 0 {
cookie.MaxAge = manager.maxage cookie.MaxAge = manager.config.Maxage
}
if manager.config.EnableSetCookie {
http.SetCookie(w, cookie)
} }
http.SetCookie(w, cookie)
r.AddCookie(cookie) r.AddCookie(cookie)
} }
} }
return return
} }
//Destroy sessionid // Destroy session by its id in http request cookie.
func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(manager.cookieName) cookie, err := r.Cookie(manager.config.CookieName)
if err != nil || cookie.Value == "" { if err != nil || cookie.Value == "" {
return return
} else { } else {
manager.provider.SessionDestroy(cookie.Value) manager.provider.SessionDestroy(cookie.Value)
expiration := time.Now() expiration := time.Now()
cookie := http.Cookie{Name: manager.cookieName, Path: "/", HttpOnly: true, Expires: expiration, MaxAge: -1} cookie := http.Cookie{Name: manager.config.CookieName,
Path: "/",
HttpOnly: true,
Expires: expiration,
MaxAge: -1}
http.SetCookie(w, &cookie) http.SetCookie(w, &cookie)
} }
} }
// Get SessionStore by its id.
func (manager *Manager) GetProvider(sid string) (sessions SessionStore, err error) { func (manager *Manager) GetProvider(sid string) (sessions SessionStore, err error) {
sessions, err = manager.provider.SessionRead(sid) sessions, err = manager.provider.SessionRead(sid)
return return
} }
// Start session gc process.
// it can do gc in times after gc lifetime.
func (manager *Manager) GC() { func (manager *Manager) GC() {
manager.provider.SessionGC() manager.provider.SessionGC()
time.AfterFunc(time.Duration(manager.maxlifetime)*time.Second, func() { manager.GC() }) time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() })
} }
// Regenerate a session id for this SessionStore who's id is saving in http request.
func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Request) (session SessionStore) { func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Request) (session SessionStore) {
sid := manager.sessionId(r) sid := manager.sessionId(r)
cookie, err := r.Cookie(manager.cookieName) cookie, err := r.Cookie(manager.config.CookieName)
if err != nil && cookie.Value == "" { if err != nil && cookie.Value == "" {
//delete old cookie //delete old cookie
session, _ = manager.provider.SessionRead(sid) session, _ = manager.provider.SessionRead(sid)
cookie = &http.Cookie{Name: manager.cookieName, cookie = &http.Cookie{Name: manager.config.CookieName,
Value: url.QueryEscape(sid), Value: url.QueryEscape(sid),
Path: "/", Path: "/",
HttpOnly: true, HttpOnly: true,
Secure: manager.secure, Secure: manager.config.Secure,
} }
} else { } else {
oldsid, _ := url.QueryUnescape(cookie.Value) oldsid, _ := url.QueryUnescape(cookie.Value)
@ -196,44 +206,47 @@ func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Reque
cookie.HttpOnly = true cookie.HttpOnly = true
cookie.Path = "/" cookie.Path = "/"
} }
if manager.maxage >= 0 { if manager.config.Maxage >= 0 {
cookie.MaxAge = manager.maxage cookie.MaxAge = manager.config.Maxage
} }
http.SetCookie(w, cookie) http.SetCookie(w, cookie)
r.AddCookie(cookie) r.AddCookie(cookie)
return return
} }
// Get all active sessions count number.
func (manager *Manager) GetActiveSession() int { func (manager *Manager) GetActiveSession() int {
return manager.provider.SessionAll() return manager.provider.SessionAll()
} }
// Set hash function for generating session id.
func (manager *Manager) SetHashFunc(hasfunc, hashkey string) { func (manager *Manager) SetHashFunc(hasfunc, hashkey string) {
manager.hashfunc = hasfunc manager.config.SessionIDHashFunc = hasfunc
manager.hashkey = hashkey manager.config.SessionIDHashKey = hashkey
} }
// Set cookie with https.
func (manager *Manager) SetSecure(secure bool) { func (manager *Manager) SetSecure(secure bool) {
manager.secure = secure manager.config.Secure = secure
} }
//remote_addr cruunixnano randdata // generate session id with rand string, unix nano time, remote addr by hash function.
func (manager *Manager) sessionId(r *http.Request) (sid string) { func (manager *Manager) sessionId(r *http.Request) (sid string) {
bs := make([]byte, 24) bs := make([]byte, 24)
if _, err := io.ReadFull(rand.Reader, bs); err != nil { if _, err := io.ReadFull(rand.Reader, bs); err != nil {
return "" return ""
} }
sig := fmt.Sprintf("%s%d%s", r.RemoteAddr, time.Now().UnixNano(), bs) sig := fmt.Sprintf("%s%d%s", r.RemoteAddr, time.Now().UnixNano(), bs)
if manager.hashfunc == "md5" { if manager.config.SessionIDHashFunc == "md5" {
h := md5.New() h := md5.New()
h.Write([]byte(sig)) h.Write([]byte(sig))
sid = hex.EncodeToString(h.Sum(nil)) sid = hex.EncodeToString(h.Sum(nil))
} else if manager.hashfunc == "sha1" { } else if manager.config.SessionIDHashFunc == "sha1" {
h := hmac.New(sha1.New, []byte(manager.hashkey)) h := hmac.New(sha1.New, []byte(manager.config.SessionIDHashKey))
fmt.Fprintf(h, "%s", sig) fmt.Fprintf(h, "%s", sig)
sid = hex.EncodeToString(h.Sum(nil)) sid = hex.EncodeToString(h.Sum(nil))
} else { } else {
h := hmac.New(sha1.New, []byte(manager.hashkey)) h := hmac.New(sha1.New, []byte(manager.config.SessionIDHashKey))
fmt.Fprintf(h, "%s", sig) fmt.Fprintf(h, "%s", sig)
sid = hex.EncodeToString(h.Sum(nil)) sid = hex.EncodeToString(h.Sum(nil))
} }

View File

@ -8,6 +8,7 @@ import (
var port = "" var port = ""
var baseUrl = "http://localhost:" var baseUrl = "http://localhost:"
// beego test request client
type TestHttpRequest struct { type TestHttpRequest struct {
httplib.BeegoHttpRequest httplib.BeegoHttpRequest
} }
@ -24,22 +25,27 @@ func getPort() string {
return port return port
} }
// returns test client in GET method
func Get(path string) *TestHttpRequest { func Get(path string) *TestHttpRequest {
return &TestHttpRequest{*httplib.Get(baseUrl + getPort() + path)} return &TestHttpRequest{*httplib.Get(baseUrl + getPort() + path)}
} }
// returns test client in POST method
func Post(path string) *TestHttpRequest { func Post(path string) *TestHttpRequest {
return &TestHttpRequest{*httplib.Post(baseUrl + getPort() + path)} return &TestHttpRequest{*httplib.Post(baseUrl + getPort() + path)}
} }
// returns test client in PUT method
func Put(path string) *TestHttpRequest { func Put(path string) *TestHttpRequest {
return &TestHttpRequest{*httplib.Put(baseUrl + getPort() + path)} return &TestHttpRequest{*httplib.Put(baseUrl + getPort() + path)}
} }
// returns test client in DELETE method
func Delete(path string) *TestHttpRequest { func Delete(path string) *TestHttpRequest {
return &TestHttpRequest{*httplib.Delete(baseUrl + getPort() + path)} return &TestHttpRequest{*httplib.Delete(baseUrl + getPort() + path)}
} }
// returns test client in HEAD method
func Head(path string) *TestHttpRequest { func Head(path string) *TestHttpRequest {
return &TestHttpRequest{*httplib.Head(baseUrl + getPort() + path)} return &TestHttpRequest{*httplib.Head(baseUrl + getPort() + path)}
} }

View File

@ -29,16 +29,13 @@ type pointerInfo struct {
used []int used []int
} }
//
// print the data in console // print the data in console
//
func Display(data ...interface{}) { func Display(data ...interface{}) {
display(true, data...) display(true, data...)
} }
//
// return string // return data print string
//
func GetDisplayString(data ...interface{}) string { func GetDisplayString(data ...interface{}) string {
return display(false, data...) return display(false, data...)
} }
@ -67,9 +64,7 @@ func display(displayed bool, data ...interface{}) string {
return buf.String() return buf.String()
} }
// // return data dump and format bytes
// return fomateinfo
//
func fomateinfo(headlen int, data ...interface{}) []byte { func fomateinfo(headlen int, data ...interface{}) []byte {
var buf = new(bytes.Buffer) var buf = new(bytes.Buffer)
@ -108,6 +103,7 @@ func fomateinfo(headlen int, data ...interface{}) []byte {
return buf.Bytes() return buf.Bytes()
} }
// check data is golang basic type
func isSimpleType(val reflect.Value, kind reflect.Kind, pointers **pointerInfo, interfaces *[]reflect.Value) bool { func isSimpleType(val reflect.Value, kind reflect.Kind, pointers **pointerInfo, interfaces *[]reflect.Value) bool {
switch kind { switch kind {
case reflect.Bool: case reflect.Bool:
@ -158,6 +154,7 @@ func isSimpleType(val reflect.Value, kind reflect.Kind, pointers **pointerInfo,
return false return false
} }
// dump value
func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo, interfaces *[]reflect.Value, structFilter func(string, string) bool, formatOutput bool, indent string, level int) { func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo, interfaces *[]reflect.Value, structFilter func(string, string) bool, formatOutput bool, indent string, level int) {
var t = val.Kind() var t = val.Kind()
@ -367,6 +364,7 @@ func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo,
} }
} }
// dump pointer value
func printPointerInfo(buf *bytes.Buffer, headlen int, pointers *pointerInfo) { func printPointerInfo(buf *bytes.Buffer, headlen int, pointers *pointerInfo) {
var anyused = false var anyused = false
var pointerNum = 0 var pointerNum = 0
@ -434,9 +432,7 @@ func printPointerInfo(buf *bytes.Buffer, headlen int, pointers *pointerInfo) {
} }
} }
// // get stack bytes
// get stack info
//
func stack(skip int, indent string) []byte { func stack(skip int, indent string) []byte {
var buf = new(bytes.Buffer) var buf = new(bytes.Buffer)
@ -455,7 +451,7 @@ func stack(skip int, indent string) []byte {
return buf.Bytes() return buf.Bytes()
} }
// function returns, if possible, the name of the function containing the PC. // return the name of the function containing the PC if possible,
func function(pc uintptr) []byte { func function(pc uintptr) []byte {
fn := runtime.FuncForPC(pc) fn := runtime.FuncForPC(pc)
if fn == nil { if fn == nil {

View File

@ -13,12 +13,15 @@ package toolbox
//AddHealthCheck("database",&DatabaseCheck{}) //AddHealthCheck("database",&DatabaseCheck{})
// health checker map
var AdminCheckList map[string]HealthChecker var AdminCheckList map[string]HealthChecker
// health checker interface
type HealthChecker interface { type HealthChecker interface {
Check() error Check() error
} }
// add health checker with name string
func AddHealthCheck(name string, hc HealthChecker) { func AddHealthCheck(name string, hc HealthChecker) {
AdminCheckList[name] = hc AdminCheckList[name] = hc
} }

View File

@ -19,6 +19,7 @@ func init() {
pid = os.Getpid() pid = os.Getpid()
} }
// parse input command string
func ProcessInput(input string, w io.Writer) { func ProcessInput(input string, w io.Writer) {
switch input { switch input {
case "lookup goroutine": case "lookup goroutine":
@ -44,6 +45,7 @@ func ProcessInput(input string, w io.Writer) {
} }
} }
// record memory profile in pprof
func MemProf() { func MemProf() {
if f, err := os.Create("mem-" + strconv.Itoa(pid) + ".memprof"); err != nil { if f, err := os.Create("mem-" + strconv.Itoa(pid) + ".memprof"); err != nil {
log.Fatal("record memory profile failed: %v", err) log.Fatal("record memory profile failed: %v", err)
@ -54,6 +56,7 @@ func MemProf() {
} }
} }
// start cpu profile monitor
func StartCPUProfile() { func StartCPUProfile() {
f, err := os.Create("cpu-" + strconv.Itoa(pid) + ".pprof") f, err := os.Create("cpu-" + strconv.Itoa(pid) + ".pprof")
if err != nil { if err != nil {
@ -62,10 +65,12 @@ func StartCPUProfile() {
pprof.StartCPUProfile(f) pprof.StartCPUProfile(f)
} }
// stop cpu profile monitor
func StopCPUProfile() { func StopCPUProfile() {
pprof.StopCPUProfile() pprof.StopCPUProfile()
} }
// print gc information to io.Writer
func PrintGCSummary(w io.Writer) { func PrintGCSummary(w io.Writer) {
memStats := &runtime.MemStats{} memStats := &runtime.MemStats{}
runtime.ReadMemStats(memStats) runtime.ReadMemStats(memStats)
@ -114,7 +119,7 @@ func avg(items []time.Duration) time.Duration {
return time.Duration(int64(sum) / int64(len(items))) return time.Duration(int64(sum) / int64(len(items)))
} }
// human readable format // format bytes number friendly
func toH(bytes uint64) string { func toH(bytes uint64) string {
switch { switch {
case bytes < 1024: case bytes < 1024:

View File

@ -7,6 +7,7 @@ import (
"time" "time"
) )
// Statistics struct
type Statistics struct { type Statistics struct {
RequestUrl string RequestUrl string
RequestController string RequestController string
@ -16,12 +17,15 @@ type Statistics struct {
TotalTime time.Duration TotalTime time.Duration
} }
// UrlMap contains several statistics struct to log different data
type UrlMap struct { type UrlMap struct {
lock sync.RWMutex lock sync.RWMutex
LengthLimit int //limit the urlmap's length if it's equal to 0 there's no limit LengthLimit int //limit the urlmap's length if it's equal to 0 there's no limit
urlmap map[string]map[string]*Statistics urlmap map[string]map[string]*Statistics
} }
// add statistics task.
// it needs request method, request url, request controller and statistics time duration
func (m *UrlMap) AddStatistics(requestMethod, requestUrl, requestController string, requesttime time.Duration) { func (m *UrlMap) AddStatistics(requestMethod, requestUrl, requestController string, requesttime time.Duration) {
m.lock.Lock() m.lock.Lock()
defer m.lock.Unlock() defer m.lock.Unlock()
@ -65,6 +69,7 @@ func (m *UrlMap) AddStatistics(requestMethod, requestUrl, requestController stri
} }
} }
// put url statistics result in io.Writer
func (m *UrlMap) GetMap(rw io.Writer) { func (m *UrlMap) GetMap(rw io.Writer) {
m.lock.RLock() m.lock.RLock()
defer m.lock.RUnlock() defer m.lock.RUnlock()
@ -78,6 +83,7 @@ func (m *UrlMap) GetMap(rw io.Writer) {
} }
} }
// global statistics data map
var StatisticsMap *UrlMap var StatisticsMap *UrlMap
func init() { func init() {

View File

@ -53,6 +53,7 @@ const (
starBit = 1 << 63 starBit = 1 << 63
) )
// time taks schedule
type Schedule struct { type Schedule struct {
Second uint64 Second uint64
Minute uint64 Minute uint64
@ -62,8 +63,10 @@ type Schedule struct {
Week uint64 Week uint64
} }
// task func type
type TaskFunc func() error type TaskFunc func() error
// task interface
type Tasker interface { type Tasker interface {
GetStatus() string GetStatus() string
Run() error Run() error
@ -73,21 +76,24 @@ type Tasker interface {
GetPrev() time.Time GetPrev() time.Time
} }
// task error
type taskerr struct { type taskerr struct {
t time.Time t time.Time
errinfo string errinfo string
} }
// task struct
type Task struct { type Task struct {
Taskname string Taskname string
Spec *Schedule Spec *Schedule
DoFunc TaskFunc DoFunc TaskFunc
Prev time.Time Prev time.Time
Next time.Time Next time.Time
Errlist []*taskerr //errtime:errinfo Errlist []*taskerr // like errtime:errinfo
ErrLimit int //max length for the errlist 0 stand for there' no limit ErrLimit int // max length for the errlist, 0 stand for no limit
} }
// add new task with name, time and func
func NewTask(tname string, spec string, f TaskFunc) *Task { func NewTask(tname string, spec string, f TaskFunc) *Task {
task := &Task{ task := &Task{
@ -99,6 +105,7 @@ func NewTask(tname string, spec string, f TaskFunc) *Task {
return task return task
} }
// get current task status
func (tk *Task) GetStatus() string { func (tk *Task) GetStatus() string {
var str string var str string
for _, v := range tk.Errlist { for _, v := range tk.Errlist {
@ -107,6 +114,7 @@ func (tk *Task) GetStatus() string {
return str return str
} }
// run task
func (tk *Task) Run() error { func (tk *Task) Run() error {
err := tk.DoFunc() err := tk.DoFunc()
if err != nil { if err != nil {
@ -117,53 +125,58 @@ func (tk *Task) Run() error {
return err return err
} }
// set next time for this task
func (tk *Task) SetNext(now time.Time) { func (tk *Task) SetNext(now time.Time) {
tk.Next = tk.Spec.Next(now) tk.Next = tk.Spec.Next(now)
} }
// get the next call time of this task
func (tk *Task) GetNext() time.Time { func (tk *Task) GetNext() time.Time {
return tk.Next return tk.Next
} }
// set prev time of this task
func (tk *Task) SetPrev(now time.Time) { func (tk *Task) SetPrev(now time.Time) {
tk.Prev = now tk.Prev = now
} }
// get prev time of this task
func (tk *Task) GetPrev() time.Time { func (tk *Task) GetPrev() time.Time {
return tk.Prev return tk.Prev
} }
//前6个字段分别表示 // six columns mean
// 秒钟0-59 // second0-59
// 分钟0-59 // minute0-59
// 小时1-23 // hour1-23
// 日期1-31 // day1-31
// 月份1-12 // month1-12
// 星期0-60表示周日 // week0-60 means Sunday
//还可以用一些特殊符号 // some signals
// * 表示任何时刻 // * any time
// , 表示分割如第三段里2,4表示2点和4点执行 // ,  separate signal
//   表示一个段,如第三端里: 1-5就表示1到5点 //   duration
// /n : 表示每个n的单位执行一次如第三段里*/1, 就表示每隔1个小时执行一次命令。也可以写成1-23/1. // /n : do as n times of time duration
///////////////////////////////////////////////////////// /////////////////////////////////////////////////////////
// 0/30 * * * * * 每30秒 执行 // 0/30 * * * * * every 30s
// 0 43 21 * * * 21:43 执行 // 0 43 21 * * * 21:43
// 0 15 05 * * *    05:15 执行 // 0 15 05 * * *    05:15
// 0 0 17 * * * 17:00 执行 // 0 0 17 * * * 17:00
// 0 0 17 * * 1 每周一的 17:00 执行 // 0 0 17 * * 1 17:00 in every Monday
// 0 0,10 17 * * 0,2,3 每周日,周二,周三的 17:00和 17:10 执行 // 0 0,10 17 * * 0,2,3 17:00 and 17:10 in every Sunday, Tuesday and Wednesday
// 0 0-10 17 1 * * 毎月1日从 17:00到7:10 毎隔1分钟 执行 // 0 0-10 17 1 * * 17:00 to 17:10 in 1 min duration each time on the first day of month
// 0 0 0 1,15 * 1 毎月1日和 15日和 一日的 0:00 执行 // 0 0 0 1,15 * 1 0:00 on the 1st day and 15th day of month
// 0 42 4 1 * *     毎月1日的 4:42分 执行 // 0 42 4 1 * *     4:42 on the 1st day of month
// 0 0 21 * * 1-6   周一到周六 21:00 执行 // 0 0 21 * * 1-6   21:00 from Monday to Saturday
// 0 0,10,20,30,40,50 * * * *  每隔10分 执行 // 0 0,10,20,30,40,50 * * * *  every 10 min duration
// 0 */10 * * * *        每隔10分 执行 // 0 */10 * * * *        every 10 min duration
// 0 * 1 * * *         从1:0到1:59 每隔1分钟 执行 // 0 * 1 * * *         1:00 to 1:59 in 1 min duration each time
// 0 0 1 * * *         1:00 执行 // 0 0 1 * * *         1:00
// 0 0 */1 * * *        毎时0分 每隔1小时 执行 // 0 0 */1 * * *        0 min of hour in 1 hour duration
// 0 0 * * * *         毎时0分 每隔1小时 执行 // 0 0 * * * *         0 min of hour in 1 hour duration
// 0 2 8-20/3 * * *       8:02,11:02,14:02,17:02,20:02 执行 // 0 2 8-20/3 * * *       8:02, 11:02, 14:02, 17:02, 20:02
// 0 30 5 1,15 * *       1日 和 15日的 5:30 执行 // 0 30 5 1,15 * *       5:30 on the 1st day and 15th day of month
func (t *Task) SetCron(spec string) { func (t *Task) SetCron(spec string) {
t.Spec = t.parse(spec) t.Spec = t.parse(spec)
} }
@ -252,6 +265,7 @@ func (t *Task) parseSpec(spec string) *Schedule {
return nil return nil
} }
// set schedule to next time
func (s *Schedule) Next(t time.Time) time.Time { func (s *Schedule) Next(t time.Time) time.Time {
// Start at the earliest possible time (the upcoming second). // Start at the earliest possible time (the upcoming second).
@ -349,6 +363,7 @@ func dayMatches(s *Schedule, t time.Time) bool {
return domMatch || dowMatch return domMatch || dowMatch
} }
// start all tasks
func StartTask() { func StartTask() {
go run() go run()
} }
@ -388,20 +403,23 @@ func run() {
} }
} }
// start all tasks
func StopTask() { func StopTask() {
stop <- true stop <- true
} }
// add task with name
func AddTask(taskname string, t Tasker) { func AddTask(taskname string, t Tasker) {
AdminTaskList[taskname] = t AdminTaskList[taskname] = t
} }
//sort map for tasker // sort map for tasker
type MapSorter struct { type MapSorter struct {
Keys []string Keys []string
Vals []Tasker Vals []Tasker
} }
// create new tasker map
func NewMapSorter(m map[string]Tasker) *MapSorter { func NewMapSorter(m map[string]Tasker) *MapSorter {
ms := &MapSorter{ ms := &MapSorter{
Keys: make([]string, 0, len(m)), Keys: make([]string, 0, len(m)),
@ -414,6 +432,7 @@ func NewMapSorter(m map[string]Tasker) *MapSorter {
return ms return ms
} }
// sort tasker map
func (ms *MapSorter) Sort() { func (ms *MapSorter) Sort() {
sort.Sort(ms) sort.Sort(ms)
} }

View File

@ -5,6 +5,7 @@ import (
"runtime" "runtime"
) )
// get function name
func GetFuncName(i interface{}) string { func GetFuncName(i interface{}) string {
return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name() return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name()
} }

45
utils/captcha/README.md Normal file
View File

@ -0,0 +1,45 @@
# Captcha
an example for use captcha
```
package controllers
import (
"github.com/astaxie/beego"
"github.com/astaxie/beego/cache"
"github.com/astaxie/beego/utils/captcha"
)
var cpt *captcha.Captcha
func init() {
// use beego cache system store the captcha data
store := cache.NewMemoryCache()
cpt = captcha.NewWithFilter("/captcha/", store)
}
type MainController struct {
beego.Controller
}
func (this *MainController) Get() {
this.TplNames = "index.tpl"
}
func (this *MainController) Post() {
this.TplNames = "index.tpl"
this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request)
}
```
template usage
```
{{.Success}}
<form action="/" method="post">
{{create_captcha}}
<input name="captcha" type="text">
</form>
```

251
utils/captcha/captcha.go Normal file
View File

@ -0,0 +1,251 @@
// an example for use captcha
//
// ```
// package controllers
//
// import (
// "github.com/astaxie/beego"
// "github.com/astaxie/beego/cache"
// "github.com/astaxie/beego/utils/captcha"
// )
//
// var cpt *captcha.Captcha
//
// func init() {
// // use beego cache system store the captcha data
// store := cache.NewMemoryCache()
// cpt = captcha.NewWithFilter("/captcha/", store)
// }
//
// type MainController struct {
// beego.Controller
// }
//
// func (this *MainController) Get() {
// this.TplNames = "index.tpl"
// }
//
// func (this *MainController) Post() {
// this.TplNames = "index.tpl"
//
// this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request)
// }
// ```
//
// template usage
//
// ```
// {{.Success}}
// <form action="/" method="post">
// {{create_captcha}}
// <input name="captcha" type="text">
// </form>
// ```
package captcha
import (
"fmt"
"html/template"
"net/http"
"path"
"strings"
"github.com/astaxie/beego"
"github.com/astaxie/beego/cache"
"github.com/astaxie/beego/context"
"github.com/astaxie/beego/utils"
)
var (
defaultChars = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
)
const (
// default captcha attributes
challengeNums = 6
expiration = 600
fieldIdName = "captcha_id"
fieldCaptchaName = "captcha"
cachePrefix = "captcha_"
urlPrefix = "/captcha/"
)
// Captcha struct
type Captcha struct {
// beego cache store
store cache.Cache
// url prefix for captcha image
urlPrefix string
// specify captcha id input field name
FieldIdName string
// specify captcha result input field name
FieldCaptchaName string
// captcha image width and height
StdWidth int
StdHeight int
// captcha chars nums
ChallengeNums int
// captcha expiration seconds
Expiration int64
// cache key prefix
CachePrefix string
}
// generate key string
func (c *Captcha) key(id string) string {
return c.CachePrefix + id
}
// generate rand chars with default chars
func (c *Captcha) genRandChars() []byte {
return utils.RandomCreateBytes(c.ChallengeNums, defaultChars...)
}
// beego filter handler for serve captcha image
func (c *Captcha) Handler(ctx *context.Context) {
var chars []byte
id := path.Base(ctx.Request.RequestURI)
if i := strings.Index(id, "."); i != -1 {
id = id[:i]
}
key := c.key(id)
if v, ok := c.store.Get(key).([]byte); ok {
chars = v
} else {
ctx.Output.SetStatus(404)
ctx.WriteString("captcha not found")
return
}
// reload captcha
if len(ctx.Input.Query("reload")) > 0 {
chars = c.genRandChars()
if err := c.store.Put(key, chars, c.Expiration); err != nil {
ctx.Output.SetStatus(500)
ctx.WriteString("captcha reload error")
beego.Error("Reload Create Captcha Error:", err)
return
}
}
img := NewImage(chars, c.StdWidth, c.StdHeight)
if _, err := img.WriteTo(ctx.ResponseWriter); err != nil {
beego.Error("Write Captcha Image Error:", err)
}
}
// tempalte func for output html
func (c *Captcha) CreateCaptchaHtml() template.HTML {
value, err := c.CreateCaptcha()
if err != nil {
beego.Error("Create Captcha Error:", err)
return ""
}
// create html
return template.HTML(fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`+
`<a class="captcha" href="javascript:">`+
`<img onclick="this.src=('%s%s.png?reload='+(new Date()).getTime())" class="captcha-img" src="%s%s.png">`+
`</a>`, c.FieldIdName, value, c.urlPrefix, value, c.urlPrefix, value))
}
// create a new captcha id
func (c *Captcha) CreateCaptcha() (string, error) {
// generate captcha id
id := string(utils.RandomCreateBytes(15))
// get the captcha chars
chars := c.genRandChars()
// save to store
if err := c.store.Put(c.key(id), chars, c.Expiration); err != nil {
return "", err
}
return id, nil
}
// verify from a request
func (c *Captcha) VerifyReq(req *http.Request) bool {
req.ParseForm()
return c.Verify(req.Form.Get(c.FieldIdName), req.Form.Get(c.FieldCaptchaName))
}
// direct verify id and challenge string
func (c *Captcha) Verify(id string, challenge string) (success bool) {
if len(challenge) == 0 || len(id) == 0 {
return
}
var chars []byte
key := c.key(id)
if v, ok := c.store.Get(key).([]byte); ok && len(v) == len(challenge) {
chars = v
} else {
return
}
defer func() {
// finally remove it
c.store.Delete(key)
}()
// verify challenge
for i, c := range chars {
if c != challenge[i]-48 {
return
}
}
return true
}
// create a new captcha.Captcha
func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha {
cpt := &Captcha{}
cpt.store = store
cpt.FieldIdName = fieldIdName
cpt.FieldCaptchaName = fieldCaptchaName
cpt.ChallengeNums = challengeNums
cpt.Expiration = expiration
cpt.CachePrefix = cachePrefix
cpt.StdWidth = stdWidth
cpt.StdHeight = stdHeight
if len(urlPrefix) == 0 {
urlPrefix = urlPrefix
}
if urlPrefix[len(urlPrefix)-1] != '/' {
urlPrefix += "/"
}
cpt.urlPrefix = urlPrefix
return cpt
}
// create a new captcha.Captcha and auto AddFilter for serve captacha image
// and add a tempalte func for output html
func NewWithFilter(urlPrefix string, store cache.Cache) *Captcha {
cpt := NewCaptcha(urlPrefix, store)
// create filter for serve captcha image
beego.AddFilter(urlPrefix+":", "BeforeRouter", cpt.Handler)
// add to template func map
beego.AddFuncMap("create_captcha", cpt.CreateCaptchaHtml)
return cpt
}

484
utils/captcha/image.go Normal file
View File

@ -0,0 +1,484 @@
// modifiy and integrated to Beego from https://github.com/dchest/captcha
package captcha
import (
"bytes"
"image"
"image/color"
"image/png"
"io"
"math"
)
const (
fontWidth = 11
fontHeight = 18
blackChar = 1
// Standard width and height of a captcha image.
stdWidth = 240
stdHeight = 80
// Maximum absolute skew factor of a single digit.
maxSkew = 0.7
// Number of background circles.
circleCount = 20
)
var font = [][]byte{
{ // 0
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0,
0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0,
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1,
0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0,
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
},
{ // 1
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,
0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0,
0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0,
0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0,
0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
},
{ // 2
0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0,
0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,
0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0,
0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
},
{ // 3
0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0,
0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
},
{ // 4
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0,
0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0,
0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0,
0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0,
0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0,
0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0,
0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0,
0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0,
1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0,
1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
},
{ // 5
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0,
},
{ // 6
0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0,
0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0,
0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0,
1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0,
1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0,
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1,
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
},
{ // 7
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0,
0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0,
},
{ // 8
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1,
0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1,
0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1,
0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0,
0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0,
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0,
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0,
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
},
{ // 9
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0,
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1,
0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1,
0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,
0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
},
}
type Image struct {
*image.Paletted
numWidth int
numHeight int
dotSize int
}
var prng = &siprng{}
// randIntn returns a pseudorandom non-negative int in range [0, n).
func randIntn(n int) int {
return prng.Intn(n)
}
// randInt returns a pseudorandom int in range [from, to].
func randInt(from, to int) int {
return prng.Intn(to+1-from) + from
}
// randFloat returns a pseudorandom float64 in range [from, to].
func randFloat(from, to float64) float64 {
return (to-from)*prng.Float64() + from
}
func randomPalette() color.Palette {
p := make([]color.Color, circleCount+1)
// Transparent color.
p[0] = color.RGBA{0xFF, 0xFF, 0xFF, 0x00}
// Primary color.
prim := color.RGBA{
uint8(randIntn(129)),
uint8(randIntn(129)),
uint8(randIntn(129)),
0xFF,
}
p[1] = prim
// Circle colors.
for i := 2; i <= circleCount; i++ {
p[i] = randomBrightness(prim, 255)
}
return p
}
// NewImage returns a new captcha image of the given width and height with the
// given digits, where each digit must be in range 0-9.
func NewImage(digits []byte, width, height int) *Image {
m := new(Image)
m.Paletted = image.NewPaletted(image.Rect(0, 0, width, height), randomPalette())
m.calculateSizes(width, height, len(digits))
// Randomly position captcha inside the image.
maxx := width - (m.numWidth+m.dotSize)*len(digits) - m.dotSize
maxy := height - m.numHeight - m.dotSize*2
var border int
if width > height {
border = height / 5
} else {
border = width / 5
}
x := randInt(border, maxx-border)
y := randInt(border, maxy-border)
// Draw digits.
for _, n := range digits {
m.drawDigit(font[n], x, y)
x += m.numWidth + m.dotSize
}
// Draw strike-through line.
m.strikeThrough()
// Apply wave distortion.
m.distort(randFloat(5, 10), randFloat(100, 200))
// Fill image with random circles.
m.fillWithCircles(circleCount, m.dotSize)
return m
}
// encodedPNG encodes an image to PNG and returns
// the result as a byte slice.
func (m *Image) encodedPNG() []byte {
var buf bytes.Buffer
if err := png.Encode(&buf, m.Paletted); err != nil {
panic(err.Error())
}
return buf.Bytes()
}
// WriteTo writes captcha image in PNG format into the given writer.
func (m *Image) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(m.encodedPNG())
return int64(n), err
}
func (m *Image) calculateSizes(width, height, ncount int) {
// Goal: fit all digits inside the image.
var border int
if width > height {
border = height / 4
} else {
border = width / 4
}
// Convert everything to floats for calculations.
w := float64(width - border*2)
h := float64(height - border*2)
// fw takes into account 1-dot spacing between digits.
fw := float64(fontWidth + 1)
fh := float64(fontHeight)
nc := float64(ncount)
// Calculate the width of a single digit taking into account only the
// width of the image.
nw := w / nc
// Calculate the height of a digit from this width.
nh := nw * fh / fw
// Digit too high?
if nh > h {
// Fit digits based on height.
nh = h
nw = fw / fh * nh
}
// Calculate dot size.
m.dotSize = int(nh / fh)
// Save everything, making the actual width smaller by 1 dot to account
// for spacing between digits.
m.numWidth = int(nw) - m.dotSize
m.numHeight = int(nh)
}
func (m *Image) drawHorizLine(fromX, toX, y int, colorIdx uint8) {
for x := fromX; x <= toX; x++ {
m.SetColorIndex(x, y, colorIdx)
}
}
func (m *Image) drawCircle(x, y, radius int, colorIdx uint8) {
f := 1 - radius
dfx := 1
dfy := -2 * radius
xo := 0
yo := radius
m.SetColorIndex(x, y+radius, colorIdx)
m.SetColorIndex(x, y-radius, colorIdx)
m.drawHorizLine(x-radius, x+radius, y, colorIdx)
for xo < yo {
if f >= 0 {
yo--
dfy += 2
f += dfy
}
xo++
dfx += 2
f += dfx
m.drawHorizLine(x-xo, x+xo, y+yo, colorIdx)
m.drawHorizLine(x-xo, x+xo, y-yo, colorIdx)
m.drawHorizLine(x-yo, x+yo, y+xo, colorIdx)
m.drawHorizLine(x-yo, x+yo, y-xo, colorIdx)
}
}
func (m *Image) fillWithCircles(n, maxradius int) {
maxx := m.Bounds().Max.X
maxy := m.Bounds().Max.Y
for i := 0; i < n; i++ {
colorIdx := uint8(randInt(1, circleCount-1))
r := randInt(1, maxradius)
m.drawCircle(randInt(r, maxx-r), randInt(r, maxy-r), r, colorIdx)
}
}
func (m *Image) strikeThrough() {
maxx := m.Bounds().Max.X
maxy := m.Bounds().Max.Y
y := randInt(maxy/3, maxy-maxy/3)
amplitude := randFloat(5, 20)
period := randFloat(80, 180)
dx := 2.0 * math.Pi / period
for x := 0; x < maxx; x++ {
xo := amplitude * math.Cos(float64(y)*dx)
yo := amplitude * math.Sin(float64(x)*dx)
for yn := 0; yn < m.dotSize; yn++ {
r := randInt(0, m.dotSize)
m.drawCircle(x+int(xo), y+int(yo)+(yn*m.dotSize), r/2, 1)
}
}
}
func (m *Image) drawDigit(digit []byte, x, y int) {
skf := randFloat(-maxSkew, maxSkew)
xs := float64(x)
r := m.dotSize / 2
y += randInt(-r, r)
for yo := 0; yo < fontHeight; yo++ {
for xo := 0; xo < fontWidth; xo++ {
if digit[yo*fontWidth+xo] != blackChar {
continue
}
m.drawCircle(x+xo*m.dotSize, y+yo*m.dotSize, r, 1)
}
xs += skf
x = int(xs)
}
}
func (m *Image) distort(amplude float64, period float64) {
w := m.Bounds().Max.X
h := m.Bounds().Max.Y
oldm := m.Paletted
newm := image.NewPaletted(image.Rect(0, 0, w, h), oldm.Palette)
dx := 2.0 * math.Pi / period
for x := 0; x < w; x++ {
for y := 0; y < h; y++ {
xo := amplude * math.Sin(float64(y)*dx)
yo := amplude * math.Cos(float64(x)*dx)
newm.SetColorIndex(x, y, oldm.ColorIndexAt(x+int(xo), y+int(yo)))
}
}
m.Paletted = newm
}
func randomBrightness(c color.RGBA, max uint8) color.RGBA {
minc := min3(c.R, c.G, c.B)
maxc := max3(c.R, c.G, c.B)
if maxc > max {
return c
}
n := randIntn(int(max-maxc)) - int(minc)
return color.RGBA{
uint8(int(c.R) + n),
uint8(int(c.G) + n),
uint8(int(c.B) + n),
uint8(c.A),
}
}
func min3(x, y, z uint8) (m uint8) {
m = x
if y < m {
m = y
}
if z < m {
m = z
}
return
}
func max3(x, y, z uint8) (m uint8) {
m = x
if y > m {
m = y
}
if z > m {
m = z
}
return
}

View File

@ -0,0 +1,38 @@
package captcha
import (
"testing"
"github.com/astaxie/beego/utils"
)
type byteCounter struct {
n int64
}
func (bc *byteCounter) Write(b []byte) (int, error) {
bc.n += int64(len(b))
return len(b), nil
}
func BenchmarkNewImage(b *testing.B) {
b.StopTimer()
d := utils.RandomCreateBytes(challengeNums, defaultChars...)
b.StartTimer()
for i := 0; i < b.N; i++ {
NewImage(d, stdWidth, stdHeight)
}
}
func BenchmarkImageWriteTo(b *testing.B) {
b.StopTimer()
d := utils.RandomCreateBytes(challengeNums, defaultChars...)
b.StartTimer()
counter := &byteCounter{}
for i := 0; i < b.N; i++ {
img := NewImage(d, stdWidth, stdHeight)
img.WriteTo(counter)
b.SetBytes(counter.n)
counter.n = 0
}
}

264
utils/captcha/siprng.go Normal file
View File

@ -0,0 +1,264 @@
// modifiy and integrated to Beego from https://github.com/dchest/captcha
package captcha
import (
"crypto/rand"
"encoding/binary"
"io"
"sync"
)
// siprng is PRNG based on SipHash-2-4.
type siprng struct {
mu sync.Mutex
k0, k1, ctr uint64
}
// siphash implements SipHash-2-4, accepting a uint64 as a message.
func siphash(k0, k1, m uint64) uint64 {
// Initialization.
v0 := k0 ^ 0x736f6d6570736575
v1 := k1 ^ 0x646f72616e646f6d
v2 := k0 ^ 0x6c7967656e657261
v3 := k1 ^ 0x7465646279746573
t := uint64(8) << 56
// Compression.
v3 ^= m
// Round 1.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
// Round 2.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
v0 ^= m
// Compress last block.
v3 ^= t
// Round 1.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
// Round 2.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
v0 ^= t
// Finalization.
v2 ^= 0xff
// Round 1.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
// Round 2.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
// Round 3.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
// Round 4.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
return v0 ^ v1 ^ v2 ^ v3
}
// rekey sets a new PRNG key, which is read from crypto/rand.
func (p *siprng) rekey() {
var k [16]byte
if _, err := io.ReadFull(rand.Reader, k[:]); err != nil {
panic(err.Error())
}
p.k0 = binary.LittleEndian.Uint64(k[0:8])
p.k1 = binary.LittleEndian.Uint64(k[8:16])
p.ctr = 1
}
// Uint64 returns a new pseudorandom uint64.
// It rekeys PRNG on the first call and every 64 MB of generated data.
func (p *siprng) Uint64() uint64 {
p.mu.Lock()
if p.ctr == 0 || p.ctr > 8*1024*1024 {
p.rekey()
}
v := siphash(p.k0, p.k1, p.ctr)
p.ctr++
p.mu.Unlock()
return v
}
func (p *siprng) Int63() int64 {
return int64(p.Uint64() & 0x7fffffffffffffff)
}
func (p *siprng) Uint32() uint32 {
return uint32(p.Uint64())
}
func (p *siprng) Int31() int32 {
return int32(p.Uint32() & 0x7fffffff)
}
func (p *siprng) Intn(n int) int {
if n <= 0 {
panic("invalid argument to Intn")
}
if n <= 1<<31-1 {
return int(p.Int31n(int32(n)))
}
return int(p.Int63n(int64(n)))
}
func (p *siprng) Int63n(n int64) int64 {
if n <= 0 {
panic("invalid argument to Int63n")
}
max := int64((1 << 63) - 1 - (1<<63)%uint64(n))
v := p.Int63()
for v > max {
v = p.Int63()
}
return v % n
}
func (p *siprng) Int31n(n int32) int32 {
if n <= 0 {
panic("invalid argument to Int31n")
}
max := int32((1 << 31) - 1 - (1<<31)%uint32(n))
v := p.Int31()
for v > max {
v = p.Int31()
}
return v % n
}
func (p *siprng) Float64() float64 { return float64(p.Int63()) / (1 << 63) }

View File

@ -0,0 +1,19 @@
package captcha
import "testing"
func TestSiphash(t *testing.T) {
good := uint64(0xe849e8bb6ffe2567)
cur := siphash(0, 0, 0)
if cur != good {
t.Fatalf("siphash: expected %x, got %x", good, cur)
}
}
func BenchmarkSiprng(b *testing.B) {
b.SetBytes(8)
p := &siprng{}
for i := 0; i < b.N; i++ {
p.Uint64()
}
}

View File

@ -9,11 +9,13 @@ import (
"regexp" "regexp"
) )
// SelfPath gets compiled executable file absolute path
func SelfPath() string { func SelfPath() string {
path, _ := filepath.Abs(os.Args[0]) path, _ := filepath.Abs(os.Args[0])
return path return path
} }
// SelfDir gets compiled executable file directory
func SelfDir() string { func SelfDir() string {
return filepath.Dir(SelfPath()) return filepath.Dir(SelfPath())
} }
@ -28,8 +30,8 @@ func FileExists(name string) bool {
return true return true
} }
// search a file in paths. // Search a file in paths.
// this is offen used in search config file in /etc ~/ // this is often used in search config file in /etc ~/
func SearchFile(filename string, paths ...string) (fullpath string, err error) { func SearchFile(filename string, paths ...string) (fullpath string, err error) {
for _, path := range paths { for _, path := range paths {
if fullpath = filepath.Join(path, filename); FileExists(fullpath) { if fullpath = filepath.Join(path, filename); FileExists(fullpath) {

301
utils/mail.go Normal file
View File

@ -0,0 +1,301 @@
package utils
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"mime"
"mime/multipart"
"net/mail"
"net/smtp"
"net/textproto"
"os"
"path"
"path/filepath"
"strconv"
"strings"
)
const (
maxLineLength = 76
)
// Email is the type used for email messages
type Email struct {
Auth smtp.Auth
Identity string `json:"identity"`
Username string `json:"username"`
Password string `json:"password"`
Host string `json:"host"`
Port int `json:"port"`
From string `json:"from"`
To []string
Bcc []string
Cc []string
Subject string
Text string // Plaintext message (optional)
HTML string // Html message (optional)
Headers textproto.MIMEHeader
Attachments []*Attachment
ReadReceipt []string
}
// Attachment is a struct representing an email attachment.
// Based on the mime/multipart.FileHeader struct, Attachment contains the name, MIMEHeader, and content of the attachment in question
type Attachment struct {
Filename string
Header textproto.MIMEHeader
Content []byte
}
// NewEMail create new Email struct with config json.
// config json is followed from Email struct fields.
func NewEMail(config string) *Email {
e := new(Email)
e.Headers = textproto.MIMEHeader{}
err := json.Unmarshal([]byte(config), e)
if err != nil {
return nil
}
if e.From == "" {
e.From = e.Username
}
return e
}
// Make all send information to byte
func (e *Email) Bytes() ([]byte, error) {
buff := &bytes.Buffer{}
w := multipart.NewWriter(buff)
// Set the appropriate headers (overwriting any conflicts)
// Leave out Bcc (only included in envelope headers)
e.Headers.Set("To", strings.Join(e.To, ","))
if e.Cc != nil {
e.Headers.Set("Cc", strings.Join(e.Cc, ","))
}
e.Headers.Set("From", e.From)
e.Headers.Set("Subject", e.Subject)
if len(e.ReadReceipt) != 0 {
e.Headers.Set("Disposition-Notification-To", strings.Join(e.ReadReceipt, ","))
}
e.Headers.Set("MIME-Version", "1.0")
e.Headers.Set("Content-Type", fmt.Sprintf("multipart/mixed;\r\n boundary=%s\r\n", w.Boundary()))
// Write the envelope headers (including any custom headers)
if err := headerToBytes(buff, e.Headers); err != nil {
return nil, fmt.Errorf("Failed to render message headers: %s", err)
}
// Start the multipart/mixed part
fmt.Fprintf(buff, "--%s\r\n", w.Boundary())
header := textproto.MIMEHeader{}
// Check to see if there is a Text or HTML field
if e.Text != "" || e.HTML != "" {
subWriter := multipart.NewWriter(buff)
// Create the multipart alternative part
header.Set("Content-Type", fmt.Sprintf("multipart/alternative;\r\n boundary=%s\r\n", subWriter.Boundary()))
// Write the header
if err := headerToBytes(buff, header); err != nil {
return nil, fmt.Errorf("Failed to render multipart message headers: %s", err)
}
// Create the body sections
if e.Text != "" {
header.Set("Content-Type", fmt.Sprintf("text/plain; charset=UTF-8"))
header.Set("Content-Transfer-Encoding", "quoted-printable")
if _, err := subWriter.CreatePart(header); err != nil {
return nil, err
}
// Write the text
if err := quotePrintEncode(buff, e.Text); err != nil {
return nil, err
}
}
if e.HTML != "" {
header.Set("Content-Type", fmt.Sprintf("text/html; charset=UTF-8"))
header.Set("Content-Transfer-Encoding", "quoted-printable")
if _, err := subWriter.CreatePart(header); err != nil {
return nil, err
}
// Write the text
if err := quotePrintEncode(buff, e.HTML); err != nil {
return nil, err
}
}
if err := subWriter.Close(); err != nil {
return nil, err
}
}
// Create attachment part, if necessary
for _, a := range e.Attachments {
ap, err := w.CreatePart(a.Header)
if err != nil {
return nil, err
}
// Write the base64Wrapped content to the part
base64Wrap(ap, a.Content)
}
if err := w.Close(); err != nil {
return nil, err
}
return buff.Bytes(), nil
}
// Add attach file to the send mail
func (e *Email) AttachFile(filename string) (a *Attachment, err error) {
f, err := os.Open(filename)
if err != nil {
return
}
ct := mime.TypeByExtension(filepath.Ext(filename))
basename := path.Base(filename)
return e.Attach(f, basename, ct)
}
// Attach is used to attach content from an io.Reader to the email.
// Parameters include an io.Reader, the desired filename for the attachment, and the Content-Type.
func (e *Email) Attach(r io.Reader, filename string, c string) (a *Attachment, err error) {
var buffer bytes.Buffer
if _, err = io.Copy(&buffer, r); err != nil {
return
}
at := &Attachment{
Filename: filename,
Header: textproto.MIMEHeader{},
Content: buffer.Bytes(),
}
// Get the Content-Type to be used in the MIMEHeader
if c != "" {
at.Header.Set("Content-Type", c)
} else {
// If the Content-Type is blank, set the Content-Type to "application/octet-stream"
at.Header.Set("Content-Type", "application/octet-stream")
}
at.Header.Set("Content-Disposition", fmt.Sprintf("attachment;\r\n filename=\"%s\"", filename))
at.Header.Set("Content-Transfer-Encoding", "base64")
e.Attachments = append(e.Attachments, at)
return at, nil
}
func (e *Email) Send() error {
if e.Auth == nil {
e.Auth = smtp.PlainAuth(e.Identity, e.Username, e.Password, e.Host)
}
// Merge the To, Cc, and Bcc fields
to := make([]string, 0, len(e.To)+len(e.Cc)+len(e.Bcc))
to = append(append(append(to, e.To...), e.Cc...), e.Bcc...)
// Check to make sure there is at least one recipient and one "From" address
if e.From == "" || len(to) == 0 {
return errors.New("Must specify at least one From address and one To address")
}
from, err := mail.ParseAddress(e.From)
if err != nil {
return err
}
raw, err := e.Bytes()
if err != nil {
return err
}
return smtp.SendMail(e.Host+":"+strconv.Itoa(e.Port), e.Auth, from.Address, to, raw)
}
// quotePrintEncode writes the quoted-printable text to the IO Writer (according to RFC 2045)
func quotePrintEncode(w io.Writer, s string) error {
var buf [3]byte
mc := 0
for i := 0; i < len(s); i++ {
c := s[i]
// We're assuming Unix style text formats as input (LF line break), and
// quoted-printble uses CRLF line breaks. (Literal CRs will become
// "=0D", but probably shouldn't be there to begin with!)
if c == '\n' {
io.WriteString(w, "\r\n")
mc = 0
continue
}
var nextOut []byte
if isPrintable(c) {
nextOut = append(buf[:0], c)
} else {
nextOut = buf[:]
qpEscape(nextOut, c)
}
// Add a soft line break if the next (encoded) byte would push this line
// to or past the limit.
if mc+len(nextOut) >= maxLineLength {
if _, err := io.WriteString(w, "=\r\n"); err != nil {
return err
}
mc = 0
}
if _, err := w.Write(nextOut); err != nil {
return err
}
mc += len(nextOut)
}
// No trailing end-of-line?? Soft line break, then. TODO: is this sane?
if mc > 0 {
io.WriteString(w, "=\r\n")
}
return nil
}
// isPrintable returns true if the rune given is "printable" according to RFC 2045, false otherwise
func isPrintable(c byte) bool {
return (c >= '!' && c <= '<') || (c >= '>' && c <= '~') || (c == ' ' || c == '\n' || c == '\t')
}
// qpEscape is a helper function for quotePrintEncode which escapes a
// non-printable byte. Expects len(dest) == 3.
func qpEscape(dest []byte, c byte) {
const nums = "0123456789ABCDEF"
dest[0] = '='
dest[1] = nums[(c&0xf0)>>4]
dest[2] = nums[(c & 0xf)]
}
// headerToBytes enumerates the key and values in the header, and writes the results to the IO Writer
func headerToBytes(w io.Writer, t textproto.MIMEHeader) error {
for k, v := range t {
// Write the header key
_, err := fmt.Fprintf(w, "%s:", k)
if err != nil {
return err
}
// Write each value in the header
for _, c := range v {
_, err := fmt.Fprintf(w, " %s\r\n", c)
if err != nil {
return err
}
}
}
return nil
}
// base64Wrap encodes the attachment content, and wraps it according to RFC 2045 standards (every 76 chars)
// The output is then written to the specified io.Writer
func base64Wrap(w io.Writer, b []byte) {
// 57 raw bytes per 76-byte base64 line.
const maxRaw = 57
// Buffer for each line, including trailing CRLF.
var buffer [maxLineLength + len("\r\n")]byte
copy(buffer[maxLineLength:], "\r\n")
// Process raw chunks until there's no longer enough to fill a line.
for len(b) >= maxRaw {
base64.StdEncoding.Encode(buffer[:], b[:maxRaw])
w.Write(buffer[:])
b = b[maxRaw:]
}
// Handle the last chunk of bytes.
if len(b) > 0 {
out := buffer[:base64.StdEncoding.EncodedLen(len(b))]
base64.StdEncoding.Encode(out, b)
out = append(out, "\r\n"...)
w.Write(out)
}
}

27
utils/mail_test.go Normal file
View File

@ -0,0 +1,27 @@
package utils
import "testing"
func TestMail(t *testing.T) {
config := `{"username":"astaxie@gmail.com","password":"astaxie","host":"smtp.gmail.com","port":587}`
mail := NewEMail(config)
if mail.Username != "astaxie@gmail.com" {
t.Fatal("email parse get username error")
}
if mail.Password != "astaxie" {
t.Fatal("email parse get password error")
}
if mail.Host != "smtp.gmail.com" {
t.Fatal("email parse get host error")
}
if mail.Port != 587 {
t.Fatal("email parse get port error")
}
mail.To = []string{"xiemengjun@gmail.com"}
mail.From = "astaxie@gmail.com"
mail.Subject = "hi, just from beego!"
mail.Text = "Text Body is, of course, supported!"
mail.HTML = "<h1>Fancy Html is supported, too!</h1>"
mail.AttachFile("/Users/astaxie/github/beego/beego.go")
mail.Send()
}

20
utils/rand.go Normal file
View File

@ -0,0 +1,20 @@
package utils
import (
"crypto/rand"
)
// RandomCreateBytes generate random []byte by specify chars.
func RandomCreateBytes(n int, alphabets ...byte) []byte {
const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
var bytes = make([]byte, n)
rand.Read(bytes)
for i, b := range bytes {
if len(alphabets) == 0 {
bytes[i] = alphanum[b%byte(len(alphanum))]
} else {
bytes[i] = alphabets[b%byte(len(alphabets))]
}
}
return bytes
}

View File

@ -9,6 +9,7 @@ type BeeMap struct {
bm map[interface{}]interface{} bm map[interface{}]interface{}
} }
// NewBeeMap return new safemap
func NewBeeMap() *BeeMap { func NewBeeMap() *BeeMap {
return &BeeMap{ return &BeeMap{
lock: new(sync.RWMutex), lock: new(sync.RWMutex),
@ -16,7 +17,7 @@ func NewBeeMap() *BeeMap {
} }
} }
//Get from maps return the k's value // Get from maps return the k's value
func (m *BeeMap) Get(k interface{}) interface{} { func (m *BeeMap) Get(k interface{}) interface{} {
m.lock.RLock() m.lock.RLock()
defer m.lock.RUnlock() defer m.lock.RUnlock()
@ -51,12 +52,14 @@ func (m *BeeMap) Check(k interface{}) bool {
return true return true
} }
// Delete the given key and value.
func (m *BeeMap) Delete(k interface{}) { func (m *BeeMap) Delete(k interface{}) {
m.lock.Lock() m.lock.Lock()
defer m.lock.Unlock() defer m.lock.Unlock()
delete(m.bm, k) delete(m.bm, k)
} }
// Items returns all items in safemap.
func (m *BeeMap) Items() map[interface{}]interface{} { func (m *BeeMap) Items() map[interface{}]interface{} {
m.lock.RLock() m.lock.RLock()
defer m.lock.RUnlock() defer m.lock.RUnlock()

View File

@ -8,6 +8,7 @@ import (
type reducetype func(interface{}) interface{} type reducetype func(interface{}) interface{}
type filtertype func(interface{}) bool type filtertype func(interface{}) bool
// InSlice checks given string in string slice or not.
func InSlice(v string, sl []string) bool { func InSlice(v string, sl []string) bool {
for _, vv := range sl { for _, vv := range sl {
if vv == v { if vv == v {
@ -17,6 +18,7 @@ func InSlice(v string, sl []string) bool {
return false return false
} }
// InSliceIface checks given interface in interface slice.
func InSliceIface(v interface{}, sl []interface{}) bool { func InSliceIface(v interface{}, sl []interface{}) bool {
for _, vv := range sl { for _, vv := range sl {
if vv == v { if vv == v {
@ -26,6 +28,7 @@ func InSliceIface(v interface{}, sl []interface{}) bool {
return false return false
} }
// SliceRandList generate an int slice from min to max.
func SliceRandList(min, max int) []int { func SliceRandList(min, max int) []int {
if max < min { if max < min {
min, max = max, min min, max = max, min
@ -40,11 +43,13 @@ func SliceRandList(min, max int) []int {
return list return list
} }
// SliceMerge merges interface slices to one slice.
func SliceMerge(slice1, slice2 []interface{}) (c []interface{}) { func SliceMerge(slice1, slice2 []interface{}) (c []interface{}) {
c = append(slice1, slice2...) c = append(slice1, slice2...)
return return
} }
// SliceReduce generates a new slice after parsing every value by reduce function
func SliceReduce(slice []interface{}, a reducetype) (dslice []interface{}) { func SliceReduce(slice []interface{}, a reducetype) (dslice []interface{}) {
for _, v := range slice { for _, v := range slice {
dslice = append(dslice, a(v)) dslice = append(dslice, a(v))
@ -52,12 +57,14 @@ func SliceReduce(slice []interface{}, a reducetype) (dslice []interface{}) {
return return
} }
// SliceRand returns random one from slice.
func SliceRand(a []interface{}) (b interface{}) { func SliceRand(a []interface{}) (b interface{}) {
randnum := rand.Intn(len(a)) randnum := rand.Intn(len(a))
b = a[randnum] b = a[randnum]
return return
} }
// SliceSum sums all values in int64 slice.
func SliceSum(intslice []int64) (sum int64) { func SliceSum(intslice []int64) (sum int64) {
for _, v := range intslice { for _, v := range intslice {
sum += v sum += v
@ -65,6 +72,7 @@ func SliceSum(intslice []int64) (sum int64) {
return return
} }
// SliceFilter generates a new slice after filter function.
func SliceFilter(slice []interface{}, a filtertype) (ftslice []interface{}) { func SliceFilter(slice []interface{}, a filtertype) (ftslice []interface{}) {
for _, v := range slice { for _, v := range slice {
if a(v) { if a(v) {
@ -74,6 +82,7 @@ func SliceFilter(slice []interface{}, a filtertype) (ftslice []interface{}) {
return return
} }
// SliceDiff returns diff slice of slice1 - slice2.
func SliceDiff(slice1, slice2 []interface{}) (diffslice []interface{}) { func SliceDiff(slice1, slice2 []interface{}) (diffslice []interface{}) {
for _, v := range slice1 { for _, v := range slice1 {
if !InSliceIface(v, slice2) { if !InSliceIface(v, slice2) {
@ -83,6 +92,7 @@ func SliceDiff(slice1, slice2 []interface{}) (diffslice []interface{}) {
return return
} }
// SliceIntersect returns diff slice of slice2 - slice1.
func SliceIntersect(slice1, slice2 []interface{}) (diffslice []interface{}) { func SliceIntersect(slice1, slice2 []interface{}) (diffslice []interface{}) {
for _, v := range slice1 { for _, v := range slice1 {
if !InSliceIface(v, slice2) { if !InSliceIface(v, slice2) {
@ -92,6 +102,7 @@ func SliceIntersect(slice1, slice2 []interface{}) (diffslice []interface{}) {
return return
} }
// SliceChuck separates one slice to some sized slice.
func SliceChunk(slice []interface{}, size int) (chunkslice [][]interface{}) { func SliceChunk(slice []interface{}, size int) (chunkslice [][]interface{}) {
if size >= len(slice) { if size >= len(slice) {
chunkslice = append(chunkslice, slice) chunkslice = append(chunkslice, slice)
@ -105,6 +116,7 @@ func SliceChunk(slice []interface{}, size int) (chunkslice [][]interface{}) {
return return
} }
// SliceRange generates a new slice from begin to end with step duration of int64 number.
func SliceRange(start, end, step int64) (intslice []int64) { func SliceRange(start, end, step int64) (intslice []int64) {
for i := start; i <= end; i += step { for i := start; i <= end; i += step {
intslice = append(intslice, i) intslice = append(intslice, i)
@ -112,6 +124,7 @@ func SliceRange(start, end, step int64) (intslice []int64) {
return return
} }
// SlicePad prepends size number of val into slice.
func SlicePad(slice []interface{}, size int, val interface{}) []interface{} { func SlicePad(slice []interface{}, size int, val interface{}) []interface{} {
if size <= len(slice) { if size <= len(slice) {
return slice return slice
@ -122,6 +135,7 @@ func SlicePad(slice []interface{}, size int, val interface{}) []interface{} {
return slice return slice
} }
// SliceUnique cleans repeated values in slice.
func SliceUnique(slice []interface{}) (uniqueslice []interface{}) { func SliceUnique(slice []interface{}) (uniqueslice []interface{}) {
for _, v := range slice { for _, v := range slice {
if !InSliceIface(v, uniqueslice) { if !InSliceIface(v, uniqueslice) {
@ -131,6 +145,7 @@ func SliceUnique(slice []interface{}) (uniqueslice []interface{}) {
return return
} }
// SliceShuffle shuffles a slice.
func SliceShuffle(slice []interface{}) []interface{} { func SliceShuffle(slice []interface{}) []interface{} {
for i := 0; i < len(slice); i++ { for i := 0; i < len(slice); i++ {
a := rand.Intn(len(slice)) a := rand.Intn(len(slice))

View File

@ -41,13 +41,16 @@ func init() {
} }
} }
// Valid function type
type ValidFunc struct { type ValidFunc struct {
Name string Name string
Params []interface{} Params []interface{}
} }
// Validate function map
type Funcs map[string]reflect.Value type Funcs map[string]reflect.Value
// validate values with named type string
func (f Funcs) Call(name string, params ...interface{}) (result []reflect.Value, err error) { func (f Funcs) Call(name string, params ...interface{}) (result []reflect.Value, err error) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {

View File

@ -32,6 +32,7 @@ type ValidationResult struct {
Ok bool Ok bool
} }
// Get ValidationResult by given key string.
func (r *ValidationResult) Key(key string) *ValidationResult { func (r *ValidationResult) Key(key string) *ValidationResult {
if r.Error != nil { if r.Error != nil {
r.Error.Key = key r.Error.Key = key
@ -39,6 +40,7 @@ func (r *ValidationResult) Key(key string) *ValidationResult {
return r return r
} }
// Set ValidationResult message by string or format string with args
func (r *ValidationResult) Message(message string, args ...interface{}) *ValidationResult { func (r *ValidationResult) Message(message string, args ...interface{}) *ValidationResult {
if r.Error != nil { if r.Error != nil {
if len(args) == 0 { if len(args) == 0 {
@ -56,10 +58,12 @@ type Validation struct {
ErrorsMap map[string]*ValidationError ErrorsMap map[string]*ValidationError
} }
// Clean all ValidationError.
func (v *Validation) Clear() { func (v *Validation) Clear() {
v.Errors = []*ValidationError{} v.Errors = []*ValidationError{}
} }
// Has ValidationError nor not.
func (v *Validation) HasErrors() bool { func (v *Validation) HasErrors() bool {
return len(v.Errors) > 0 return len(v.Errors) > 0
} }
@ -101,67 +105,83 @@ func (v *Validation) Range(obj interface{}, min, max int, key string) *Validatio
return v.apply(Range{Min{Min: min}, Max{Max: max}, key}, obj) return v.apply(Range{Min{Min: min}, Max{Max: max}, key}, obj)
} }
// Test that the obj is longer than min size if type is string or slice
func (v *Validation) MinSize(obj interface{}, min int, key string) *ValidationResult { func (v *Validation) MinSize(obj interface{}, min int, key string) *ValidationResult {
return v.apply(MinSize{min, key}, obj) return v.apply(MinSize{min, key}, obj)
} }
// Test that the obj is shorter than max size if type is string or slice
func (v *Validation) MaxSize(obj interface{}, max int, key string) *ValidationResult { func (v *Validation) MaxSize(obj interface{}, max int, key string) *ValidationResult {
return v.apply(MaxSize{max, key}, obj) return v.apply(MaxSize{max, key}, obj)
} }
// Test that the obj is same length to n if type is string or slice
func (v *Validation) Length(obj interface{}, n int, key string) *ValidationResult { func (v *Validation) Length(obj interface{}, n int, key string) *ValidationResult {
return v.apply(Length{n, key}, obj) return v.apply(Length{n, key}, obj)
} }
// Test that the obj is [a-zA-Z] if type is string
func (v *Validation) Alpha(obj interface{}, key string) *ValidationResult { func (v *Validation) Alpha(obj interface{}, key string) *ValidationResult {
return v.apply(Alpha{key}, obj) return v.apply(Alpha{key}, obj)
} }
// Test that the obj is [0-9] if type is string
func (v *Validation) Numeric(obj interface{}, key string) *ValidationResult { func (v *Validation) Numeric(obj interface{}, key string) *ValidationResult {
return v.apply(Numeric{key}, obj) return v.apply(Numeric{key}, obj)
} }
// Test that the obj is [0-9a-zA-Z] if type is string
func (v *Validation) AlphaNumeric(obj interface{}, key string) *ValidationResult { func (v *Validation) AlphaNumeric(obj interface{}, key string) *ValidationResult {
return v.apply(AlphaNumeric{key}, obj) return v.apply(AlphaNumeric{key}, obj)
} }
// Test that the obj matches regexp if type is string
func (v *Validation) Match(obj interface{}, regex *regexp.Regexp, key string) *ValidationResult { func (v *Validation) Match(obj interface{}, regex *regexp.Regexp, key string) *ValidationResult {
return v.apply(Match{regex, key}, obj) return v.apply(Match{regex, key}, obj)
} }
// Test that the obj doesn't match regexp if type is string
func (v *Validation) NoMatch(obj interface{}, regex *regexp.Regexp, key string) *ValidationResult { func (v *Validation) NoMatch(obj interface{}, regex *regexp.Regexp, key string) *ValidationResult {
return v.apply(NoMatch{Match{Regexp: regex}, key}, obj) return v.apply(NoMatch{Match{Regexp: regex}, key}, obj)
} }
// Test that the obj is [0-9a-zA-Z_-] if type is string
func (v *Validation) AlphaDash(obj interface{}, key string) *ValidationResult { func (v *Validation) AlphaDash(obj interface{}, key string) *ValidationResult {
return v.apply(AlphaDash{NoMatch{Match: Match{Regexp: alphaDashPattern}}, key}, obj) return v.apply(AlphaDash{NoMatch{Match: Match{Regexp: alphaDashPattern}}, key}, obj)
} }
// Test that the obj is email address if type is string
func (v *Validation) Email(obj interface{}, key string) *ValidationResult { func (v *Validation) Email(obj interface{}, key string) *ValidationResult {
return v.apply(Email{Match{Regexp: emailPattern}, key}, obj) return v.apply(Email{Match{Regexp: emailPattern}, key}, obj)
} }
// Test that the obj is IP address if type is string
func (v *Validation) IP(obj interface{}, key string) *ValidationResult { func (v *Validation) IP(obj interface{}, key string) *ValidationResult {
return v.apply(IP{Match{Regexp: ipPattern}, key}, obj) return v.apply(IP{Match{Regexp: ipPattern}, key}, obj)
} }
// Test that the obj is base64 encoded if type is string
func (v *Validation) Base64(obj interface{}, key string) *ValidationResult { func (v *Validation) Base64(obj interface{}, key string) *ValidationResult {
return v.apply(Base64{Match{Regexp: base64Pattern}, key}, obj) return v.apply(Base64{Match{Regexp: base64Pattern}, key}, obj)
} }
// Test that the obj is chinese mobile number if type is string
func (v *Validation) Mobile(obj interface{}, key string) *ValidationResult { func (v *Validation) Mobile(obj interface{}, key string) *ValidationResult {
return v.apply(Mobile{Match{Regexp: mobilePattern}, key}, obj) return v.apply(Mobile{Match{Regexp: mobilePattern}, key}, obj)
} }
// Test that the obj is chinese telephone number if type is string
func (v *Validation) Tel(obj interface{}, key string) *ValidationResult { func (v *Validation) Tel(obj interface{}, key string) *ValidationResult {
return v.apply(Tel{Match{Regexp: telPattern}, key}, obj) return v.apply(Tel{Match{Regexp: telPattern}, key}, obj)
} }
// Test that the obj is chinese mobile or telephone number if type is string
func (v *Validation) Phone(obj interface{}, key string) *ValidationResult { func (v *Validation) Phone(obj interface{}, key string) *ValidationResult {
return v.apply(Phone{Mobile{Match: Match{Regexp: mobilePattern}}, return v.apply(Phone{Mobile{Match: Match{Regexp: mobilePattern}},
Tel{Match: Match{Regexp: telPattern}}, key}, obj) Tel{Match: Match{Regexp: telPattern}}, key}, obj)
} }
// Test that the obj is chinese zip code if type is string
func (v *Validation) ZipCode(obj interface{}, key string) *ValidationResult { func (v *Validation) ZipCode(obj interface{}, key string) *ValidationResult {
return v.apply(ZipCode{Match{Regexp: zipCodePattern}, key}, obj) return v.apply(ZipCode{Match{Regexp: zipCodePattern}, key}, obj)
} }
@ -210,6 +230,7 @@ func (v *Validation) setError(err *ValidationError) {
} }
} }
// Set error message for one field in ValidationError
func (v *Validation) SetError(fieldName string, errMsg string) *ValidationError { func (v *Validation) SetError(fieldName string, errMsg string) *ValidationError {
err := &ValidationError{Key: fieldName, Field: fieldName, Tmpl: errMsg, Message: errMsg} err := &ValidationError{Key: fieldName, Field: fieldName, Tmpl: errMsg, Message: errMsg}
v.setError(err) v.setError(err)
@ -230,6 +251,7 @@ func (v *Validation) Check(obj interface{}, checks ...Validator) *ValidationResu
return result return result
} }
// Validate a struct.
// the obj parameter must be a struct or a struct pointer // the obj parameter must be a struct or a struct pointer
func (v *Validation) Valid(obj interface{}) (b bool, err error) { func (v *Validation) Valid(obj interface{}) (b bool, err error) {
objT := reflect.TypeOf(obj) objT := reflect.TypeOf(obj)