1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-26 05:11:31 +00:00

Merge pull request #1 from fuxiaohei/develop

merge develop
This commit is contained in:
傅小黑 2014-01-17 07:48:39 -08:00
commit 6b5108ef92
77 changed files with 3478 additions and 685 deletions

View File

@ -35,9 +35,5 @@ More info [beego.me](http://beego.me)
beego is licensed under the Apache Licence, Version 2.0 beego is licensed under the Apache Licence, Version 2.0
(http://www.apache.org/licenses/LICENSE-2.0.html). (http://www.apache.org/licenses/LICENSE-2.0.html).
[![Clone in Koding](http://learn.koding.com/btn/clone_d.png)][koding]
## Use case [koding]: https://koding.com/Teamwork?import=https://github.com/astaxie/beego/archive/master.zip&c=git1
- Displaying API documentation: [gowalker](https://github.com/Unknwon/gowalker)
- seocms: [seocms](https://github.com/chinakr/seocms)
- CMS: [toropress](https://github.com/insionng/toropress)

8
app.go
View File

@ -118,6 +118,14 @@ func (app *App) AutoRouter(c ControllerInterface) *App {
return app return app
} }
// AutoRouterWithPrefix adds beego-defined controller handler with prefix.
// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page,
// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function.
func (app *App) AutoRouterWithPrefix(prefix string, c ControllerInterface) *App {
app.Handlers.AddAutoPrefix(prefix, c)
return app
}
// UrlFor creates a url with another registered controller handler with params. // UrlFor creates a url with another registered controller handler with params.
// The endpoint is formed as path.controller.name to defined the controller method which will run. // The endpoint is formed as path.controller.name to defined the controller method which will run.
// The values need key-pair data to assign into controller method. // The values need key-pair data to assign into controller method.

126
beego.go
View File

@ -4,6 +4,7 @@ import (
"net/http" "net/http"
"path" "path"
"path/filepath" "path/filepath"
"strconv"
"strings" "strings"
"github.com/astaxie/beego/middleware" "github.com/astaxie/beego/middleware"
@ -13,6 +14,76 @@ import (
// beego web framework version. // beego web framework version.
const VERSION = "1.0.1" const VERSION = "1.0.1"
type hookfunc func() error //hook function to run
var hooks []hookfunc //hook function slice to store the hookfunc
type groupRouter struct {
pattern string
controller ControllerInterface
mappingMethods string
}
// RouterGroups which will store routers
type GroupRouters []groupRouter
// Get a new GroupRouters
func NewGroupRouters() GroupRouters {
return make([]groupRouter, 0)
}
// Add Router in the GroupRouters
// it is for plugin or module to register router
func (gr GroupRouters) AddRouter(pattern string, c ControllerInterface, mappingMethod ...string) {
var newRG groupRouter
if len(mappingMethod) > 0 {
newRG = groupRouter{
pattern,
c,
mappingMethod[0],
}
} else {
newRG = groupRouter{
pattern,
c,
"",
}
}
gr = append(gr, newRG)
}
func (gr GroupRouters) AddAuto(c ControllerInterface) {
newRG := groupRouter{
"",
c,
"",
}
gr = append(gr, newRG)
}
// AddGroupRouter with the prefix
// it will register the router in BeeApp
// the follow code is write in modules:
// GR:=NewGroupRouters()
// GR.AddRouter("/login",&UserController,"get:Login")
// GR.AddRouter("/logout",&UserController,"get:Logout")
// GR.AddRouter("/register",&UserController,"get:Reg")
// the follow code is write in app:
// import "github.com/beego/modules/auth"
// AddRouterGroup("/admin", auth.GR)
func AddGroupRouter(prefix string, groups GroupRouters) *App {
for _, v := range groups {
if v.pattern == "" {
BeeApp.AutoRouterWithPrefix(prefix, v.controller)
} else if v.mappingMethods != "" {
BeeApp.Router(prefix+v.pattern, v.controller, v.mappingMethods)
} else {
BeeApp.Router(prefix+v.pattern, v.controller)
}
}
return BeeApp
}
// Router adds a patterned controller handler to BeeApp. // Router adds a patterned controller handler to BeeApp.
// it's an alias method of App.Router. // it's an alias method of App.Router.
func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App { func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App {
@ -36,6 +107,13 @@ func AutoRouter(c ControllerInterface) *App {
return BeeApp return BeeApp
} }
// AutoPrefix adds controller handler to BeeApp with prefix.
// it's same to App.AutoRouterWithPrefix.
func AutoPrefix(prefix string, c ControllerInterface) *App {
BeeApp.AutoRouterWithPrefix(prefix, c)
return BeeApp
}
// ErrorHandler registers http.HandlerFunc to each http err code string. // ErrorHandler registers http.HandlerFunc to each http err code string.
// usage: // usage:
// beego.ErrorHandler("404",NotFound) // beego.ErrorHandler("404",NotFound)
@ -87,6 +165,12 @@ func InsertFilter(pattern string, pos int, filter FilterFunc) *App {
return BeeApp return BeeApp
} }
// The hookfunc will run in beego.Run()
// such as sessionInit, middlerware start, buildtemplate, admin start
func AddAPPStartHook(hf hookfunc) {
hooks = append(hooks, hf)
}
// Run beego application. // Run beego application.
// it's alias of App.Run. // it's alias of App.Run.
func Run() { func Run() {
@ -99,18 +183,32 @@ func Run() {
} }
} }
//init mime // do hooks function
initMime() for _, hk := range hooks {
err := hk()
if err != nil {
panic(err)
}
}
if SessionOn { if SessionOn {
GlobalSessions, _ = session.NewManager(SessionProvider, var err error
SessionName, sessionConfig := AppConfig.String("sessionConfig")
SessionGCMaxLifetime, if sessionConfig == "" {
SessionSavePath, sessionConfig = `{"cookieName":"` + SessionName + `",` +
HttpTLS, `"gclifetime":` + strconv.FormatInt(SessionGCMaxLifetime, 10) + `,` +
SessionHashFunc, `"providerConfig":"` + SessionSavePath + `",` +
SessionHashKey, `"secure":` + strconv.FormatBool(HttpTLS) + `,` +
SessionCookieLifeTime) `"sessionIDHashFunc":"` + SessionHashFunc + `",` +
`"sessionIDHashKey":"` + SessionHashKey + `",` +
`"enableSetCookie":` + strconv.FormatBool(SessionAutoSetCookie) + `,` +
`"cookieLifeTime":` + strconv.Itoa(SessionCookieLifeTime) + `}`
}
GlobalSessions, err = session.NewManager(SessionProvider,
sessionConfig)
if err != nil {
panic(err)
}
go GlobalSessions.GC() go GlobalSessions.GC()
} }
@ -123,7 +221,7 @@ func Run() {
middleware.VERSION = VERSION middleware.VERSION = VERSION
middleware.AppName = AppName middleware.AppName = AppName
middleware.RegisterErrorHander() middleware.RegisterErrorHandler()
if EnableAdmin { if EnableAdmin {
go BeeAdminApp.Run() go BeeAdminApp.Run()
@ -131,3 +229,9 @@ func Run() {
BeeApp.Run() BeeApp.Run()
} }
func init() {
hooks = make([]hookfunc, 0)
//init mime
AddAPPStartHook(initMime)
}

2
cache/README.md vendored
View File

@ -43,7 +43,7 @@ interval means the gc time. The cache will check at each time interval, whether
## Memcache adapter ## Memcache adapter
memory adapter use the vitess's [Memcache](http://code.google.com/p/vitess/go/memcache) client. Memcache adapter use the vitess's [Memcache](http://code.google.com/p/vitess/go/memcache) client.
Configure like this: Configure like this:

50
cache/cache_test.go vendored
View File

@ -5,7 +5,7 @@ import (
"time" "time"
) )
func Test_cache(t *testing.T) { func TestCache(t *testing.T) {
bm, err := NewCache("memory", `{"interval":20}`) bm, err := NewCache("memory", `{"interval":20}`)
if err != nil { if err != nil {
t.Error("init err") t.Error("init err")
@ -51,3 +51,51 @@ func Test_cache(t *testing.T) {
t.Error("delete err") t.Error("delete err")
} }
} }
func TestFileCache(t *testing.T) {
bm, err := NewCache("file", `{"CachePath":"/cache","FileSuffix":".bin","DirectoryLevel":2,"EmbedExpiry":0}`)
if err != nil {
t.Error("init err")
}
if err = bm.Put("astaxie", 1, 10); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
t.Error("check err")
}
if v := bm.Get("astaxie"); v.(int) != 1 {
t.Error("get err")
}
if err = bm.Incr("astaxie"); err != nil {
t.Error("Incr Error", err)
}
if v := bm.Get("astaxie"); v.(int) != 2 {
t.Error("get err")
}
if err = bm.Decr("astaxie"); err != nil {
t.Error("Incr Error", err)
}
if v := bm.Get("astaxie"); v.(int) != 1 {
t.Error("get err")
}
bm.Delete("astaxie")
if bm.IsExist("astaxie") {
t.Error("delete err")
}
//test string
if err = bm.Put("astaxie", "author", 10); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
t.Error("check err")
}
if v := bm.Get("astaxie"); v.(string) != "author" {
t.Error("get err")
}
}

10
cache/file.go vendored
View File

@ -61,6 +61,7 @@ func (this *FileCache) StartAndGC(config string) error {
var cfg map[string]string var cfg map[string]string
json.Unmarshal([]byte(config), &cfg) json.Unmarshal([]byte(config), &cfg)
//fmt.Println(cfg) //fmt.Println(cfg)
//fmt.Println(config)
if _, ok := cfg["CachePath"]; !ok { if _, ok := cfg["CachePath"]; !ok {
cfg["CachePath"] = FileCachePath cfg["CachePath"] = FileCachePath
} }
@ -135,7 +136,7 @@ func (this *FileCache) Get(key string) interface{} {
return "" return ""
} }
var to FileCacheItem var to FileCacheItem
Gob_decode([]byte(filedata), &to) Gob_decode(filedata, &to)
if to.Expired < time.Now().Unix() { if to.Expired < time.Now().Unix() {
return "" return ""
} }
@ -177,7 +178,7 @@ func (this *FileCache) Delete(key string) error {
func (this *FileCache) Incr(key string) error { func (this *FileCache) Incr(key string) error {
data := this.Get(key) data := this.Get(key)
var incr int var incr int
fmt.Println(reflect.TypeOf(data).Name()) //fmt.Println(reflect.TypeOf(data).Name())
if reflect.TypeOf(data).Name() != "int" { if reflect.TypeOf(data).Name() != "int" {
incr = 0 incr = 0
} else { } else {
@ -210,8 +211,7 @@ func (this *FileCache) IsExist(key string) bool {
// Clean cached files. // Clean cached files.
// not implemented. // not implemented.
func (this *FileCache) ClearAll() error { func (this *FileCache) ClearAll() error {
//this.CachePath .递归删除 //this.CachePath
return nil return nil
} }
@ -271,7 +271,7 @@ func Gob_encode(data interface{}) ([]byte, error) {
} }
// Gob decodes file cache item. // Gob decodes file cache item.
func Gob_decode(data []byte, to interface{}) error { func Gob_decode(data []byte, to *FileCacheItem) error {
buf := bytes.NewBuffer(data) buf := bytes.NewBuffer(data)
dec := gob.NewDecoder(buf) dec := gob.NewDecoder(buf)
return dec.Decode(&to) return dec.Decode(&to)

41
cache/memcache.go vendored
View File

@ -21,7 +21,11 @@ func NewMemCache() *MemcacheCache {
// get value from memcache. // get value from memcache.
func (rc *MemcacheCache) Get(key string) interface{} { func (rc *MemcacheCache) Get(key string) interface{} {
if rc.c == nil { if rc.c == nil {
rc.c = rc.connectInit() var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
} }
v, err := rc.c.Get(key) v, err := rc.c.Get(key)
if err != nil { if err != nil {
@ -39,7 +43,11 @@ func (rc *MemcacheCache) Get(key string) interface{} {
// put value to memcache. only support string. // put value to memcache. only support string.
func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error { func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error {
if rc.c == nil { if rc.c == nil {
rc.c = rc.connectInit() var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
} }
v, ok := val.(string) v, ok := val.(string)
if !ok { if !ok {
@ -55,7 +63,11 @@ func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error {
// delete value in memcache. // delete value in memcache.
func (rc *MemcacheCache) Delete(key string) error { func (rc *MemcacheCache) Delete(key string) error {
if rc.c == nil { if rc.c == nil {
rc.c = rc.connectInit() var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
} }
_, err := rc.c.Delete(key) _, err := rc.c.Delete(key)
return err return err
@ -76,7 +88,11 @@ func (rc *MemcacheCache) Decr(key string) error {
// check value exists in memcache. // check value exists in memcache.
func (rc *MemcacheCache) IsExist(key string) bool { func (rc *MemcacheCache) IsExist(key string) bool {
if rc.c == nil { if rc.c == nil {
rc.c = rc.connectInit() var err error
rc.c, err = rc.connectInit()
if err != nil {
return false
}
} }
v, err := rc.c.Get(key) v, err := rc.c.Get(key)
if err != nil { if err != nil {
@ -93,7 +109,11 @@ func (rc *MemcacheCache) IsExist(key string) bool {
// clear all cached in memcache. // clear all cached in memcache.
func (rc *MemcacheCache) ClearAll() error { func (rc *MemcacheCache) ClearAll() error {
if rc.c == nil { if rc.c == nil {
rc.c = rc.connectInit() var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
} }
err := rc.c.FlushAll() err := rc.c.FlushAll()
return err return err
@ -109,20 +129,21 @@ func (rc *MemcacheCache) StartAndGC(config string) error {
return errors.New("config has no conn key") return errors.New("config has no conn key")
} }
rc.conninfo = cf["conn"] rc.conninfo = cf["conn"]
rc.c = rc.connectInit() var err error
if rc.c == nil { rc.c, err = rc.connectInit()
if err != nil {
return errors.New("dial tcp conn error") return errors.New("dial tcp conn error")
} }
return nil return nil
} }
// connect to memcache and keep the connection. // connect to memcache and keep the connection.
func (rc *MemcacheCache) connectInit() *memcache.Connection { func (rc *MemcacheCache) connectInit() (*memcache.Connection, error) {
c, err := memcache.Connect(rc.conninfo) c, err := memcache.Connect(rc.conninfo)
if err != nil { if err != nil {
return nil return nil, err
} }
return c return c, nil
} }
func init() { func init() {

106
cache/redis.go vendored
View File

@ -3,6 +3,7 @@ package cache
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"time"
"github.com/beego/redigo/redis" "github.com/beego/redigo/redis"
) )
@ -14,7 +15,7 @@ var (
// Redis cache adapter. // Redis cache adapter.
type RedisCache struct { type RedisCache struct {
c redis.Conn p *redis.Pool // redis connection pool
conninfo string conninfo string
key string key string
} }
@ -24,107 +25,62 @@ func NewRedisCache() *RedisCache {
return &RedisCache{key: DefaultKey} return &RedisCache{key: DefaultKey}
} }
// actually do the redis cmds
func (rc *RedisCache) do(commandName string, args ...interface{}) (reply interface{}, err error) {
c := rc.p.Get()
defer c.Close()
return c.Do(commandName, args...)
}
// Get cache from redis. // Get cache from redis.
func (rc *RedisCache) Get(key string) interface{} { func (rc *RedisCache) Get(key string) interface{} {
if rc.c == nil { v, err := rc.do("HGET", rc.key, key)
var err error
rc.c, err = rc.connectInit()
if err != nil {
return nil
}
}
v, err := rc.c.Do("HGET", rc.key, key)
if err != nil { if err != nil {
return nil return nil
} }
return v return v
} }
// put cache to redis. // put cache to redis.
// timeout is ignored. // timeout is ignored.
func (rc *RedisCache) Put(key string, val interface{}, timeout int64) error { func (rc *RedisCache) Put(key string, val interface{}, timeout int64) error {
if rc.c == nil { _, err := rc.do("HSET", rc.key, key, val)
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
_, err := rc.c.Do("HSET", rc.key, key, val)
return err return err
} }
// delete cache in redis. // delete cache in redis.
func (rc *RedisCache) Delete(key string) error { func (rc *RedisCache) Delete(key string) error {
if rc.c == nil { _, err := rc.do("HDEL", rc.key, key)
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
_, err := rc.c.Do("HDEL", rc.key, key)
return err return err
} }
// check cache exist in redis. // check cache exist in redis.
func (rc *RedisCache) IsExist(key string) bool { func (rc *RedisCache) IsExist(key string) bool {
if rc.c == nil { v, err := redis.Bool(rc.do("HEXISTS", rc.key, key))
var err error
rc.c, err = rc.connectInit()
if err != nil {
return false
}
}
v, err := redis.Bool(rc.c.Do("HEXISTS", rc.key, key))
if err != nil { if err != nil {
return false return false
} }
return v return v
} }
// increase counter in redis. // increase counter in redis.
func (rc *RedisCache) Incr(key string) error { func (rc *RedisCache) Incr(key string) error {
if rc.c == nil { _, err := redis.Bool(rc.do("HINCRBY", rc.key, key, 1))
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err return err
} }
}
_, err := redis.Bool(rc.c.Do("HINCRBY", rc.key, key, 1))
if err != nil {
return err
}
return nil
}
// decrease counter in redis. // decrease counter in redis.
func (rc *RedisCache) Decr(key string) error { func (rc *RedisCache) Decr(key string) error {
if rc.c == nil { _, err := redis.Bool(rc.do("HINCRBY", rc.key, key, -1))
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err return err
} }
}
_, err := redis.Bool(rc.c.Do("HINCRBY", rc.key, key, -1))
if err != nil {
return err
}
return nil
}
// clean all cache in redis. delete this redis collection. // clean all cache in redis. delete this redis collection.
func (rc *RedisCache) ClearAll() error { func (rc *RedisCache) ClearAll() error {
if rc.c == nil { _, err := rc.do("DEL", rc.key)
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
_, err := rc.c.Do("DEL", rc.key)
return err return err
} }
@ -135,32 +91,42 @@ func (rc *RedisCache) ClearAll() error {
func (rc *RedisCache) StartAndGC(config string) error { func (rc *RedisCache) StartAndGC(config string) error {
var cf map[string]string var cf map[string]string
json.Unmarshal([]byte(config), &cf) json.Unmarshal([]byte(config), &cf)
if _, ok := cf["key"]; !ok { if _, ok := cf["key"]; !ok {
cf["key"] = DefaultKey cf["key"] = DefaultKey
} }
if _, ok := cf["conn"]; !ok { if _, ok := cf["conn"]; !ok {
return errors.New("config has no conn key") return errors.New("config has no conn key")
} }
rc.key = cf["key"] rc.key = cf["key"]
rc.conninfo = cf["conn"] rc.conninfo = cf["conn"]
var err error rc.connectInit()
rc.c, err = rc.connectInit()
if err != nil { c := rc.p.Get()
defer c.Close()
if err := c.Err(); err != nil {
return err return err
} }
if rc.c == nil {
return errors.New("dial tcp conn error")
}
return nil return nil
} }
// connect to redis. // connect to redis.
func (rc *RedisCache) connectInit() (redis.Conn, error) { func (rc *RedisCache) connectInit() {
// initialize a new pool
rc.p = &redis.Pool{
MaxIdle: 3,
IdleTimeout: 180 * time.Second,
Dial: func() (redis.Conn, error) {
c, err := redis.Dial("tcp", rc.conninfo) c, err := redis.Dial("tcp", rc.conninfo)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return c, nil return c, nil
},
}
} }
func init() { func init() {

View File

@ -40,6 +40,7 @@ var (
SessionHashFunc string // session hash generation func. SessionHashFunc string // session hash generation func.
SessionHashKey string // session hash salt string. SessionHashKey string // session hash salt string.
SessionCookieLifeTime int // the life time of session id in cookie. SessionCookieLifeTime int // the life time of session id in cookie.
SessionAutoSetCookie bool // auto setcookie
UseFcgi bool UseFcgi bool
MaxMemory int64 MaxMemory int64
EnableGzip bool // flag of enable gzip EnableGzip bool // flag of enable gzip
@ -96,6 +97,7 @@ func init() {
SessionHashFunc = "sha1" SessionHashFunc = "sha1"
SessionHashKey = "beegoserversessionkey" SessionHashKey = "beegoserversessionkey"
SessionCookieLifeTime = 0 //set cookie default is the brower life SessionCookieLifeTime = 0 //set cookie default is the brower life
SessionAutoSetCookie = true
UseFcgi = false UseFcgi = false
@ -139,6 +141,7 @@ func init() {
func ParseConfig() (err error) { func ParseConfig() (err error) {
AppConfig, err = config.NewConfig("ini", AppConfigPath) AppConfig, err = config.NewConfig("ini", AppConfigPath)
if err != nil { if err != nil {
AppConfig = config.NewFakeConfig()
return err return err
} else { } else {
HttpAddr = AppConfig.String("HttpAddr") HttpAddr = AppConfig.String("HttpAddr")

View File

@ -8,6 +8,7 @@ import (
type ConfigContainer interface { type ConfigContainer interface {
Set(key, val string) error // support section::key type in given key when using ini type. Set(key, val string) error // support section::key type in given key when using ini type.
String(key string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. String(key string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
Strings(key string) []string //get string slice
Int(key string) (int, error) Int(key string) (int, error)
Int64(key string) (int64, error) Int64(key string) (int64, error)
Bool(key string) (bool, error) Bool(key string) (bool, error)

62
config/fake.go Normal file
View File

@ -0,0 +1,62 @@
package config
import (
"errors"
"strconv"
"strings"
)
type fakeConfigContainer struct {
data map[string]string
}
func (c *fakeConfigContainer) getData(key string) string {
key = strings.ToLower(key)
return c.data[key]
}
func (c *fakeConfigContainer) Set(key, val string) error {
key = strings.ToLower(key)
c.data[key] = val
return nil
}
func (c *fakeConfigContainer) String(key string) string {
return c.getData(key)
}
func (c *fakeConfigContainer) Strings(key string) []string {
return strings.Split(c.getData(key), ";")
}
func (c *fakeConfigContainer) Int(key string) (int, error) {
return strconv.Atoi(c.getData(key))
}
func (c *fakeConfigContainer) Int64(key string) (int64, error) {
return strconv.ParseInt(c.getData(key), 10, 64)
}
func (c *fakeConfigContainer) Bool(key string) (bool, error) {
return strconv.ParseBool(c.getData(key))
}
func (c *fakeConfigContainer) Float(key string) (float64, error) {
return strconv.ParseFloat(c.getData(key), 64)
}
func (c *fakeConfigContainer) DIY(key string) (interface{}, error) {
key = strings.ToLower(key)
if v, ok := c.data[key]; ok {
return v, nil
}
return nil, errors.New("key not find")
}
var _ ConfigContainer = new(fakeConfigContainer)
func NewFakeConfig() ConfigContainer {
return &fakeConfigContainer{
data: make(map[string]string),
}
}

View File

@ -146,6 +146,11 @@ func (c *IniConfigContainer) String(key string) string {
return c.getdata(key) return c.getdata(key)
} }
// Strings returns the []string value for a given key.
func (c *IniConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";")
}
// WriteValue writes a new value for key. // WriteValue writes a new value for key.
// if write to one section, the key need be "section::key". // if write to one section, the key need be "section::key".
// if the section is not existed, it panics. // if the section is not existed, it panics.

View File

@ -19,6 +19,7 @@ copyrequestbody = true
key1="asta" key1="asta"
key2 = "xie" key2 = "xie"
CaseInsensitive = true CaseInsensitive = true
peers = one;two;three
` `
func TestIni(t *testing.T) { func TestIni(t *testing.T) {
@ -78,4 +79,11 @@ func TestIni(t *testing.T) {
if v, err := iniconf.Bool("demo::caseinsensitive"); err != nil || v != true { if v, err := iniconf.Bool("demo::caseinsensitive"); err != nil || v != true {
t.Fatal("get demo.caseinsensitive error") t.Fatal("get demo.caseinsensitive error")
} }
if data := iniconf.Strings("demo::peers"); len(data) != 3 {
t.Fatal("get strings error", data)
} else if data[0] != "one" {
t.Fatal("get first params error not equat to one")
}
} }

View File

@ -116,6 +116,11 @@ func (c *JsonConfigContainer) String(key string) string {
return "" return ""
} }
// Strings returns the []string value for a given key.
func (c *JsonConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";")
}
// WriteValue writes a new value for key. // WriteValue writes a new value for key.
func (c *JsonConfigContainer) Set(key, val string) error { func (c *JsonConfigContainer) Set(key, val string) error {
c.Lock() c.Lock()

View File

@ -5,6 +5,7 @@ import (
"io/ioutil" "io/ioutil"
"os" "os"
"strconv" "strconv"
"strings"
"sync" "sync"
"github.com/beego/x2j" "github.com/beego/x2j"
@ -72,6 +73,11 @@ func (c *XMLConfigContainer) String(key string) string {
return "" return ""
} }
// Strings returns the []string value for a given key.
func (c *XMLConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";")
}
// WriteValue writes a new value for key. // WriteValue writes a new value for key.
func (c *XMLConfigContainer) Set(key, val string) error { func (c *XMLConfigContainer) Set(key, val string) error {
c.Lock() c.Lock()

View File

@ -7,6 +7,7 @@ import (
"io/ioutil" "io/ioutil"
"log" "log"
"os" "os"
"strings"
"sync" "sync"
"github.com/beego/goyaml2" "github.com/beego/goyaml2"
@ -117,6 +118,11 @@ func (c *YAMLConfigContainer) String(key string) string {
return "" return ""
} }
// Strings returns the []string value for a given key.
func (c *YAMLConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";")
}
// WriteValue writes a new value for key. // WriteValue writes a new value for key.
func (c *YAMLConfigContainer) Set(key, val string) error { func (c *YAMLConfigContainer) Set(key, val string) error {
c.Lock() c.Lock()

View File

@ -3,7 +3,6 @@ package beego
import ( import (
"bytes" "bytes"
"crypto/hmac" "crypto/hmac"
"crypto/rand"
"crypto/sha1" "crypto/sha1"
"encoding/base64" "encoding/base64"
"errors" "errors"
@ -22,6 +21,7 @@ import (
"github.com/astaxie/beego/context" "github.com/astaxie/beego/context"
"github.com/astaxie/beego/session" "github.com/astaxie/beego/session"
"github.com/astaxie/beego/utils"
) )
var ( var (
@ -140,7 +140,7 @@ func (c *Controller) RenderString() (string, error) {
return string(b), e return string(b), e
} }
// RenderBytes returns the bytes of renderd tempate string. Do not send out response. // RenderBytes returns the bytes of rendered template string. Do not send out response.
func (c *Controller) RenderBytes() ([]byte, error) { func (c *Controller) RenderBytes() ([]byte, error) {
//if the controller has set layout, then first get the tplname's content set the content to the layout //if the controller has set layout, then first get the tplname's content set the content to the layout
if c.Layout != "" { if c.Layout != "" {
@ -165,7 +165,7 @@ func (c *Controller) RenderBytes() ([]byte, error) {
if c.LayoutSections != nil { if c.LayoutSections != nil {
for sectionName, sectionTpl := range c.LayoutSections { for sectionName, sectionTpl := range c.LayoutSections {
if (sectionTpl == "") { if sectionTpl == "" {
c.Data[sectionName] = "" c.Data[sectionName] = ""
continue continue
} }
@ -391,6 +391,7 @@ func (c *Controller) DelSession(name interface{}) {
// SessionRegenerateID regenerates session id for this session. // SessionRegenerateID regenerates session id for this session.
// the session data have no changes. // the session data have no changes.
func (c *Controller) SessionRegenerateID() { func (c *Controller) SessionRegenerateID() {
c.CruSession.SessionRelease(c.Ctx.ResponseWriter)
c.CruSession = GlobalSessions.SessionRegenerateId(c.Ctx.ResponseWriter, c.Ctx.Request) c.CruSession = GlobalSessions.SessionRegenerateId(c.Ctx.ResponseWriter, c.Ctx.Request)
c.Ctx.Input.CruSession = c.CruSession c.Ctx.Input.CruSession = c.CruSession
} }
@ -454,7 +455,7 @@ func (c *Controller) XsrfToken() string {
} else { } else {
expire = int64(XSRFExpire) expire = int64(XSRFExpire)
} }
token = getRandomString(15) token = string(utils.RandomCreateBytes(15))
c.SetSecureCookie(XSRFKEY, "_xsrf", token, expire) c.SetSecureCookie(XSRFKEY, "_xsrf", token, expire)
} }
c._xsrf_token = token c._xsrf_token = token
@ -491,14 +492,3 @@ func (c *Controller) XsrfFormHtml() string {
func (c *Controller) GetControllerAndAction() (controllerName, actionName string) { func (c *Controller) GetControllerAndAction() (controllerName, actionName string) {
return c.controllerName, c.actionName return c.controllerName, c.actionName
} }
// getRandomString returns random string.
func getRandomString(n int) string {
const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
var bytes = make([]byte, n)
rand.Read(bytes)
for i, b := range bytes {
bytes[i] = alphanum[b%byte(len(alphanum))]
}
return string(bytes)
}

View File

@ -1,12 +1,13 @@
package controllers package controllers
import ( import (
"github.com/astaxie/beego"
"github.com/garyburd/go-websocket/websocket"
"io/ioutil" "io/ioutil"
"math/rand" "math/rand"
"net/http" "net/http"
"time" "time"
"github.com/astaxie/beego"
"github.com/gorilla/websocket"
) )
const ( const (

View File

@ -28,6 +28,12 @@ func (mr *FilterRouter) ValidRouter(router string) (bool, map[string]string) {
if router == mr.pattern { if router == mr.pattern {
return true, nil return true, nil
} }
//pattern /admin router /admin/ match
//pattern /admin/ router /admin don't match, because url will 301 in router
if n := len(router); n > 1 && router[n-1] == '/' && router[:n-2] == mr.pattern {
return true, nil
}
if mr.hasregex { if mr.hasregex {
if !mr.regex.MatchString(router) { if !mr.regex.MatchString(router) {
return false, nil return false, nil
@ -46,7 +52,7 @@ func (mr *FilterRouter) ValidRouter(router string) (bool, map[string]string) {
return false, nil return false, nil
} }
func buildFilter(pattern string, filter FilterFunc) *FilterRouter { func buildFilter(pattern string, filter FilterFunc) (*FilterRouter, error) {
mr := new(FilterRouter) mr := new(FilterRouter)
mr.params = make(map[int]string) mr.params = make(map[int]string)
mr.filterFunc = filter mr.filterFunc = filter
@ -54,7 +60,7 @@ func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
j := 0 j := 0
for i, part := range parts { for i, part := range parts {
if strings.HasPrefix(part, ":") { if strings.HasPrefix(part, ":") {
expr := "(.+)" expr := "(.*)"
//a user may choose to override the default expression //a user may choose to override the default expression
// similar to expressjs: /user/:id([0-9]+) // similar to expressjs: /user/:id([0-9]+)
if index := strings.Index(part, "("); index != -1 { if index := strings.Index(part, "("); index != -1 {
@ -77,7 +83,7 @@ func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
j++ j++
} }
if strings.HasPrefix(part, "*") { if strings.HasPrefix(part, "*") {
expr := "(.+)" expr := "(.*)"
if part == "*.*" { if part == "*.*" {
mr.params[j] = ":path" mr.params[j] = ":path"
parts[i] = "([^.]+).([^.]+)" parts[i] = "([^.]+).([^.]+)"
@ -137,12 +143,11 @@ func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
pattern = strings.Join(parts, "/") pattern = strings.Join(parts, "/")
regex, regexErr := regexp.Compile(pattern) regex, regexErr := regexp.Compile(pattern)
if regexErr != nil { if regexErr != nil {
//TODO add error handling here to avoid panic return nil, regexErr
panic(regexErr)
} }
mr.regex = regex mr.regex = regex
mr.hasregex = true mr.hasregex = true
} }
mr.pattern = pattern mr.pattern = pattern
return mr return mr, nil
} }

View File

@ -23,3 +23,32 @@ func TestFilter(t *testing.T) {
t.Errorf("user define func can't run") t.Errorf("user define func can't run")
} }
} }
var FilterAdminUser = func(ctx *context.Context) {
ctx.Output.Body([]byte("i am admin"))
}
// Filter pattern /admin/:all
// all url like /admin/ /admin/xie will all get filter
func TestPatternTwo(t *testing.T) {
r, _ := http.NewRequest("GET", "/admin/", nil)
w := httptest.NewRecorder()
handler := NewControllerRegistor()
handler.AddFilter("/admin/:all", "AfterStatic", FilterAdminUser)
handler.ServeHTTP(w, r)
if w.Body.String() != "i am admin" {
t.Errorf("filter /admin/ can't run")
}
}
func TestPatternThree(t *testing.T) {
r, _ := http.NewRequest("GET", "/admin/astaxie", nil)
w := httptest.NewRecorder()
handler := NewControllerRegistor()
handler.AddFilter("/admin/:all", "AfterStatic", FilterAdminUser)
handler.ServeHTTP(w, r)
if w.Body.String() != "i am admin" {
t.Errorf("filter /admin/astaxie can't run")
}
}

View File

@ -7,6 +7,8 @@ import (
"net" "net"
) )
// ConnWriter implements LoggerInterface.
// it writes messages in keep-live tcp connection.
type ConnWriter struct { type ConnWriter struct {
lg *log.Logger lg *log.Logger
innerWriter io.WriteCloser innerWriter io.WriteCloser
@ -17,12 +19,15 @@ type ConnWriter struct {
Level int `json:"level"` Level int `json:"level"`
} }
// create new ConnWrite returning as LoggerInterface.
func NewConn() LoggerInterface { func NewConn() LoggerInterface {
conn := new(ConnWriter) conn := new(ConnWriter)
conn.Level = LevelTrace conn.Level = LevelTrace
return conn return conn
} }
// init connection writer with json config.
// json config only need key "level".
func (c *ConnWriter) Init(jsonconfig string) error { func (c *ConnWriter) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), c) err := json.Unmarshal([]byte(jsonconfig), c)
if err != nil { if err != nil {
@ -31,6 +36,8 @@ func (c *ConnWriter) Init(jsonconfig string) error {
return nil return nil
} }
// write message in connection.
// if connection is down, try to re-connect.
func (c *ConnWriter) WriteMsg(msg string, level int) error { func (c *ConnWriter) WriteMsg(msg string, level int) error {
if level < c.Level { if level < c.Level {
return nil return nil
@ -49,10 +56,12 @@ func (c *ConnWriter) WriteMsg(msg string, level int) error {
return nil return nil
} }
// implementing method. empty.
func (c *ConnWriter) Flush() { func (c *ConnWriter) Flush() {
} }
// destroy connection writer and close tcp listener.
func (c *ConnWriter) Destroy() { func (c *ConnWriter) Destroy() {
if c.innerWriter == nil { if c.innerWriter == nil {
return return

View File

@ -6,11 +6,13 @@ import (
"os" "os"
) )
// ConsoleWriter implements LoggerInterface and writes messages to terminal.
type ConsoleWriter struct { type ConsoleWriter struct {
lg *log.Logger lg *log.Logger
Level int `json:"level"` Level int `json:"level"`
} }
// create ConsoleWriter returning as LoggerInterface.
func NewConsole() LoggerInterface { func NewConsole() LoggerInterface {
cw := new(ConsoleWriter) cw := new(ConsoleWriter)
cw.lg = log.New(os.Stdout, "", log.Ldate|log.Ltime) cw.lg = log.New(os.Stdout, "", log.Ldate|log.Ltime)
@ -18,6 +20,8 @@ func NewConsole() LoggerInterface {
return cw return cw
} }
// init console logger.
// jsonconfig like '{"level":LevelTrace}'.
func (c *ConsoleWriter) Init(jsonconfig string) error { func (c *ConsoleWriter) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), c) err := json.Unmarshal([]byte(jsonconfig), c)
if err != nil { if err != nil {
@ -26,6 +30,7 @@ func (c *ConsoleWriter) Init(jsonconfig string) error {
return nil return nil
} }
// write message in console.
func (c *ConsoleWriter) WriteMsg(msg string, level int) error { func (c *ConsoleWriter) WriteMsg(msg string, level int) error {
if level < c.Level { if level < c.Level {
return nil return nil
@ -34,10 +39,12 @@ func (c *ConsoleWriter) WriteMsg(msg string, level int) error {
return nil return nil
} }
// implementing method. empty.
func (c *ConsoleWriter) Destroy() { func (c *ConsoleWriter) Destroy() {
} }
// implementing method. empty.
func (c *ConsoleWriter) Flush() { func (c *ConsoleWriter) Flush() {
} }

View File

@ -13,6 +13,8 @@ import (
"time" "time"
) )
// FileLogWriter implements LoggerInterface.
// It writes messages by lines limit, file size limit, or time frequency.
type FileLogWriter struct { type FileLogWriter struct {
*log.Logger *log.Logger
mw *MuxWriter mw *MuxWriter
@ -38,17 +40,20 @@ type FileLogWriter struct {
Level int `json:"level"` Level int `json:"level"`
} }
// an *os.File writer with locker.
type MuxWriter struct { type MuxWriter struct {
sync.Mutex sync.Mutex
fd *os.File fd *os.File
} }
// write to os.File.
func (l *MuxWriter) Write(b []byte) (int, error) { func (l *MuxWriter) Write(b []byte) (int, error) {
l.Lock() l.Lock()
defer l.Unlock() defer l.Unlock()
return l.fd.Write(b) return l.fd.Write(b)
} }
// set os.File in writer.
func (l *MuxWriter) SetFd(fd *os.File) { func (l *MuxWriter) SetFd(fd *os.File) {
if l.fd != nil { if l.fd != nil {
l.fd.Close() l.fd.Close()
@ -56,6 +61,7 @@ func (l *MuxWriter) SetFd(fd *os.File) {
l.fd = fd l.fd = fd
} }
// create a FileLogWriter returning as LoggerInterface.
func NewFileWriter() LoggerInterface { func NewFileWriter() LoggerInterface {
w := &FileLogWriter{ w := &FileLogWriter{
Filename: "", Filename: "",
@ -73,7 +79,8 @@ func NewFileWriter() LoggerInterface {
return w return w
} }
// jsonconfig like this // Init file logger with json config.
// jsonconfig like:
// { // {
// "filename":"logs/beego.log", // "filename":"logs/beego.log",
// "maxlines":10000, // "maxlines":10000,
@ -94,6 +101,7 @@ func (w *FileLogWriter) Init(jsonconfig string) error {
return err return err
} }
// start file logger. create log file and set to locker-inside file writer.
func (w *FileLogWriter) StartLogger() error { func (w *FileLogWriter) StartLogger() error {
fd, err := w.createLogFile() fd, err := w.createLogFile()
if err != nil { if err != nil {
@ -122,6 +130,7 @@ func (w *FileLogWriter) docheck(size int) {
w.maxsize_cursize += size w.maxsize_cursize += size
} }
// write logger message into file.
func (w *FileLogWriter) WriteMsg(msg string, level int) error { func (w *FileLogWriter) WriteMsg(msg string, level int) error {
if level < w.Level { if level < w.Level {
return nil return nil
@ -158,6 +167,8 @@ func (w *FileLogWriter) initFd() error {
return nil return nil
} }
// DoRotate means it need to write file in new file.
// new file name like xx.log.2013-01-01.2
func (w *FileLogWriter) DoRotate() error { func (w *FileLogWriter) DoRotate() error {
_, err := os.Lstat(w.Filename) _, err := os.Lstat(w.Filename)
if err == nil { // file exists if err == nil { // file exists
@ -211,10 +222,14 @@ func (w *FileLogWriter) deleteOldLog() {
}) })
} }
// destroy file logger, close file writer.
func (w *FileLogWriter) Destroy() { func (w *FileLogWriter) Destroy() {
w.mw.fd.Close() w.mw.fd.Close()
} }
// flush file logger.
// there are no buffering messages in file logger in memory.
// flush file means sync file from disk.
func (w *FileLogWriter) Flush() { func (w *FileLogWriter) Flush() {
w.mw.fd.Sync() w.mw.fd.Sync()
} }

View File

@ -6,6 +6,7 @@ import (
) )
const ( const (
// log message levels
LevelTrace = iota LevelTrace = iota
LevelDebug LevelDebug
LevelInfo LevelInfo
@ -16,6 +17,7 @@ const (
type loggerType func() LoggerInterface type loggerType func() LoggerInterface
// LoggerInterface defines the behavior of a log provider.
type LoggerInterface interface { type LoggerInterface interface {
Init(config string) error Init(config string) error
WriteMsg(msg string, level int) error WriteMsg(msg string, level int) error
@ -38,6 +40,8 @@ func Register(name string, log loggerType) {
adapters[name] = log adapters[name] = log
} }
// BeeLogger is default logger in beego application.
// it can contain several providers and log message into all providers.
type BeeLogger struct { type BeeLogger struct {
lock sync.Mutex lock sync.Mutex
level int level int
@ -50,7 +54,9 @@ type logMsg struct {
msg string msg string
} }
// config need to be correct JSON as string: {"interval":360} // NewLogger returns a new BeeLogger.
// channellen means the number of messages in chan.
// if the buffering chan is full, logger adapters write to file or other way.
func NewLogger(channellen int64) *BeeLogger { func NewLogger(channellen int64) *BeeLogger {
bl := new(BeeLogger) bl := new(BeeLogger)
bl.msg = make(chan *logMsg, channellen) bl.msg = make(chan *logMsg, channellen)
@ -60,6 +66,8 @@ func NewLogger(channellen int64) *BeeLogger {
return bl return bl
} }
// SetLogger provides a given logger adapter into BeeLogger with config string.
// config need to be correct JSON as string: {"interval":360}.
func (bl *BeeLogger) SetLogger(adaptername string, config string) error { func (bl *BeeLogger) SetLogger(adaptername string, config string) error {
bl.lock.Lock() bl.lock.Lock()
defer bl.lock.Unlock() defer bl.lock.Unlock()
@ -73,6 +81,7 @@ func (bl *BeeLogger) SetLogger(adaptername string, config string) error {
} }
} }
// remove a logger adapter in BeeLogger.
func (bl *BeeLogger) DelLogger(adaptername string) error { func (bl *BeeLogger) DelLogger(adaptername string) error {
bl.lock.Lock() bl.lock.Lock()
defer bl.lock.Unlock() defer bl.lock.Unlock()
@ -96,10 +105,14 @@ func (bl *BeeLogger) writerMsg(loglevel int, msg string) error {
return nil return nil
} }
// set log message level.
// if message level (such as LevelTrace) is less than logger level (such as LevelWarn), ignore message.
func (bl *BeeLogger) SetLevel(l int) { func (bl *BeeLogger) SetLevel(l int) {
bl.level = l bl.level = l
} }
// start logger chan reading.
// when chan is full, write logs.
func (bl *BeeLogger) StartLogger() { func (bl *BeeLogger) StartLogger() {
for { for {
select { select {
@ -111,43 +124,50 @@ func (bl *BeeLogger) StartLogger() {
} }
} }
// log trace level message.
func (bl *BeeLogger) Trace(format string, v ...interface{}) { func (bl *BeeLogger) Trace(format string, v ...interface{}) {
msg := fmt.Sprintf("[T] "+format, v...) msg := fmt.Sprintf("[T] "+format, v...)
bl.writerMsg(LevelTrace, msg) bl.writerMsg(LevelTrace, msg)
} }
// log debug level message.
func (bl *BeeLogger) Debug(format string, v ...interface{}) { func (bl *BeeLogger) Debug(format string, v ...interface{}) {
msg := fmt.Sprintf("[D] "+format, v...) msg := fmt.Sprintf("[D] "+format, v...)
bl.writerMsg(LevelDebug, msg) bl.writerMsg(LevelDebug, msg)
} }
// log info level message.
func (bl *BeeLogger) Info(format string, v ...interface{}) { func (bl *BeeLogger) Info(format string, v ...interface{}) {
msg := fmt.Sprintf("[I] "+format, v...) msg := fmt.Sprintf("[I] "+format, v...)
bl.writerMsg(LevelInfo, msg) bl.writerMsg(LevelInfo, msg)
} }
// log warn level message.
func (bl *BeeLogger) Warn(format string, v ...interface{}) { func (bl *BeeLogger) Warn(format string, v ...interface{}) {
msg := fmt.Sprintf("[W] "+format, v...) msg := fmt.Sprintf("[W] "+format, v...)
bl.writerMsg(LevelWarn, msg) bl.writerMsg(LevelWarn, msg)
} }
// log error level message.
func (bl *BeeLogger) Error(format string, v ...interface{}) { func (bl *BeeLogger) Error(format string, v ...interface{}) {
msg := fmt.Sprintf("[E] "+format, v...) msg := fmt.Sprintf("[E] "+format, v...)
bl.writerMsg(LevelError, msg) bl.writerMsg(LevelError, msg)
} }
// log critical level message.
func (bl *BeeLogger) Critical(format string, v ...interface{}) { func (bl *BeeLogger) Critical(format string, v ...interface{}) {
msg := fmt.Sprintf("[C] "+format, v...) msg := fmt.Sprintf("[C] "+format, v...)
bl.writerMsg(LevelCritical, msg) bl.writerMsg(LevelCritical, msg)
} }
//flush all chan data // flush all chan data.
func (bl *BeeLogger) Flush() { func (bl *BeeLogger) Flush() {
for _, l := range bl.outputs { for _, l := range bl.outputs {
l.Flush() l.Flush()
} }
} }
// close logger, flush all chan data and destroy all adapters in BeeLogger.
func (bl *BeeLogger) Close() { func (bl *BeeLogger) Close() {
for { for {
if len(bl.msg) > 0 { if len(bl.msg) > 0 {

View File

@ -12,7 +12,7 @@ const (
subjectPhrase = "Diagnostic message from server" subjectPhrase = "Diagnostic message from server"
) )
// smtpWriter is used to send emails via given SMTP-server. // smtpWriter implements LoggerInterface and is used to send emails via given SMTP-server.
type SmtpWriter struct { type SmtpWriter struct {
Username string `json:"Username"` Username string `json:"Username"`
Password string `json:"password"` Password string `json:"password"`
@ -22,10 +22,21 @@ type SmtpWriter struct {
Level int `json:"level"` Level int `json:"level"`
} }
// create smtp writer.
func NewSmtpWriter() LoggerInterface { func NewSmtpWriter() LoggerInterface {
return &SmtpWriter{Level: LevelTrace} return &SmtpWriter{Level: LevelTrace}
} }
// init smtp writer with json config.
// config like:
// {
// "Username":"example@gmail.com",
// "password:"password",
// "host":"smtp.gmail.com:465",
// "subject":"email title",
// "sendTos":["email1","email2"],
// "level":LevelError
// }
func (s *SmtpWriter) Init(jsonconfig string) error { func (s *SmtpWriter) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), s) err := json.Unmarshal([]byte(jsonconfig), s)
if err != nil { if err != nil {
@ -34,6 +45,8 @@ func (s *SmtpWriter) Init(jsonconfig string) error {
return nil return nil
} }
// write message in smtp writer.
// it will send an email with subject and only this message.
func (s *SmtpWriter) WriteMsg(msg string, level int) error { func (s *SmtpWriter) WriteMsg(msg string, level int) error {
if level < s.Level { if level < s.Level {
return nil return nil
@ -65,9 +78,12 @@ func (s *SmtpWriter) WriteMsg(msg string, level int) error {
return err return err
} }
// implementing method. empty.
func (s *SmtpWriter) Flush() { func (s *SmtpWriter) Flush() {
return return
} }
// implementing method. empty.
func (s *SmtpWriter) Destroy() { func (s *SmtpWriter) Destroy() {
return return
} }

View File

@ -5,16 +5,17 @@ import (
"compress/flate" "compress/flate"
"compress/gzip" "compress/gzip"
"errors" "errors"
//"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"os" "os"
"strings" "strings"
"sync"
"time" "time"
) )
var gmfim map[string]*MemFileInfo = make(map[string]*MemFileInfo) var gmfim map[string]*MemFileInfo = make(map[string]*MemFileInfo)
var lock sync.RWMutex
// OpenMemZipFile returns MemFile object with a compressed static file. // OpenMemZipFile returns MemFile object with a compressed static file.
// it's used for serve static file if gzip enable. // it's used for serve static file if gzip enable.
@ -32,12 +33,12 @@ func OpenMemZipFile(path string, zip string) (*MemFile, error) {
modtime := osfileinfo.ModTime() modtime := osfileinfo.ModTime()
fileSize := osfileinfo.Size() fileSize := osfileinfo.Size()
lock.RLock()
cfi, ok := gmfim[zip+":"+path] cfi, ok := gmfim[zip+":"+path]
lock.RUnlock()
if ok && cfi.ModTime() == modtime && cfi.fileSize == fileSize { if ok && cfi.ModTime() == modtime && cfi.fileSize == fileSize {
//fmt.Printf("read %s file %s from cache\n", zip, path)
} else { } else {
//fmt.Printf("NOT read %s file %s from cache\n", zip, path)
var content []byte var content []byte
if zip == "gzip" { if zip == "gzip" {
//将文件内容压缩到zipbuf中 //将文件内容压缩到zipbuf中
@ -81,8 +82,9 @@ func OpenMemZipFile(path string, zip string) (*MemFile, error) {
} }
cfi = &MemFileInfo{osfileinfo, modtime, content, int64(len(content)), fileSize} cfi = &MemFileInfo{osfileinfo, modtime, content, int64(len(content)), fileSize}
lock.Lock()
defer lock.Unlock()
gmfim[zip+":"+path] = cfi gmfim[zip+":"+path] = cfi
//fmt.Printf("%s file %s to %d, cache it\n", zip, path, len(content))
} }
return &MemFile{fi: cfi, offset: 0}, nil return &MemFile{fi: cfi, offset: 0}, nil
} }

View File

@ -61,6 +61,7 @@ var tpl = `
</html> </html>
` `
// render default application error page with error and stack string.
func ShowErr(err interface{}, rw http.ResponseWriter, r *http.Request, Stack string) { func ShowErr(err interface{}, rw http.ResponseWriter, r *http.Request, Stack string) {
t, _ := template.New("beegoerrortemp").Parse(tpl) t, _ := template.New("beegoerrortemp").Parse(tpl)
data := make(map[string]string) data := make(map[string]string)
@ -71,6 +72,7 @@ func ShowErr(err interface{}, rw http.ResponseWriter, r *http.Request, Stack str
data["Stack"] = Stack data["Stack"] = Stack
data["BeegoVersion"] = VERSION data["BeegoVersion"] = VERSION
data["GoVersion"] = runtime.Version() data["GoVersion"] = runtime.Version()
rw.WriteHeader(500)
t.Execute(rw, data) t.Execute(rw, data)
} }
@ -174,18 +176,19 @@ var errtpl = `
</html> </html>
` `
// map of http handlers for each error string.
var ErrorMaps map[string]http.HandlerFunc var ErrorMaps map[string]http.HandlerFunc
func init() { func init() {
ErrorMaps = make(map[string]http.HandlerFunc) ErrorMaps = make(map[string]http.HandlerFunc)
} }
//404 // show 404 notfound error.
func NotFound(rw http.ResponseWriter, r *http.Request) { func NotFound(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{}) data := make(map[string]interface{})
data["Title"] = "Page Not Found" data["Title"] = "Page Not Found"
data["Content"] = template.HTML("<br>The Page You have requested flown the coop." + data["Content"] = template.HTML("<br>The page you have requested has flown the coop." +
"<br>Perhaps you are here because:" + "<br>Perhaps you are here because:" +
"<br><br><ul>" + "<br><br><ul>" +
"<br>The page has moved" + "<br>The page has moved" +
@ -198,28 +201,28 @@ func NotFound(rw http.ResponseWriter, r *http.Request) {
t.Execute(rw, data) t.Execute(rw, data)
} }
//401 // show 401 unauthorized error.
func Unauthorized(rw http.ResponseWriter, r *http.Request) { func Unauthorized(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{}) data := make(map[string]interface{})
data["Title"] = "Unauthorized" data["Title"] = "Unauthorized"
data["Content"] = template.HTML("<br>The Page You have requested can't authorized." + data["Content"] = template.HTML("<br>The page you have requested can't be authorized." +
"<br>Perhaps you are here because:" + "<br>Perhaps you are here because:" +
"<br><br><ul>" + "<br><br><ul>" +
"<br>Check the credentials that you supplied" + "<br>The credentials you supplied are incorrect" +
"<br>Check the address for errors" + "<br>There are errors in the website address" +
"</ul>") "</ul>")
data["BeegoVersion"] = VERSION data["BeegoVersion"] = VERSION
//rw.WriteHeader(http.StatusUnauthorized) //rw.WriteHeader(http.StatusUnauthorized)
t.Execute(rw, data) t.Execute(rw, data)
} }
//403 // show 403 forbidden error.
func Forbidden(rw http.ResponseWriter, r *http.Request) { func Forbidden(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{}) data := make(map[string]interface{})
data["Title"] = "Forbidden" data["Title"] = "Forbidden"
data["Content"] = template.HTML("<br>The Page You have requested forbidden." + data["Content"] = template.HTML("<br>The page you have requested is forbidden." +
"<br>Perhaps you are here because:" + "<br>Perhaps you are here because:" +
"<br><br><ul>" + "<br><br><ul>" +
"<br>Your address may be blocked" + "<br>Your address may be blocked" +
@ -231,12 +234,12 @@ func Forbidden(rw http.ResponseWriter, r *http.Request) {
t.Execute(rw, data) t.Execute(rw, data)
} }
//503 // show 503 service unavailable error.
func ServiceUnavailable(rw http.ResponseWriter, r *http.Request) { func ServiceUnavailable(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{}) data := make(map[string]interface{})
data["Title"] = "Service Unavailable" data["Title"] = "Service Unavailable"
data["Content"] = template.HTML("<br>The Page You have requested unavailable." + data["Content"] = template.HTML("<br>The page you have requested is unavailable." +
"<br>Perhaps you are here because:" + "<br>Perhaps you are here because:" +
"<br><br><ul>" + "<br><br><ul>" +
"<br><br>The page is overloaded" + "<br><br>The page is overloaded" +
@ -247,30 +250,32 @@ func ServiceUnavailable(rw http.ResponseWriter, r *http.Request) {
t.Execute(rw, data) t.Execute(rw, data)
} }
//500 // show 500 internal server error.
func InternalServerError(rw http.ResponseWriter, r *http.Request) { func InternalServerError(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{}) data := make(map[string]interface{})
data["Title"] = "Internal Server Error" data["Title"] = "Internal Server Error"
data["Content"] = template.HTML("<br>The Page You have requested has down now." + data["Content"] = template.HTML("<br>The page you have requested is down right now." +
"<br><br><ul>" + "<br><br><ul>" +
"<br>simply try again later" + "<br>Please try again later and report the error to the website administrator" +
"<br>you should report the fault to the website administrator" + "<br></ul>")
"</ul>")
data["BeegoVersion"] = VERSION data["BeegoVersion"] = VERSION
//rw.WriteHeader(http.StatusInternalServerError) //rw.WriteHeader(http.StatusInternalServerError)
t.Execute(rw, data) t.Execute(rw, data)
} }
// show 500 internal error with simple text string.
func SimpleServerError(rw http.ResponseWriter, r *http.Request) { func SimpleServerError(rw http.ResponseWriter, r *http.Request) {
http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
} }
// add http handler for given error string.
func Errorhandler(err string, h http.HandlerFunc) { func Errorhandler(err string, h http.HandlerFunc) {
ErrorMaps[err] = h ErrorMaps[err] = h
} }
func RegisterErrorHander() { // register default error http handlers, 404,401,403,500 and 503.
func RegisterErrorHandler() {
if _, ok := ErrorMaps["404"]; !ok { if _, ok := ErrorMaps["404"]; !ok {
ErrorMaps["404"] = NotFound ErrorMaps["404"] = NotFound
} }
@ -292,6 +297,8 @@ func RegisterErrorHander() {
} }
} }
// show error string as simple text message.
// if error string is empty, show 500 error as default.
func Exception(errcode string, w http.ResponseWriter, r *http.Request, msg string) { func Exception(errcode string, w http.ResponseWriter, r *http.Request, msg string) {
if h, ok := ErrorMaps[errcode]; ok { if h, ok := ErrorMaps[errcode]; ok {
isint, err := strconv.Atoi(errcode) isint, err := strconv.Atoi(errcode)

View File

@ -2,16 +2,19 @@ package middleware
import "fmt" import "fmt"
// http exceptions
type HTTPException struct { type HTTPException struct {
StatusCode int // http status code 4xx, 5xx StatusCode int // http status code 4xx, 5xx
Description string Description string
} }
// return http exception error string, e.g. "400 Bad Request".
func (e *HTTPException) Error() string { func (e *HTTPException) Error() string {
// return `status description`, e.g. `400 Bad Request`
return fmt.Sprintf("%d %s", e.StatusCode, e.Description) return fmt.Sprintf("%d %s", e.StatusCode, e.Description)
} }
// map of http exceptions for each http status code int.
// defined 400,401,403,404,405,500,502,503 and 504 default.
var HTTPExceptionMaps map[int]HTTPException var HTTPExceptionMaps map[int]HTTPException
func init() { func init() {

View File

@ -544,8 +544,9 @@ var mimemaps map[string]string = map[string]string{
".mustache": "text/html", ".mustache": "text/html",
} }
func initMime() { func initMime() error {
for k, v := range mimemaps { for k, v := range mimemaps {
mime.AddExtensionType(k, v) mime.AddExtensionType(k, v)
} }
return nil
} }

View File

@ -16,6 +16,7 @@ var (
commands = make(map[string]commander) commands = make(map[string]commander)
) )
// print help.
func printHelp(errs ...string) { func printHelp(errs ...string) {
content := `orm command usage: content := `orm command usage:
@ -31,6 +32,7 @@ func printHelp(errs ...string) {
os.Exit(2) os.Exit(2)
} }
// listen for orm command and then run it if command arguments passed.
func RunCommand() { func RunCommand() {
if len(os.Args) < 2 || os.Args[1] != "orm" { if len(os.Args) < 2 || os.Args[1] != "orm" {
return return
@ -58,6 +60,7 @@ func RunCommand() {
} }
} }
// sync database struct command interface.
type commandSyncDb struct { type commandSyncDb struct {
al *alias al *alias
force bool force bool
@ -66,6 +69,7 @@ type commandSyncDb struct {
rtOnError bool rtOnError bool
} }
// parse orm command line arguments.
func (d *commandSyncDb) Parse(args []string) { func (d *commandSyncDb) Parse(args []string) {
var name string var name string
@ -78,6 +82,7 @@ func (d *commandSyncDb) Parse(args []string) {
d.al = getDbAlias(name) d.al = getDbAlias(name)
} }
// run orm line command.
func (d *commandSyncDb) Run() error { func (d *commandSyncDb) Run() error {
var drops []string var drops []string
if d.force { if d.force {
@ -208,10 +213,12 @@ func (d *commandSyncDb) Run() error {
return nil return nil
} }
// database creation commander interface implement.
type commandSqlAll struct { type commandSqlAll struct {
al *alias al *alias
} }
// parse orm command line arguments.
func (d *commandSqlAll) Parse(args []string) { func (d *commandSqlAll) Parse(args []string) {
var name string var name string
@ -222,6 +229,7 @@ func (d *commandSqlAll) Parse(args []string) {
d.al = getDbAlias(name) d.al = getDbAlias(name)
} }
// run orm line command.
func (d *commandSqlAll) Run() error { func (d *commandSqlAll) Run() error {
sqls, indexes := getDbCreateSql(d.al) sqls, indexes := getDbCreateSql(d.al)
var all []string var all []string
@ -243,6 +251,10 @@ func init() {
commands["sqlall"] = new(commandSqlAll) commands["sqlall"] = new(commandSqlAll)
} }
// run syncdb command line.
// name means table's alias name. default is "default".
// force means run next sql if the current is error.
// verbose means show all info when running command or not.
func RunSyncdb(name string, force bool, verbose bool) error { func RunSyncdb(name string, force bool, verbose bool) error {
BootStrap() BootStrap()

View File

@ -12,6 +12,7 @@ type dbIndex struct {
Sql string Sql string
} }
// create database drop sql.
func getDbDropSql(al *alias) (sqls []string) { func getDbDropSql(al *alias) (sqls []string) {
if len(modelCache.cache) == 0 { if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model") fmt.Println("no Model found, need register your model")
@ -26,6 +27,7 @@ func getDbDropSql(al *alias) (sqls []string) {
return sqls return sqls
} }
// get database column type string.
func getColumnTyp(al *alias, fi *fieldInfo) (col string) { func getColumnTyp(al *alias, fi *fieldInfo) (col string) {
T := al.DbBaser.DbTypes() T := al.DbBaser.DbTypes()
fieldType := fi.fieldType fieldType := fi.fieldType
@ -79,6 +81,7 @@ checkColumn:
return return
} }
// create alter sql string.
func getColumnAddQuery(al *alias, fi *fieldInfo) string { func getColumnAddQuery(al *alias, fi *fieldInfo) string {
Q := al.DbBaser.TableQuote() Q := al.DbBaser.TableQuote()
typ := getColumnTyp(al, fi) typ := getColumnTyp(al, fi)
@ -90,6 +93,7 @@ func getColumnAddQuery(al *alias, fi *fieldInfo) string {
return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s", Q, fi.mi.table, Q, Q, fi.column, Q, typ) return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s", Q, fi.mi.table, Q, Q, fi.column, Q, typ)
} }
// create database creation string.
func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) { func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
if len(modelCache.cache) == 0 { if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model") fmt.Println("no Model found, need register your model")

171
orm/db.go
View File

@ -15,7 +15,7 @@ const (
) )
var ( var (
ErrMissPK = errors.New("missed pk value") ErrMissPK = errors.New("missed pk value") // missing pk error
) )
var ( var (
@ -45,13 +45,22 @@ var (
} }
) )
// an instance of dbBaser interface/
type dbBase struct { type dbBase struct {
ins dbBaser ins dbBaser
} }
// check dbBase implements dbBaser interface.
var _ dbBaser = new(dbBase) var _ dbBaser = new(dbBase)
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) { // get struct columns values as interface slice.
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, err error) {
var columns []string
if names != nil {
columns = *names
}
for _, column := range cols { for _, column := range cols {
var fi *fieldInfo var fi *fieldInfo
if fi, _ = mi.fields.GetByAny(column); fi != nil { if fi, _ = mi.fields.GetByAny(column); fi != nil {
@ -64,14 +73,24 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
} }
value, err := d.collectFieldValue(mi, fi, ind, insert, tz) value, err := d.collectFieldValue(mi, fi, ind, insert, tz)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
if names != nil {
columns = append(columns, column) columns = append(columns, column)
}
values = append(values, value) values = append(values, value)
} }
if names != nil {
*names = columns
}
return return
} }
// get one field value in struct column as interface.
func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) { func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) {
var value interface{} var value interface{}
if fi.pk { if fi.pk {
@ -140,6 +159,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
return value, nil return value, nil
} }
// create insert sql preparation statement object.
func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) {
Q := d.ins.TableQuote() Q := d.ins.TableQuote()
@ -165,8 +185,9 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
return stmt, query, err return stmt, query, err
} }
// insert struct with prepared statement and given struct reflect value.
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
_, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz) values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -185,6 +206,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
} }
} }
// query sql ,read records and persist in dbBaser.
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) error { func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) error {
var whereCols []string var whereCols []string
var args []interface{} var args []interface{}
@ -192,7 +214,8 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
// if specify cols length > 0, then use it for where condition. // if specify cols length > 0, then use it for where condition.
if len(cols) > 0 { if len(cols) > 0 {
var err error var err error
whereCols, args, err = d.collectValues(mi, ind, cols, false, false, tz) whereCols = make([]string, 0, len(cols))
args, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz)
if err != nil { if err != nil {
return err return err
} }
@ -202,7 +225,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
if ok == false { if ok == false {
return ErrMissPK return ErrMissPK
} }
whereCols = append(whereCols, pkColumn) whereCols = []string{pkColumn}
args = append(args, pkValue) args = append(args, pkValue)
} }
@ -243,16 +266,77 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
return nil return nil
} }
// execute insert sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
names, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz) names := make([]string, 0, len(mi.fields.dbcols)-1)
values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
return d.InsertValue(q, mi, names, values) return d.InsertValue(q, mi, false, names, values)
} }
func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values []interface{}) (int64, error) { // multi-insert sql with given slice struct reflect.Value.
func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) {
var (
cnt int64
nums int
values []interface{}
names []string
)
// typ := reflect.Indirect(mi.addrField).Type()
length := sind.Len()
for i := 1; i <= length; i++ {
ind := reflect.Indirect(sind.Index(i - 1))
// Is this needed ?
// if !ind.Type().AssignableTo(typ) {
// return cnt, ErrArgs
// }
if i == 1 {
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
if err != nil {
return cnt, err
}
values = make([]interface{}, bulk*len(vus))
nums += copy(values, vus)
} else {
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
if err != nil {
return cnt, err
}
if len(vus) != len(names) {
return cnt, ErrArgs
}
nums += copy(values[nums:], vus)
}
if i > 1 && i%bulk == 0 || length == i {
num, err := d.InsertValue(q, mi, true, names, values[:nums])
if err != nil {
return cnt, err
}
cnt += num
nums = 0
}
}
return cnt, nil
}
// execute insert sql with given struct and given values.
// insert the given values, not the field values in struct.
func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
Q := d.ins.TableQuote() Q := d.ins.TableQuote()
marks := make([]string, len(names)) marks := make([]string, len(names))
@ -264,36 +348,51 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values
qmarks := strings.Join(marks, ", ") qmarks := strings.Join(marks, ", ")
columns := strings.Join(names, sep) columns := strings.Join(names, sep)
multi := len(values) / len(names)
if isMulti {
qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
}
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
if d.ins.HasReturningID(mi, &query) { if isMulti || !d.ins.HasReturningID(mi, &query) {
row := q.QueryRow(query, values...)
var id int64
err := row.Scan(&id)
return id, err
} else {
if res, err := q.Exec(query, values...); err == nil { if res, err := q.Exec(query, values...); err == nil {
if isMulti {
return res.RowsAffected()
}
return res.LastInsertId() return res.LastInsertId()
} else { } else {
return 0, err return 0, err
} }
} else {
row := q.QueryRow(query, values...)
var id int64
err := row.Scan(&id)
return id, err
} }
} }
// execute update sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
pkName, pkValue, ok := getExistPk(mi, ind) pkName, pkValue, ok := getExistPk(mi, ind)
if ok == false { if ok == false {
return 0, ErrMissPK return 0, ErrMissPK
} }
var setNames []string
// if specify cols length is zero, then commit all columns. // if specify cols length is zero, then commit all columns.
if len(cols) == 0 { if len(cols) == 0 {
cols = mi.fields.dbcols cols = mi.fields.dbcols
setNames = make([]string, 0, len(mi.fields.dbcols)-1)
} else {
setNames = make([]string, 0, len(cols))
} }
setNames, setValues, err := d.collectValues(mi, ind, cols, true, false, tz) setValues, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -317,6 +416,8 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
return 0, nil return 0, nil
} }
// execute delete sql dbQuerier with given struct reflect.Value.
// delete index is pk.
func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
pkName, pkValue, ok := getExistPk(mi, ind) pkName, pkValue, ok := getExistPk(mi, ind)
if ok == false { if ok == false {
@ -358,6 +459,8 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
return 0, nil return 0, nil
} }
// update table-related record by querySet.
// need querySet not struct reflect.Value to update related records.
func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) {
columns := make([]string, 0, len(params)) columns := make([]string, 0, len(params))
values := make([]interface{}, 0, len(params)) values := make([]interface{}, 0, len(params))
@ -433,6 +536,8 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
return 0, nil return 0, nil
} }
// delete related records.
// do UpdateBanch or DeleteBanch by condition of tables' relationship.
func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error {
for _, fi := range mi.fields.fieldsReverse { for _, fi := range mi.fields.fieldsReverse {
fi = fi.reverseFieldInfo fi = fi.reverseFieldInfo
@ -459,8 +564,11 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *
return nil return nil
} }
// delete table-related records.
func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) {
tables := newDbTables(mi, d.ins) tables := newDbTables(mi, d.ins)
tables.skipEnd = true
if qs != nil { if qs != nil {
tables.parseRelated(qs.related, qs.relDepth) tables.parseRelated(qs.related, qs.relDepth)
} }
@ -486,6 +594,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
rs = r rs = r
} }
defer rs.Close()
var ref interface{} var ref interface{}
args = make([]interface{}, 0) args = make([]interface{}, 0)
@ -532,6 +642,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
return 0, nil return 0, nil
} }
// read related records.
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
val := reflect.ValueOf(container) val := reflect.ValueOf(container)
@ -640,6 +751,8 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
refs[i] = &ref refs[i] = &ref
} }
defer rs.Close()
slice := ind slice := ind
var cnt int64 var cnt int64
@ -739,6 +852,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
return cnt, nil return cnt, nil
} }
// excute count sql and return count result int64.
func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
tables := newDbTables(mi, d.ins) tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth) tables.parseRelated(qs.related, qs.relDepth)
@ -759,6 +873,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
return return
} }
// generate sql with replacing operator string placeholders and replaced values.
func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
sql := "" sql := ""
params := getFlatParams(fi, args, tz) params := getFlatParams(fi, args, tz)
@ -812,10 +927,12 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri
return sql, params return sql, params
} }
// gernerate sql string with inner function, such as UPPER(text).
func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) { func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) {
// default not use // default not use
} }
// set values to struct column.
func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) { func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) {
for i, column := range cols { for i, column := range cols {
val := reflect.Indirect(reflect.ValueOf(values[i])).Interface() val := reflect.Indirect(reflect.ValueOf(values[i])).Interface()
@ -837,6 +954,7 @@ func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string,
} }
} }
// convert value from database result to value following in field type.
func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) { func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) {
if val == nil { if val == nil {
return nil, nil return nil, nil
@ -989,6 +1107,7 @@ end:
} }
// set one value to struct column field.
func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) { func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) {
fieldType := fi.fieldType fieldType := fi.fieldType
@ -1063,6 +1182,7 @@ setValue:
return value, nil return value, nil
} }
// query sql, read values , save to *[]ParamList.
func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) {
var ( var (
@ -1150,6 +1270,8 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
refs[i] = &ref refs[i] = &ref
} }
defer rs.Close()
var ( var (
cnt int64 cnt int64
columns []string columns []string
@ -1228,6 +1350,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
return cnt, nil return cnt, nil
} }
// flag of update joined record.
func (d *dbBase) SupportUpdateJoin() bool { func (d *dbBase) SupportUpdateJoin() bool {
return true return true
} }
@ -1236,30 +1359,37 @@ func (d *dbBase) MaxLimit() uint64 {
return 18446744073709551615 return 18446744073709551615
} }
// return quote.
func (d *dbBase) TableQuote() string { func (d *dbBase) TableQuote() string {
return "`" return "`"
} }
// replace value placeholer in parametered sql string.
func (d *dbBase) ReplaceMarks(query *string) { func (d *dbBase) ReplaceMarks(query *string) {
// default use `?` as mark, do nothing // default use `?` as mark, do nothing
} }
// flag of RETURNING sql.
func (d *dbBase) HasReturningID(*modelInfo, *string) bool { func (d *dbBase) HasReturningID(*modelInfo, *string) bool {
return false return false
} }
// convert time from db.
func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) { func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) {
*t = t.In(tz) *t = t.In(tz)
} }
// convert time to db.
func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) { func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) {
*t = t.In(tz) *t = t.In(tz)
} }
// get database types.
func (d *dbBase) DbTypes() map[string]string { func (d *dbBase) DbTypes() map[string]string {
return nil return nil
} }
// gt all tables.
func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) { func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
tables := make(map[string]bool) tables := make(map[string]bool)
query := d.ins.ShowTablesQuery() query := d.ins.ShowTablesQuery()
@ -1268,6 +1398,8 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
return tables, err return tables, err
} }
defer rows.Close()
for rows.Next() { for rows.Next() {
var table string var table string
err := rows.Scan(&table) err := rows.Scan(&table)
@ -1282,6 +1414,7 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
return tables, nil return tables, nil
} }
// get all cloumns in table.
func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
columns := make(map[string][3]string) columns := make(map[string][3]string)
query := d.ins.ShowColumnsQuery(table) query := d.ins.ShowColumnsQuery(table)
@ -1290,6 +1423,8 @@ func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, e
return columns, err return columns, err
} }
defer rows.Close()
for rows.Next() { for rows.Next() {
var ( var (
name string name string
@ -1306,18 +1441,22 @@ func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, e
return columns, nil return columns, nil
} }
// not implement.
func (d *dbBase) OperatorSql(operator string) string { func (d *dbBase) OperatorSql(operator string) string {
panic(ErrNotImplement) panic(ErrNotImplement)
} }
// not implement.
func (d *dbBase) ShowTablesQuery() string { func (d *dbBase) ShowTablesQuery() string {
panic(ErrNotImplement) panic(ErrNotImplement)
} }
// not implement.
func (d *dbBase) ShowColumnsQuery(table string) string { func (d *dbBase) ShowColumnsQuery(table string) string {
panic(ErrNotImplement) panic(ErrNotImplement)
} }
// not implement.
func (d *dbBase) IndexExists(dbQuerier, string, string) bool { func (d *dbBase) IndexExists(dbQuerier, string, string) bool {
panic(ErrNotImplement) panic(ErrNotImplement)
} }

View File

@ -9,27 +9,32 @@ import (
"time" "time"
) )
// database driver constant int.
type DriverType int type DriverType int
const ( const (
_ DriverType = iota _ DriverType = iota // int enum type
DR_MySQL DR_MySQL // mysql
DR_Sqlite DR_Sqlite // sqlite
DR_Oracle DR_Oracle // oracle
DR_Postgres DR_Postgres // pgsql
) )
// database driver string.
type driver string type driver string
// get type constant int of current driver..
func (d driver) Type() DriverType { func (d driver) Type() DriverType {
a, _ := dataBaseCache.get(string(d)) a, _ := dataBaseCache.get(string(d))
return a.Driver return a.Driver
} }
// get name of current driver
func (d driver) Name() string { func (d driver) Name() string {
return string(d) return string(d)
} }
// check driver iis implemented Driver interface or not.
var _ Driver = new(driver) var _ Driver = new(driver)
var ( var (
@ -47,11 +52,13 @@ var (
} }
) )
// database alias cacher.
type _dbCache struct { type _dbCache struct {
mux sync.RWMutex mux sync.RWMutex
cache map[string]*alias cache map[string]*alias
} }
// add database alias with original name.
func (ac *_dbCache) add(name string, al *alias) (added bool) { func (ac *_dbCache) add(name string, al *alias) (added bool) {
ac.mux.Lock() ac.mux.Lock()
defer ac.mux.Unlock() defer ac.mux.Unlock()
@ -62,6 +69,7 @@ func (ac *_dbCache) add(name string, al *alias) (added bool) {
return return
} }
// get database alias if cached.
func (ac *_dbCache) get(name string) (al *alias, ok bool) { func (ac *_dbCache) get(name string) (al *alias, ok bool) {
ac.mux.RLock() ac.mux.RLock()
defer ac.mux.RUnlock() defer ac.mux.RUnlock()
@ -69,6 +77,7 @@ func (ac *_dbCache) get(name string) (al *alias, ok bool) {
return return
} }
// get default alias.
func (ac *_dbCache) getDefault() (al *alias) { func (ac *_dbCache) getDefault() (al *alias) {
al, _ = ac.get("default") al, _ = ac.get("default")
return return
@ -123,21 +132,18 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
switch al.Driver { switch al.Driver {
case DR_MySQL: case DR_MySQL:
row := al.DB.QueryRow("SELECT @@session.time_zone") row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
var tz string var tz string
row.Scan(&tz) row.Scan(&tz)
if tz == "SYSTEM" { if len(tz) >= 8 {
tz = "" if tz[0] != '-' {
row = al.DB.QueryRow("SELECT @@system_time_zone") tz = "+" + tz
row.Scan(&tz)
t, err := time.Parse("MST", tz)
if err == nil {
al.TZ = t.Location()
} }
} else { t, err := time.Parse("-07:00:00", tz)
t, err := time.Parse("-07:00", tz)
if err == nil { if err == nil {
al.TZ = t.Location() al.TZ = t.Location()
} else {
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
} }
} }
@ -163,6 +169,8 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
loc, err := time.LoadLocation(tz) loc, err := time.LoadLocation(tz)
if err == nil { if err == nil {
al.TZ = loc al.TZ = loc
} else {
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
} }
} }

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
) )
// mysql operators.
var mysqlOperators = map[string]string{ var mysqlOperators = map[string]string{
"exact": "= ?", "exact": "= ?",
"iexact": "LIKE ?", "iexact": "LIKE ?",
@ -21,6 +22,7 @@ var mysqlOperators = map[string]string{
"iendswith": "LIKE ?", "iendswith": "LIKE ?",
} }
// mysql column field types.
var mysqlTypes = map[string]string{ var mysqlTypes = map[string]string{
"auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY", "auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY",
"pk": "NOT NULL PRIMARY KEY", "pk": "NOT NULL PRIMARY KEY",
@ -41,29 +43,35 @@ var mysqlTypes = map[string]string{
"float64-decimal": "numeric(%d, %d)", "float64-decimal": "numeric(%d, %d)",
} }
// mysql dbBaser implementation.
type dbBaseMysql struct { type dbBaseMysql struct {
dbBase dbBase
} }
var _ dbBaser = new(dbBaseMysql) var _ dbBaser = new(dbBaseMysql)
// get mysql operator.
func (d *dbBaseMysql) OperatorSql(operator string) string { func (d *dbBaseMysql) OperatorSql(operator string) string {
return mysqlOperators[operator] return mysqlOperators[operator]
} }
// get mysql table field types.
func (d *dbBaseMysql) DbTypes() map[string]string { func (d *dbBaseMysql) DbTypes() map[string]string {
return mysqlTypes return mysqlTypes
} }
// show table sql for mysql.
func (d *dbBaseMysql) ShowTablesQuery() string { func (d *dbBaseMysql) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()" return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
} }
// show columns sql of table for mysql.
func (d *dbBaseMysql) ShowColumnsQuery(table string) string { func (d *dbBaseMysql) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+ return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
"WHERE table_schema = DATABASE() AND table_name = '%s'", table) "WHERE table_schema = DATABASE() AND table_name = '%s'", table)
} }
// execute sql to check index exist.
func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool { func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool {
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
@ -72,6 +80,7 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool
return cnt > 0 return cnt > 0
} }
// create new mysql dbBaser.
func newdbBaseMysql() dbBaser { func newdbBaseMysql() dbBaser {
b := new(dbBaseMysql) b := new(dbBaseMysql)
b.ins = b b.ins = b

View File

@ -1,11 +1,13 @@
package orm package orm
// oracle dbBaser
type dbBaseOracle struct { type dbBaseOracle struct {
dbBase dbBase
} }
var _ dbBaser = new(dbBaseOracle) var _ dbBaser = new(dbBaseOracle)
// create oracle dbBaser.
func newdbBaseOracle() dbBaser { func newdbBaseOracle() dbBaser {
b := new(dbBaseOracle) b := new(dbBaseOracle)
b.ins = b b.ins = b

View File

@ -5,6 +5,7 @@ import (
"strconv" "strconv"
) )
// postgresql operators.
var postgresOperators = map[string]string{ var postgresOperators = map[string]string{
"exact": "= ?", "exact": "= ?",
"iexact": "= UPPER(?)", "iexact": "= UPPER(?)",
@ -20,6 +21,7 @@ var postgresOperators = map[string]string{
"iendswith": "LIKE UPPER(?)", "iendswith": "LIKE UPPER(?)",
} }
// postgresql column field types.
var postgresTypes = map[string]string{ var postgresTypes = map[string]string{
"auto": "serial NOT NULL PRIMARY KEY", "auto": "serial NOT NULL PRIMARY KEY",
"pk": "NOT NULL PRIMARY KEY", "pk": "NOT NULL PRIMARY KEY",
@ -40,16 +42,19 @@ var postgresTypes = map[string]string{
"float64-decimal": "numeric(%d, %d)", "float64-decimal": "numeric(%d, %d)",
} }
// postgresql dbBaser.
type dbBasePostgres struct { type dbBasePostgres struct {
dbBase dbBase
} }
var _ dbBaser = new(dbBasePostgres) var _ dbBaser = new(dbBasePostgres)
// get postgresql operator.
func (d *dbBasePostgres) OperatorSql(operator string) string { func (d *dbBasePostgres) OperatorSql(operator string) string {
return postgresOperators[operator] return postgresOperators[operator]
} }
// generate functioned sql string, such as contains(text).
func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
switch operator { switch operator {
case "contains", "startswith", "endswith": case "contains", "startswith", "endswith":
@ -59,6 +64,7 @@ func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string,
} }
} }
// postgresql unsupports updating joined record.
func (d *dbBasePostgres) SupportUpdateJoin() bool { func (d *dbBasePostgres) SupportUpdateJoin() bool {
return false return false
} }
@ -67,10 +73,13 @@ func (d *dbBasePostgres) MaxLimit() uint64 {
return 0 return 0
} }
// postgresql quote is ".
func (d *dbBasePostgres) TableQuote() string { func (d *dbBasePostgres) TableQuote() string {
return `"` return `"`
} }
// postgresql value placeholder is $n.
// replace default ? to $n.
func (d *dbBasePostgres) ReplaceMarks(query *string) { func (d *dbBasePostgres) ReplaceMarks(query *string) {
q := *query q := *query
num := 0 num := 0
@ -97,6 +106,7 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
*query = string(data) *query = string(data)
} }
// make returning sql support for postgresql.
func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) { func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) {
if mi.fields.pk.auto { if mi.fields.pk.auto {
if query != nil { if query != nil {
@ -107,18 +117,22 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool)
return return
} }
// show table sql for postgresql.
func (d *dbBasePostgres) ShowTablesQuery() string { func (d *dbBasePostgres) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')" return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')"
} }
// show table columns sql for postgresql.
func (d *dbBasePostgres) ShowColumnsQuery(table string) string { func (d *dbBasePostgres) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table) return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table)
} }
// get column types of postgresql.
func (d *dbBasePostgres) DbTypes() map[string]string { func (d *dbBasePostgres) DbTypes() map[string]string {
return postgresTypes return postgresTypes
} }
// check index exist in postgresql.
func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool { func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool {
query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name) query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name)
row := db.QueryRow(query) row := db.QueryRow(query)
@ -127,6 +141,7 @@ func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bo
return cnt > 0 return cnt > 0
} }
// create new postgresql dbBaser.
func newdbBasePostgres() dbBaser { func newdbBasePostgres() dbBaser {
b := new(dbBasePostgres) b := new(dbBasePostgres)
b.ins = b b.ins = b

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
) )
// sqlite operators.
var sqliteOperators = map[string]string{ var sqliteOperators = map[string]string{
"exact": "= ?", "exact": "= ?",
"iexact": "LIKE ? ESCAPE '\\'", "iexact": "LIKE ? ESCAPE '\\'",
@ -20,6 +21,7 @@ var sqliteOperators = map[string]string{
"iendswith": "LIKE ? ESCAPE '\\'", "iendswith": "LIKE ? ESCAPE '\\'",
} }
// sqlite column types.
var sqliteTypes = map[string]string{ var sqliteTypes = map[string]string{
"auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT", "auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT",
"pk": "NOT NULL PRIMARY KEY", "pk": "NOT NULL PRIMARY KEY",
@ -40,38 +42,47 @@ var sqliteTypes = map[string]string{
"float64-decimal": "decimal", "float64-decimal": "decimal",
} }
// sqlite dbBaser.
type dbBaseSqlite struct { type dbBaseSqlite struct {
dbBase dbBase
} }
var _ dbBaser = new(dbBaseSqlite) var _ dbBaser = new(dbBaseSqlite)
// get sqlite operator.
func (d *dbBaseSqlite) OperatorSql(operator string) string { func (d *dbBaseSqlite) OperatorSql(operator string) string {
return sqliteOperators[operator] return sqliteOperators[operator]
} }
// generate functioned sql for sqlite.
// only support DATE(text).
func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
if fi.fieldType == TypeDateField { if fi.fieldType == TypeDateField {
*leftCol = fmt.Sprintf("DATE(%s)", *leftCol) *leftCol = fmt.Sprintf("DATE(%s)", *leftCol)
} }
} }
// unable updating joined record in sqlite.
func (d *dbBaseSqlite) SupportUpdateJoin() bool { func (d *dbBaseSqlite) SupportUpdateJoin() bool {
return false return false
} }
// max int in sqlite.
func (d *dbBaseSqlite) MaxLimit() uint64 { func (d *dbBaseSqlite) MaxLimit() uint64 {
return 9223372036854775807 return 9223372036854775807
} }
// get column types in sqlite.
func (d *dbBaseSqlite) DbTypes() map[string]string { func (d *dbBaseSqlite) DbTypes() map[string]string {
return sqliteTypes return sqliteTypes
} }
// get show tables sql in sqlite.
func (d *dbBaseSqlite) ShowTablesQuery() string { func (d *dbBaseSqlite) ShowTablesQuery() string {
return "SELECT name FROM sqlite_master WHERE type = 'table'" return "SELECT name FROM sqlite_master WHERE type = 'table'"
} }
// get columns in sqlite.
func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
query := d.ins.ShowColumnsQuery(table) query := d.ins.ShowColumnsQuery(table)
rows, err := db.Query(query) rows, err := db.Query(query)
@ -92,10 +103,12 @@ func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]str
return columns, nil return columns, nil
} }
// get show columns sql in sqlite.
func (d *dbBaseSqlite) ShowColumnsQuery(table string) string { func (d *dbBaseSqlite) ShowColumnsQuery(table string) string {
return fmt.Sprintf("pragma table_info('%s')", table) return fmt.Sprintf("pragma table_info('%s')", table)
} }
// check index exist in sqlite.
func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool { func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool {
query := fmt.Sprintf("PRAGMA index_list('%s')", table) query := fmt.Sprintf("PRAGMA index_list('%s')", table)
rows, err := db.Query(query) rows, err := db.Query(query)
@ -113,6 +126,7 @@ func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool
return false return false
} }
// create new sqlite dbBaser.
func newdbBaseSqlite() dbBaser { func newdbBaseSqlite() dbBaser {
b := new(dbBaseSqlite) b := new(dbBaseSqlite)
b.ins = b b.ins = b

View File

@ -6,6 +6,7 @@ import (
"time" "time"
) )
// table info struct.
type dbTable struct { type dbTable struct {
id int id int
index string index string
@ -18,13 +19,17 @@ type dbTable struct {
jtl *dbTable jtl *dbTable
} }
// tables collection struct, contains some tables.
type dbTables struct { type dbTables struct {
tablesM map[string]*dbTable tablesM map[string]*dbTable
tables []*dbTable tables []*dbTable
mi *modelInfo mi *modelInfo
base dbBaser base dbBaser
skipEnd bool
} }
// set table info to collection.
// if not exist, create new.
func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable { func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
name := strings.Join(names, ExprSep) name := strings.Join(names, ExprSep)
if j, ok := t.tablesM[name]; ok { if j, ok := t.tablesM[name]; ok {
@ -41,6 +46,7 @@ func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool)
return t.tablesM[name] return t.tablesM[name]
} }
// add table info to collection.
func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) { func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
name := strings.Join(names, ExprSep) name := strings.Join(names, ExprSep)
if _, ok := t.tablesM[name]; ok == false { if _, ok := t.tablesM[name]; ok == false {
@ -53,11 +59,14 @@ func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool)
return t.tablesM[name], false return t.tablesM[name], false
} }
// get table info in collection.
func (t *dbTables) get(name string) (*dbTable, bool) { func (t *dbTables) get(name string) (*dbTable, bool) {
j, ok := t.tablesM[name] j, ok := t.tablesM[name]
return j, ok return j, ok
} }
// get related fields info in recursive depth loop.
// loop once, depth decreases one.
func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string { func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
if depth < 0 || fi.fieldType == RelManyToMany { if depth < 0 || fi.fieldType == RelManyToMany {
return related return related
@ -78,6 +87,7 @@ func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []
return related return related
} }
// parse related fields.
func (t *dbTables) parseRelated(rels []string, depth int) { func (t *dbTables) parseRelated(rels []string, depth int) {
relsNum := len(rels) relsNum := len(rels)
@ -111,7 +121,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
names = append(names, fi.name) names = append(names, fi.name)
mmi = fi.relModelInfo mmi = fi.relModelInfo
if fi.null { if fi.null || t.skipEnd {
inner = false inner = false
} }
@ -139,6 +149,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
} }
} }
// generate join string.
func (t *dbTables) getJoinSql() (join string) { func (t *dbTables) getJoinSql() (join string) {
Q := t.base.TableQuote() Q := t.base.TableQuote()
@ -185,9 +196,12 @@ func (t *dbTables) getJoinSql() (join string) {
return return
} }
// parse orm model struct field tag expression.
func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) { func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) {
var ( var (
jtl *dbTable jtl *dbTable
fi *fieldInfo
fiN *fieldInfo
mmi = mi mmi = mi
) )
@ -196,9 +210,22 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
inner := true inner := true
loopFor:
for i, ex := range exprs { for i, ex := range exprs {
fi, ok := mmi.fields.GetByAny(ex) var ok, okN bool
if fiN != nil {
fi = fiN
ok = true
fiN = nil
}
if i == 0 {
fi, ok = mmi.fields.GetByAny(ex)
}
_ = okN
if ok { if ok {
@ -216,17 +243,33 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
mmi = fi.reverseFieldInfo.mi mmi = fi.reverseFieldInfo.mi
} }
if i < num {
fiN, okN = mmi.fields.GetByAny(exprs[i+1])
}
if isRel && (fi.mi.isThrough == false || num != i) { if isRel && (fi.mi.isThrough == false || num != i) {
if fi.null { if fi.null || t.skipEnd {
inner = false inner = false
} }
if t.skipEnd && okN || !t.skipEnd {
if t.skipEnd && okN && fiN.pk {
goto loopEnd
}
jt, _ := t.add(names, mmi, fi, inner) jt, _ := t.add(names, mmi, fi, inner)
jt.jtl = jtl jt.jtl = jtl
jtl = jt jtl = jt
} }
if num == i { }
if num != i {
continue
}
loopEnd:
if i == 0 || jtl == nil { if i == 0 || jtl == nil {
index = "T0" index = "T0"
} else { } else {
@ -252,7 +295,8 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
name = info.name name = info.name
} }
} }
}
break loopFor
} else { } else {
index = "" index = ""
@ -267,6 +311,7 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
return return
} }
// generate condition sql.
func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) { func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
if cond == nil || cond.IsEmpty() { if cond == nil || cond.IsEmpty() {
return return
@ -331,6 +376,7 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
return return
} }
// generate order sql.
func (t *dbTables) getOrderSql(orders []string) (orderSql string) { func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
if len(orders) == 0 { if len(orders) == 0 {
return return
@ -359,6 +405,7 @@ func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
return return
} }
// generate limit sql.
func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits string) { func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits string) {
if limit == 0 { if limit == 0 {
limit = int64(DefaultRowsLimit) limit = int64(DefaultRowsLimit)
@ -381,6 +428,7 @@ func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits
return return
} }
// crete new tables collection.
func newDbTables(mi *modelInfo, base dbBaser) *dbTables { func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
tables := &dbTables{} tables := &dbTables{}
tables.tablesM = make(map[string]*dbTable) tables.tablesM = make(map[string]*dbTable)

View File

@ -6,6 +6,7 @@ import (
"time" "time"
) )
// get table alias.
func getDbAlias(name string) *alias { func getDbAlias(name string) *alias {
if al, ok := dataBaseCache.get(name); ok { if al, ok := dataBaseCache.get(name); ok {
return al return al
@ -15,6 +16,7 @@ func getDbAlias(name string) *alias {
return nil return nil
} }
// get pk column info.
func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) { func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
fi := mi.fields.pk fi := mi.fields.pk
@ -37,6 +39,7 @@ func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interfac
return return
} }
// get fields description as flatted string.
func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) { func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) {
outFor: outFor:

View File

@ -41,6 +41,7 @@ var (
} }
) )
// model info collection
type _modelCache struct { type _modelCache struct {
sync.RWMutex sync.RWMutex
orders []string orders []string
@ -49,6 +50,7 @@ type _modelCache struct {
done bool done bool
} }
// get all model info
func (mc *_modelCache) all() map[string]*modelInfo { func (mc *_modelCache) all() map[string]*modelInfo {
m := make(map[string]*modelInfo, len(mc.cache)) m := make(map[string]*modelInfo, len(mc.cache))
for k, v := range mc.cache { for k, v := range mc.cache {
@ -57,6 +59,7 @@ func (mc *_modelCache) all() map[string]*modelInfo {
return m return m
} }
// get orderd model info
func (mc *_modelCache) allOrdered() []*modelInfo { func (mc *_modelCache) allOrdered() []*modelInfo {
m := make([]*modelInfo, 0, len(mc.orders)) m := make([]*modelInfo, 0, len(mc.orders))
for _, table := range mc.orders { for _, table := range mc.orders {
@ -65,16 +68,19 @@ func (mc *_modelCache) allOrdered() []*modelInfo {
return m return m
} }
// get model info by table name
func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) { func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) {
mi, ok = mc.cache[table] mi, ok = mc.cache[table]
return return
} }
// get model info by field name
func (mc *_modelCache) getByFN(name string) (mi *modelInfo, ok bool) { func (mc *_modelCache) getByFN(name string) (mi *modelInfo, ok bool) {
mi, ok = mc.cacheByFN[name] mi, ok = mc.cacheByFN[name]
return return
} }
// set model info to collection
func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo { func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
mii := mc.cache[table] mii := mc.cache[table]
mc.cache[table] = mi mc.cache[table] = mi
@ -85,6 +91,7 @@ func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
return mii return mii
} }
// clean all model info.
func (mc *_modelCache) clean() { func (mc *_modelCache) clean() {
mc.orders = make([]string, 0) mc.orders = make([]string, 0)
mc.cache = make(map[string]*modelInfo) mc.cache = make(map[string]*modelInfo)

View File

@ -8,6 +8,8 @@ import (
"strings" "strings"
) )
// register models.
// prefix means table name prefix.
func registerModel(model interface{}, prefix string) { func registerModel(model interface{}, prefix string) {
val := reflect.ValueOf(model) val := reflect.ValueOf(model)
ind := reflect.Indirect(val) ind := reflect.Indirect(val)
@ -67,6 +69,7 @@ func registerModel(model interface{}, prefix string) {
modelCache.set(table, info) modelCache.set(table, info)
} }
// boostrap models
func bootStrap() { func bootStrap() {
if modelCache.done { if modelCache.done {
return return
@ -281,6 +284,7 @@ end:
} }
} }
// register models
func RegisterModel(models ...interface{}) { func RegisterModel(models ...interface{}) {
if modelCache.done { if modelCache.done {
panic(fmt.Errorf("RegisterModel must be run before BootStrap")) panic(fmt.Errorf("RegisterModel must be run before BootStrap"))
@ -302,6 +306,8 @@ func RegisterModelWithPrefix(prefix string, models ...interface{}) {
} }
} }
// bootrap models.
// make all model parsed and can not add more models
func BootStrap() { func BootStrap() {
if modelCache.done { if modelCache.done {
return return

View File

@ -9,6 +9,7 @@ import (
var errSkipField = errors.New("skip field") var errSkipField = errors.New("skip field")
// field info collection
type fields struct { type fields struct {
pk *fieldInfo pk *fieldInfo
columns map[string]*fieldInfo columns map[string]*fieldInfo
@ -23,6 +24,7 @@ type fields struct {
dbcols []string dbcols []string
} }
// add field info
func (f *fields) Add(fi *fieldInfo) (added bool) { func (f *fields) Add(fi *fieldInfo) (added bool) {
if f.fields[fi.name] == nil && f.columns[fi.column] == nil { if f.fields[fi.name] == nil && f.columns[fi.column] == nil {
f.columns[fi.column] = fi f.columns[fi.column] = fi
@ -49,14 +51,17 @@ func (f *fields) Add(fi *fieldInfo) (added bool) {
return true return true
} }
// get field info by name
func (f *fields) GetByName(name string) *fieldInfo { func (f *fields) GetByName(name string) *fieldInfo {
return f.fields[name] return f.fields[name]
} }
// get field info by column name
func (f *fields) GetByColumn(column string) *fieldInfo { func (f *fields) GetByColumn(column string) *fieldInfo {
return f.columns[column] return f.columns[column]
} }
// get field info by string, name is prior
func (f *fields) GetByAny(name string) (*fieldInfo, bool) { func (f *fields) GetByAny(name string) (*fieldInfo, bool) {
if fi, ok := f.fields[name]; ok { if fi, ok := f.fields[name]; ok {
return fi, ok return fi, ok
@ -70,6 +75,7 @@ func (f *fields) GetByAny(name string) (*fieldInfo, bool) {
return nil, false return nil, false
} }
// create new field info collection
func newFields() *fields { func newFields() *fields {
f := new(fields) f := new(fields)
f.fields = make(map[string]*fieldInfo) f.fields = make(map[string]*fieldInfo)
@ -79,6 +85,7 @@ func newFields() *fields {
return f return f
} }
// single field info
type fieldInfo struct { type fieldInfo struct {
mi *modelInfo mi *modelInfo
fieldIndex int fieldIndex int
@ -115,6 +122,7 @@ type fieldInfo struct {
onDelete string onDelete string
} }
// new field info
func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField) (fi *fieldInfo, err error) { func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField) (fi *fieldInfo, err error) {
var ( var (
tag string tag string

View File

@ -7,6 +7,7 @@ import (
"reflect" "reflect"
) )
// single model info
type modelInfo struct { type modelInfo struct {
pkg string pkg string
name string name string
@ -20,6 +21,7 @@ type modelInfo struct {
isThrough bool isThrough bool
} }
// new model info
func newModelInfo(val reflect.Value) (info *modelInfo) { func newModelInfo(val reflect.Value) (info *modelInfo) {
var ( var (
err error err error
@ -79,6 +81,8 @@ func newModelInfo(val reflect.Value) (info *modelInfo) {
return return
} }
// combine related model info to new model info.
// prepare for relation models query.
func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) { func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
info = new(modelInfo) info = new(modelInfo)
info.fields = newFields() info.fields = newFields()

View File

@ -7,10 +7,12 @@ import (
"time" "time"
) )
// get reflect.Type name with package path.
func getFullName(typ reflect.Type) string { func getFullName(typ reflect.Type) string {
return typ.PkgPath() + "." + typ.Name() return typ.PkgPath() + "." + typ.Name()
} }
// get table name. method, or field name. auto snaked.
func getTableName(val reflect.Value) string { func getTableName(val reflect.Value) string {
ind := reflect.Indirect(val) ind := reflect.Indirect(val)
fun := val.MethodByName("TableName") fun := val.MethodByName("TableName")
@ -26,6 +28,7 @@ func getTableName(val reflect.Value) string {
return snakeString(ind.Type().Name()) return snakeString(ind.Type().Name())
} }
// get table engine, mysiam or innodb.
func getTableEngine(val reflect.Value) string { func getTableEngine(val reflect.Value) string {
fun := val.MethodByName("TableEngine") fun := val.MethodByName("TableEngine")
if fun.IsValid() { if fun.IsValid() {
@ -40,6 +43,7 @@ func getTableEngine(val reflect.Value) string {
return "" return ""
} }
// get table index from method.
func getTableIndex(val reflect.Value) [][]string { func getTableIndex(val reflect.Value) [][]string {
fun := val.MethodByName("TableIndex") fun := val.MethodByName("TableIndex")
if fun.IsValid() { if fun.IsValid() {
@ -56,6 +60,7 @@ func getTableIndex(val reflect.Value) [][]string {
return nil return nil
} }
// get table unique from method
func getTableUnique(val reflect.Value) [][]string { func getTableUnique(val reflect.Value) [][]string {
fun := val.MethodByName("TableUnique") fun := val.MethodByName("TableUnique")
if fun.IsValid() { if fun.IsValid() {
@ -72,6 +77,7 @@ func getTableUnique(val reflect.Value) [][]string {
return nil return nil
} }
// get snaked column name
func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string { func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string {
col = strings.ToLower(col) col = strings.ToLower(col)
column := col column := col
@ -89,6 +95,7 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col
return column return column
} }
// return field type as type constant from reflect.Value
func getFieldType(val reflect.Value) (ft int, err error) { func getFieldType(val reflect.Value) (ft int, err error) {
elm := reflect.Indirect(val) elm := reflect.Indirect(val)
switch elm.Kind() { switch elm.Kind() {
@ -128,6 +135,7 @@ func getFieldType(val reflect.Value) (ft int, err error) {
return return
} }
// parse struct tag string
func parseStructTag(data string, attrs *map[string]bool, tags *map[string]string) { func parseStructTag(data string, attrs *map[string]bool, tags *map[string]string) {
attr := make(map[string]bool) attr := make(map[string]bool)
tag := make(map[string]string) tag := make(map[string]string)

View File

@ -25,6 +25,7 @@ var (
ErrMultiRows = errors.New("<QuerySeter> return multi rows") ErrMultiRows = errors.New("<QuerySeter> return multi rows")
ErrNoRows = errors.New("<QuerySeter> no row found") ErrNoRows = errors.New("<QuerySeter> no row found")
ErrStmtClosed = errors.New("<QuerySeter> stmt already closed") ErrStmtClosed = errors.New("<QuerySeter> stmt already closed")
ErrArgs = errors.New("<Ormer> args error may be empty")
ErrNotImplement = errors.New("have not implement") ErrNotImplement = errors.New("have not implement")
) )
@ -39,11 +40,12 @@ type orm struct {
var _ Ormer = new(orm) var _ Ormer = new(orm)
func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) { // get model info and model reflect value
func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) {
val := reflect.ValueOf(md) val := reflect.ValueOf(md)
ind = reflect.Indirect(val) ind = reflect.Indirect(val)
typ := ind.Type() typ := ind.Type()
if val.Kind() != reflect.Ptr { if needPtr && val.Kind() != reflect.Ptr {
panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ))) panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ)))
} }
name := getFullName(typ) name := getFullName(typ)
@ -53,6 +55,7 @@ func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) {
panic(fmt.Errorf("<Ormer> table: `%s` not found, maybe not RegisterModel", name)) panic(fmt.Errorf("<Ormer> table: `%s` not found, maybe not RegisterModel", name))
} }
// get field info from model info by given field name
func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
fi, ok := mi.fields.GetByAny(name) fi, ok := mi.fields.GetByAny(name)
if !ok { if !ok {
@ -61,8 +64,9 @@ func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
return fi return fi
} }
// read data to model
func (o *orm) Read(md interface{}, cols ...string) error { func (o *orm) Read(md interface{}, cols ...string) error {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md, true)
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols) err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols)
if err != nil { if err != nil {
return err return err
@ -70,13 +74,21 @@ func (o *orm) Read(md interface{}, cols ...string) error {
return nil return nil
} }
// insert model data to database
func (o *orm) Insert(md interface{}) (int64, error) { func (o *orm) Insert(md interface{}) (int64, error) {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md, true)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
if err != nil { if err != nil {
return id, err return id, err
} }
if id > 0 {
o.setPk(mi, ind, id)
return id, nil
}
// set auto pk field
func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) {
if mi.fields.pk.auto { if mi.fields.pk.auto {
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
ind.Field(mi.fields.pk.fieldIndex).SetUint(uint64(id)) ind.Field(mi.fields.pk.fieldIndex).SetUint(uint64(id))
@ -85,11 +97,46 @@ func (o *orm) Insert(md interface{}) (int64, error) {
} }
} }
} }
return id, nil
// insert some models to database
func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) {
var cnt int64
sind := reflect.Indirect(reflect.ValueOf(mds))
switch sind.Kind() {
case reflect.Array, reflect.Slice:
if sind.Len() == 0 {
return cnt, ErrArgs
}
default:
return cnt, ErrArgs
} }
if bulk <= 1 {
for i := 0; i < sind.Len(); i++ {
ind := sind.Index(i)
mi, _ := o.getMiInd(ind.Interface(), false)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
if err != nil {
return cnt, err
}
o.setPk(mi, ind, id)
cnt += 1
}
} else {
mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ)
}
return cnt, nil
}
// update model to database.
// cols set the columns those want to update.
func (o *orm) Update(md interface{}, cols ...string) (int64, error) { func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md, true)
num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
if err != nil { if err != nil {
return num, err return num, err
@ -97,26 +144,22 @@ func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
return num, nil return num, nil
} }
// delete model in database
func (o *orm) Delete(md interface{}) (int64, error) { func (o *orm) Delete(md interface{}) (int64, error) {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md, true)
num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ) num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ)
if err != nil { if err != nil {
return num, err return num, err
} }
if num > 0 { if num > 0 {
if mi.fields.pk.auto { o.setPk(mi, ind, 0)
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
ind.Field(mi.fields.pk.fieldIndex).SetUint(0)
} else {
ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
}
}
} }
return num, nil return num, nil
} }
// create a models to models queryer
func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md, true)
fi := o.getFieldInfo(mi, name) fi := o.getFieldInfo(mi, name)
switch { switch {
@ -129,6 +172,14 @@ func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer {
return newQueryM2M(md, o, mi, fi, ind) return newQueryM2M(md, o, mi, fi, ind)
} }
// load related models to md model.
// args are limit, offset int and order string.
//
// example:
// orm.LoadRelated(post,"Tags")
// for _,tag := range post.Tags{...}
//
// make sure the relation is defined in model struct tags.
func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) {
_, fi, ind, qseter := o.queryRelated(md, name) _, fi, ind, qseter := o.queryRelated(md, name)
@ -190,14 +241,21 @@ func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int
return nums, err return nums, err
} }
// return a QuerySeter for related models to md model.
// it can do all, update, delete in QuerySeter.
// example:
// qs := orm.QueryRelated(post,"Tag")
// qs.All(&[]*Tag{})
//
func (o *orm) QueryRelated(md interface{}, name string) QuerySeter { func (o *orm) QueryRelated(md interface{}, name string) QuerySeter {
// is this api needed ? // is this api needed ?
_, _, _, qs := o.queryRelated(md, name) _, _, _, qs := o.queryRelated(md, name)
return qs return qs
} }
// get QuerySeter for related models to md model
func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md, true)
fi := o.getFieldInfo(mi, name) fi := o.getFieldInfo(mi, name)
_, _, exist := getExistPk(mi, ind) _, _, exist := getExistPk(mi, ind)
@ -227,6 +285,7 @@ func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo,
return mi, fi, ind, qs return mi, fi, ind, qs
} }
// get reverse relation QuerySeter
func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
switch fi.fieldType { switch fi.fieldType {
case RelReverseOne, RelReverseMany: case RelReverseOne, RelReverseMany:
@ -247,6 +306,7 @@ func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS
return q return q
} }
// get relation QuerySeter
func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
switch fi.fieldType { switch fi.fieldType {
case RelOneToOne, RelForeignKey, RelManyToMany: case RelOneToOne, RelForeignKey, RelManyToMany:
@ -266,6 +326,9 @@ func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
return q return q
} }
// return a QuerySeter for table operations.
// table name can be string or struct.
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
name := "" name := ""
if table, ok := ptrStructOrTableName.(string); ok { if table, ok := ptrStructOrTableName.(string); ok {
@ -285,6 +348,7 @@ func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
return return
} }
// switch to another registered database driver by given name.
func (o *orm) Using(name string) error { func (o *orm) Using(name string) error {
if o.isTx { if o.isTx {
panic(fmt.Errorf("<Ormer.Using> transaction has been start, cannot change db")) panic(fmt.Errorf("<Ormer.Using> transaction has been start, cannot change db"))
@ -302,6 +366,7 @@ func (o *orm) Using(name string) error {
return nil return nil
} }
// begin transaction
func (o *orm) Begin() error { func (o *orm) Begin() error {
if o.isTx { if o.isTx {
return ErrTxHasBegan return ErrTxHasBegan
@ -320,6 +385,7 @@ func (o *orm) Begin() error {
return nil return nil
} }
// commit transaction
func (o *orm) Commit() error { func (o *orm) Commit() error {
if o.isTx == false { if o.isTx == false {
return ErrTxDone return ErrTxDone
@ -334,6 +400,7 @@ func (o *orm) Commit() error {
return err return err
} }
// rollback transaction
func (o *orm) Rollback() error { func (o *orm) Rollback() error {
if o.isTx == false { if o.isTx == false {
return ErrTxDone return ErrTxDone
@ -348,14 +415,17 @@ func (o *orm) Rollback() error {
return err return err
} }
// return a raw query seter for raw sql string.
func (o *orm) Raw(query string, args ...interface{}) RawSeter { func (o *orm) Raw(query string, args ...interface{}) RawSeter {
return newRawSet(o, query, args) return newRawSet(o, query, args)
} }
// return current using database Driver
func (o *orm) Driver() Driver { func (o *orm) Driver() Driver {
return driver(o.alias.Name) return driver(o.alias.Name)
} }
// create new orm
func NewOrm() Ormer { func NewOrm() Ormer {
BootStrap() // execute only once BootStrap() // execute only once

View File

@ -18,15 +18,19 @@ type condValue struct {
isCond bool isCond bool
} }
// condition struct.
// work for WHERE conditions.
type Condition struct { type Condition struct {
params []condValue params []condValue
} }
// return new condition struct
func NewCondition() *Condition { func NewCondition() *Condition {
c := &Condition{} c := &Condition{}
return c return c
} }
// add expression to condition
func (c Condition) And(expr string, args ...interface{}) *Condition { func (c Condition) And(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.And> args cannot empty")) panic(fmt.Errorf("<Condition.And> args cannot empty"))
@ -35,6 +39,7 @@ func (c Condition) And(expr string, args ...interface{}) *Condition {
return &c return &c
} }
// add NOT expression to condition
func (c Condition) AndNot(expr string, args ...interface{}) *Condition { func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.AndNot> args cannot empty")) panic(fmt.Errorf("<Condition.AndNot> args cannot empty"))
@ -43,6 +48,7 @@ func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
return &c return &c
} }
// combine a condition to current condition
func (c *Condition) AndCond(cond *Condition) *Condition { func (c *Condition) AndCond(cond *Condition) *Condition {
c = c.clone() c = c.clone()
if c == cond { if c == cond {
@ -54,6 +60,7 @@ func (c *Condition) AndCond(cond *Condition) *Condition {
return c return c
} }
// add OR expression to condition
func (c Condition) Or(expr string, args ...interface{}) *Condition { func (c Condition) Or(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.Or> args cannot empty")) panic(fmt.Errorf("<Condition.Or> args cannot empty"))
@ -62,6 +69,7 @@ func (c Condition) Or(expr string, args ...interface{}) *Condition {
return &c return &c
} }
// add OR NOT expression to condition
func (c Condition) OrNot(expr string, args ...interface{}) *Condition { func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.OrNot> args cannot empty")) panic(fmt.Errorf("<Condition.OrNot> args cannot empty"))
@ -70,6 +78,7 @@ func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
return &c return &c
} }
// combine a OR condition to current condition
func (c *Condition) OrCond(cond *Condition) *Condition { func (c *Condition) OrCond(cond *Condition) *Condition {
c = c.clone() c = c.clone()
if c == cond { if c == cond {
@ -81,10 +90,12 @@ func (c *Condition) OrCond(cond *Condition) *Condition {
return c return c
} }
// check the condition arguments are empty or not.
func (c *Condition) IsEmpty() bool { func (c *Condition) IsEmpty() bool {
return len(c.params) == 0 return len(c.params) == 0
} }
// clone a condition
func (c Condition) clone() *Condition { func (c Condition) clone() *Condition {
return &c return &c
} }

View File

@ -13,6 +13,7 @@ type Log struct {
*log.Logger *log.Logger
} }
// set io.Writer to create a Logger.
func NewLog(out io.Writer) *Log { func NewLog(out io.Writer) *Log {
d := new(Log) d := new(Log)
d.Logger = log.New(out, "[ORM]", 1e9) d.Logger = log.New(out, "[ORM]", 1e9)
@ -40,6 +41,8 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error
DebugLog.Println(con) DebugLog.Println(con)
} }
// statement query logger struct.
// if dev mode, use stmtQueryLog, or use stmtQuerier.
type stmtQueryLog struct { type stmtQueryLog struct {
alias *alias alias *alias
query string query string
@ -84,6 +87,8 @@ func newStmtQueryLog(alias *alias, stmt stmtQuerier, query string) stmtQuerier {
return d return d
} }
// database query logger struct.
// if dev mode, use dbQueryLog, or use dbQuerier.
type dbQueryLog struct { type dbQueryLog struct {
alias *alias alias *alias
db dbQuerier db dbQuerier

View File

@ -5,6 +5,7 @@ import (
"reflect" "reflect"
) )
// an insert queryer struct
type insertSet struct { type insertSet struct {
mi *modelInfo mi *modelInfo
orm *orm orm *orm
@ -14,6 +15,7 @@ type insertSet struct {
var _ Inserter = new(insertSet) var _ Inserter = new(insertSet)
// insert model ignore it's registered or not.
func (o *insertSet) Insert(md interface{}) (int64, error) { func (o *insertSet) Insert(md interface{}) (int64, error) {
if o.closed { if o.closed {
return 0, ErrStmtClosed return 0, ErrStmtClosed
@ -44,6 +46,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) {
return id, nil return id, nil
} }
// close insert queryer statement
func (o *insertSet) Close() error { func (o *insertSet) Close() error {
if o.closed { if o.closed {
return ErrStmtClosed return ErrStmtClosed
@ -52,6 +55,7 @@ func (o *insertSet) Close() error {
return o.stmt.Close() return o.stmt.Close()
} }
// create new insert queryer.
func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) { func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) {
bi := new(insertSet) bi := new(insertSet)
bi.orm = orm bi.orm = orm

View File

@ -4,6 +4,7 @@ import (
"reflect" "reflect"
) )
// model to model struct
type queryM2M struct { type queryM2M struct {
md interface{} md interface{}
mi *modelInfo mi *modelInfo
@ -12,6 +13,13 @@ type queryM2M struct {
ind reflect.Value ind reflect.Value
} }
// add models to origin models when creating queryM2M.
// example:
// m2m := orm.QueryM2M(post,"Tag")
// m2m.Add(&Tag1{},&Tag2{})
// for _,tag := range post.Tags{}
//
// make sure the relation is defined in post model struct tag.
func (o *queryM2M) Add(mds ...interface{}) (int64, error) { func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
fi := o.fi fi := o.fi
mi := fi.relThroughModelInfo mi := fi.relThroughModelInfo
@ -44,7 +52,8 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
names := []string{mfi.column, rfi.column} names := []string{mfi.column, rfi.column}
var nums int64 values := make([]interface{}, 0, len(models)*2)
for _, md := range models { for _, md := range models {
ind := reflect.Indirect(reflect.ValueOf(md)) ind := reflect.Indirect(reflect.ValueOf(md))
@ -59,18 +68,14 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
} }
} }
values := []interface{}{v1, v2} values = append(values, v1, v2)
_, err := dbase.InsertValue(orm.db, mi, names, values)
if err != nil {
return nums, err
} }
nums += 1 return dbase.InsertValue(orm.db, mi, true, names, values)
}
return nums, nil
} }
// remove models following the origin model relationship
func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
fi := o.fi fi := o.fi
qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md) qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md)
@ -82,17 +87,20 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
return nums, nil return nums, nil
} }
// check model is existed in relationship of origin model
func (o *queryM2M) Exist(md interface{}) bool { func (o *queryM2M) Exist(md interface{}) bool {
fi := o.fi fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md). return o.qs.Filter(fi.reverseFieldInfo.name, o.md).
Filter(fi.reverseFieldInfoTwo.name, md).Exist() Filter(fi.reverseFieldInfoTwo.name, md).Exist()
} }
// clean all models in related of origin model
func (o *queryM2M) Clear() (int64, error) { func (o *queryM2M) Clear() (int64, error) {
fi := o.fi fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete() return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete()
} }
// count all related models of origin model
func (o *queryM2M) Count() (int64, error) { func (o *queryM2M) Count() (int64, error) {
fi := o.fi fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count() return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count()
@ -100,6 +108,7 @@ func (o *queryM2M) Count() (int64, error) {
var _ QueryM2Mer = new(queryM2M) var _ QueryM2Mer = new(queryM2M)
// create new M2M queryer.
func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer { func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer {
qm2m := new(queryM2M) qm2m := new(queryM2M)
qm2m.md = md qm2m.md = md

View File

@ -18,6 +18,10 @@ const (
Col_Except Col_Except
) )
// ColValue do the field raw changes. e.g Nums = Nums + 10. usage:
// Params{
// "Nums": ColValue(Col_Add, 10),
// }
func ColValue(opt operator, value interface{}) interface{} { func ColValue(opt operator, value interface{}) interface{} {
switch opt { switch opt {
case Col_Add, Col_Minus, Col_Multiply, Col_Except: case Col_Add, Col_Minus, Col_Multiply, Col_Except:
@ -34,6 +38,7 @@ func ColValue(opt operator, value interface{}) interface{} {
return val return val
} }
// real query struct
type querySet struct { type querySet struct {
mi *modelInfo mi *modelInfo
cond *Condition cond *Condition
@ -47,6 +52,7 @@ type querySet struct {
var _ QuerySeter = new(querySet) var _ QuerySeter = new(querySet)
// add condition expression to QuerySeter.
func (o querySet) Filter(expr string, args ...interface{}) QuerySeter { func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
if o.cond == nil { if o.cond == nil {
o.cond = NewCondition() o.cond = NewCondition()
@ -55,6 +61,7 @@ func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
return &o return &o
} }
// add NOT condition to querySeter.
func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter { func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
if o.cond == nil { if o.cond == nil {
o.cond = NewCondition() o.cond = NewCondition()
@ -63,10 +70,13 @@ func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
return &o return &o
} }
// set offset number
func (o *querySet) setOffset(num interface{}) { func (o *querySet) setOffset(num interface{}) {
o.offset = ToInt64(num) o.offset = ToInt64(num)
} }
// add LIMIT value.
// args[0] means offset, e.g. LIMIT num,offset.
func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter { func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter {
o.limit = ToInt64(limit) o.limit = ToInt64(limit)
if len(args) > 0 { if len(args) > 0 {
@ -75,16 +85,21 @@ func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter {
return &o return &o
} }
// add OFFSET value
func (o querySet) Offset(offset interface{}) QuerySeter { func (o querySet) Offset(offset interface{}) QuerySeter {
o.setOffset(offset) o.setOffset(offset)
return &o return &o
} }
// add ORDER expression.
// "column" means ASC, "-column" means DESC.
func (o querySet) OrderBy(exprs ...string) QuerySeter { func (o querySet) OrderBy(exprs ...string) QuerySeter {
o.orders = exprs o.orders = exprs
return &o return &o
} }
// set relation model to query together.
// it will query relation models and assign to parent model.
func (o querySet) RelatedSel(params ...interface{}) QuerySeter { func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
var related []string var related []string
if len(params) == 0 { if len(params) == 0 {
@ -105,36 +120,50 @@ func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
return &o return &o
} }
// set condition to QuerySeter.
func (o querySet) SetCond(cond *Condition) QuerySeter { func (o querySet) SetCond(cond *Condition) QuerySeter {
o.cond = cond o.cond = cond
return &o return &o
} }
// return QuerySeter execution result number
func (o *querySet) Count() (int64, error) { func (o *querySet) Count() (int64, error) {
return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
} }
// check result empty or not after QuerySeter executed
func (o *querySet) Exist() bool { func (o *querySet) Exist() bool {
cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
return cnt > 0 return cnt > 0
} }
// execute update with parameters
func (o *querySet) Update(values Params) (int64, error) { func (o *querySet) Update(values Params) (int64, error) {
return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ)
} }
// execute delete
func (o *querySet) Delete() (int64, error) { func (o *querySet) Delete() (int64, error) {
return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
} }
// return a insert queryer.
// it can be used in times.
// example:
// i,err := sq.PrepareInsert()
// i.Add(&user1{},&user2{})
func (o *querySet) PrepareInsert() (Inserter, error) { func (o *querySet) PrepareInsert() (Inserter, error) {
return newInsertSet(o.orm, o.mi) return newInsertSet(o.orm, o.mi)
} }
// query all data and map to containers.
// cols means the columns when querying.
func (o *querySet) All(container interface{}, cols ...string) (int64, error) { func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
} }
// query one row data and map to containers.
// cols means the columns when querying.
func (o *querySet) One(container interface{}, cols ...string) error { func (o *querySet) One(container interface{}, cols ...string) error {
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
if err != nil { if err != nil {
@ -149,18 +178,26 @@ func (o *querySet) One(container interface{}, cols ...string) error {
return nil return nil
} }
// query all data and map to []map[string]interface.
// expres means condition expression.
// it converts data to []map[column]value.
func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) { func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
} }
// query all data and map to [][]interface
// it converts data to [][column_index]value
func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) { func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
} }
// query all data and map to []interface.
// it's designed for one row record set, auto change to []value, not [][column]value.
func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) { func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ)
} }
// create new QuerySeter.
func newQuerySet(orm *orm, mi *modelInfo) QuerySeter { func newQuerySet(orm *orm, mi *modelInfo) QuerySeter {
o := new(querySet) o := new(querySet)
o.mi = mi o.mi = mi

View File

@ -4,10 +4,10 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"reflect" "reflect"
"strings"
"time" "time"
) )
// raw sql string prepared statement
type rawPrepare struct { type rawPrepare struct {
rs *rawSet rs *rawSet
stmt stmtQuerier stmt stmtQuerier
@ -45,6 +45,7 @@ func newRawPreparer(rs *rawSet) (RawPreparer, error) {
return o, nil return o, nil
} }
// raw query seter
type rawSet struct { type rawSet struct {
query string query string
args []interface{} args []interface{}
@ -53,11 +54,13 @@ type rawSet struct {
var _ RawSeter = new(rawSet) var _ RawSeter = new(rawSet)
// set args for every query
func (o rawSet) SetArgs(args ...interface{}) RawSeter { func (o rawSet) SetArgs(args ...interface{}) RawSeter {
o.args = args o.args = args
return &o return &o
} }
// execute raw sql and return sql.Result
func (o *rawSet) Exec() (sql.Result, error) { func (o *rawSet) Exec() (sql.Result, error) {
query := o.query query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query) o.orm.alias.DbBaser.ReplaceMarks(&query)
@ -66,6 +69,7 @@ func (o *rawSet) Exec() (sql.Result, error) {
return o.orm.db.Exec(query, args...) return o.orm.db.Exec(query, args...)
} }
// set field value to row container
func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
switch ind.Kind() { switch ind.Kind() {
case reflect.Bool: case reflect.Bool:
@ -164,65 +168,12 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
} }
} }
func (o *rawSet) loopInitRefs(typ reflect.Type, refsPtr *[]interface{}, sIdxesPtr *[][]int) { // set field value in loop for slice container
sIdxes := *sIdxesPtr func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) {
refs := *refsPtr
if typ.Kind() == reflect.Struct {
if typ.String() == "time.Time" {
var ref interface{}
refs = append(refs, &ref)
sIdxes = append(sIdxes, []int{0})
} else {
idxs := []int{}
outFor:
for idx := 0; idx < typ.NumField(); idx++ {
ctyp := typ.Field(idx)
tag := ctyp.Tag.Get(defaultStructTagName)
for _, v := range strings.Split(tag, defaultStructTagDelim) {
if v == "-" {
continue outFor
}
}
tp := ctyp.Type
if tp.Kind() == reflect.Ptr {
tp = tp.Elem()
}
if tp.String() == "time.Time" {
var ref interface{}
refs = append(refs, &ref)
} else if tp.Kind() != reflect.Struct {
var ref interface{}
refs = append(refs, &ref)
} else {
// skip other type
continue
}
idxs = append(idxs, idx)
}
sIdxes = append(sIdxes, idxs)
}
} else {
var ref interface{}
refs = append(refs, &ref)
sIdxes = append(sIdxes, []int{0})
}
*sIdxesPtr = sIdxes
*refsPtr = refs
}
func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) {
nInds := *nIndsPtr nInds := *nIndsPtr
cur := 0 cur := 0
for i, idxs := range sIdxes { for i := 0; i < len(sInds); i++ {
sInd := sInds[i] sInd := sInds[i]
eTyp := eTyps[i] eTyp := eTyps[i]
@ -258,32 +209,8 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect
o.setFieldValue(ind, value) o.setFieldValue(ind, value)
} }
cur++ cur++
} else {
hasValue := false
for _, idx := range idxs {
tind := ind.Field(idx)
value := reflect.ValueOf(refs[cur]).Elem().Interface()
if value != nil {
hasValue = true
}
if tind.Kind() == reflect.Ptr {
if value == nil {
tindV := reflect.New(tind.Type()).Elem()
tind.Set(tindV)
} else {
tindV := reflect.New(tind.Type().Elem())
o.setFieldValue(tindV.Elem(), value)
tind.Set(tindV)
}
} else {
o.setFieldValue(tind, value)
}
cur++
}
if hasValue == false && isPtr {
val = reflect.New(val.Type()).Elem()
}
} }
} else { } else {
value := reflect.ValueOf(refs[cur]).Elem().Interface() value := reflect.ValueOf(refs[cur]).Elem().Interface()
if isPtr && value == nil { if isPtr && value == nil {
@ -312,16 +239,14 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect
} }
} }
// query data and map to container
func (o *rawSet) QueryRow(containers ...interface{}) error { func (o *rawSet) QueryRow(containers ...interface{}) error {
if len(containers) == 0 {
panic(fmt.Errorf("<RawSeter.QueryRow> need at least one arg"))
}
refs := make([]interface{}, 0, len(containers)) refs := make([]interface{}, 0, len(containers))
sIdxes := make([][]int, 0)
sInds := make([]reflect.Value, 0) sInds := make([]reflect.Value, 0)
eTyps := make([]reflect.Type, 0) eTyps := make([]reflect.Type, 0)
structMode := false
var sMi *modelInfo
for _, container := range containers { for _, container := range containers {
val := reflect.ValueOf(container) val := reflect.ValueOf(container)
ind := reflect.Indirect(val) ind := reflect.Indirect(val)
@ -335,44 +260,123 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
if typ.Kind() == reflect.Ptr { if typ.Kind() == reflect.Ptr {
typ = typ.Elem() typ = typ.Elem()
} }
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
sInds = append(sInds, ind) sInds = append(sInds, ind)
eTyps = append(eTyps, etyp) eTyps = append(eTyps, etyp)
o.loopInitRefs(typ, &refs, &sIdxes) if typ.Kind() == reflect.Struct && typ.String() != "time.Time" {
if len(containers) > 1 {
panic(fmt.Errorf("<RawSeter.QueryRow> now support one struct only. see #384"))
}
structMode = true
fn := getFullName(typ)
if mi, ok := modelCache.getByFN(fn); ok {
sMi = mi
}
} else {
var ref interface{}
refs = append(refs, &ref)
}
} }
query := o.query query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query) o.orm.alias.DbBaser.ReplaceMarks(&query)
args := getFlatParams(nil, o.args, o.orm.alias.TZ) args := getFlatParams(nil, o.args, o.orm.alias.TZ)
row := o.orm.db.QueryRow(query, args...) rows, err := o.orm.db.Query(query, args...)
if err != nil {
if err := row.Scan(refs...); err == sql.ErrNoRows { if err == sql.ErrNoRows {
return ErrNoRows return ErrNoRows
} else if err != nil { }
return err
}
defer rows.Close()
if rows.Next() {
if structMode {
columns, err := rows.Columns()
if err != nil {
return err
}
columnsMp := make(map[string]interface{}, len(columns))
refs = make([]interface{}, 0, len(columns))
for _, col := range columns {
var ref interface{}
columnsMp[col] = &ref
refs = append(refs, &ref)
}
if err := rows.Scan(refs...); err != nil {
return err
}
ind := sInds[0]
if ind.Kind() == reflect.Ptr {
if ind.IsNil() || !ind.IsValid() {
ind.Set(reflect.New(eTyps[0].Elem()))
}
ind = ind.Elem()
}
if sMi != nil {
for _, col := range columns {
if fi := sMi.fields.GetByColumn(col); fi != nil {
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
o.setFieldValue(ind.FieldByIndex([]int{fi.fieldIndex}), value)
}
}
} else {
for i := 0; i < ind.NumField(); i++ {
f := ind.Field(i)
fe := ind.Type().Field(i)
var attrs map[string]bool
var tags map[string]string
parseStructTag(fe.Tag.Get("orm"), &attrs, &tags)
var col string
if col = tags["column"]; len(col) == 0 {
col = snakeString(fe.Name)
}
if v, ok := columnsMp[col]; ok {
value := reflect.ValueOf(v).Elem().Interface()
o.setFieldValue(f, value)
}
}
}
} else {
if err := rows.Scan(refs...); err != nil {
return err return err
} }
nInds := make([]reflect.Value, len(sInds)) nInds := make([]reflect.Value, len(sInds))
o.loopSetRefs(refs, sIdxes, sInds, &nInds, eTyps, true) o.loopSetRefs(refs, sInds, &nInds, eTyps, true)
for i, sInd := range sInds { for i, sInd := range sInds {
nInd := nInds[i] nInd := nInds[i]
sInd.Set(nInd) sInd.Set(nInd)
} }
}
} else {
return ErrNoRows
}
return nil return nil
} }
// query data rows and map to container
func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
refs := make([]interface{}, 0) refs := make([]interface{}, 0, len(containers))
sIdxes := make([][]int, 0)
sInds := make([]reflect.Value, 0) sInds := make([]reflect.Value, 0)
eTyps := make([]reflect.Type, 0) eTyps := make([]reflect.Type, 0)
structMode := false
var sMi *modelInfo
for _, container := range containers { for _, container := range containers {
val := reflect.ValueOf(container) val := reflect.ValueOf(container)
sInd := reflect.Indirect(val) sInd := reflect.Indirect(val)
@ -389,7 +393,20 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
sInds = append(sInds, sInd) sInds = append(sInds, sInd)
eTyps = append(eTyps, etyp) eTyps = append(eTyps, etyp)
o.loopInitRefs(typ, &refs, &sIdxes) if typ.Kind() == reflect.Struct && typ.String() != "time.Time" {
if len(containers) > 1 {
panic(fmt.Errorf("<RawSeter.QueryRow> now support one struct only. see #384"))
}
structMode = true
fn := getFullName(typ)
if mi, ok := modelCache.getByFN(fn); ok {
sMi = mi
}
} else {
var ref interface{}
refs = append(refs, &ref)
}
} }
query := o.query query := o.query
@ -401,25 +418,102 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
return 0, err return 0, err
} }
nInds := make([]reflect.Value, len(sInds)) defer rows.Close()
var cnt int64 var cnt int64
nInds := make([]reflect.Value, len(sInds))
sInd := sInds[0]
for rows.Next() { for rows.Next() {
if structMode {
columns, err := rows.Columns()
if err != nil {
return 0, err
}
columnsMp := make(map[string]interface{}, len(columns))
refs = make([]interface{}, 0, len(columns))
for _, col := range columns {
var ref interface{}
columnsMp[col] = &ref
refs = append(refs, &ref)
}
if err := rows.Scan(refs...); err != nil { if err := rows.Scan(refs...); err != nil {
return 0, err return 0, err
} }
o.loopSetRefs(refs, sIdxes, sInds, &nInds, eTyps, cnt == 0) if cnt == 0 && !sInd.IsNil() {
sInd.Set(reflect.New(sInd.Type()).Elem())
}
var ind reflect.Value
if eTyps[0].Kind() == reflect.Ptr {
ind = reflect.New(eTyps[0].Elem())
} else {
ind = reflect.New(eTyps[0])
}
if ind.Kind() == reflect.Ptr {
ind = ind.Elem()
}
if sMi != nil {
for _, col := range columns {
if fi := sMi.fields.GetByColumn(col); fi != nil {
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
o.setFieldValue(ind.FieldByIndex([]int{fi.fieldIndex}), value)
}
}
} else {
for i := 0; i < ind.NumField(); i++ {
f := ind.Field(i)
fe := ind.Type().Field(i)
var attrs map[string]bool
var tags map[string]string
parseStructTag(fe.Tag.Get("orm"), &attrs, &tags)
var col string
if col = tags["column"]; len(col) == 0 {
col = snakeString(fe.Name)
}
if v, ok := columnsMp[col]; ok {
value := reflect.ValueOf(v).Elem().Interface()
o.setFieldValue(f, value)
}
}
}
if eTyps[0].Kind() == reflect.Ptr {
ind = ind.Addr()
}
sInd = reflect.Append(sInd, ind)
} else {
if err := rows.Scan(refs...); err != nil {
return 0, err
}
o.loopSetRefs(refs, sInds, &nInds, eTyps, cnt == 0)
}
cnt++ cnt++
} }
if cnt > 0 { if cnt > 0 {
if structMode {
sInds[0].Set(sInd)
} else {
for i, sInd := range sInds { for i, sInd := range sInds {
nInd := nInds[i] nInd := nInds[i]
sInd.Set(nInd) sInd.Set(nInd)
} }
} }
}
return cnt, nil return cnt, nil
} }
@ -455,6 +549,8 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
rs = r rs = r
} }
defer rs.Close()
var ( var (
refs []interface{} refs []interface{}
cnt int64 cnt int64
@ -527,18 +623,22 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
return cnt, nil return cnt, nil
} }
// query data to []map[string]interface
func (o *rawSet) Values(container *[]Params) (int64, error) { func (o *rawSet) Values(container *[]Params) (int64, error) {
return o.readValues(container) return o.readValues(container)
} }
// query data to [][]interface
func (o *rawSet) ValuesList(container *[]ParamsList) (int64, error) { func (o *rawSet) ValuesList(container *[]ParamsList) (int64, error) {
return o.readValues(container) return o.readValues(container)
} }
// query data to []interface
func (o *rawSet) ValuesFlat(container *ParamsList) (int64, error) { func (o *rawSet) ValuesFlat(container *ParamsList) (int64, error) {
return o.readValues(container) return o.readValues(container)
} }
// return prepared raw statement for used in times.
func (o *rawSet) Prepare() (RawPreparer, error) { func (o *rawSet) Prepare() (RawPreparer, error) {
return newRawPreparer(o) return newRawPreparer(o)
} }

View File

@ -1322,58 +1322,6 @@ func TestRawQueryRow(t *testing.T) {
} }
} }
type Tmp struct {
Skip0 string
Id int
Char *string
Skip1 int `orm:"-"`
Date time.Time
DateTime time.Time
}
Boolean = false
Text = ""
Int64 = 0
Uint = 0
tmp := new(Tmp)
cols = []string{
"int", "char", "date", "datetime", "boolean", "text", "int64", "uint",
}
query = fmt.Sprintf("SELECT NULL, %s%s%s FROM data WHERE id = ?", Q, strings.Join(cols, sep), Q)
values = []interface{}{
tmp, &Boolean, &Text, &Int64, &Uint,
}
err = dORM.Raw(query, 1).QueryRow(values...)
throwFailNow(t, err)
for _, col := range cols {
switch col {
case "id":
throwFail(t, AssertIs(tmp.Id, data_values[col]))
case "char":
c := tmp.Char
throwFail(t, AssertIs(*c, data_values[col]))
case "date":
v := tmp.Date.In(DefaultTimeLoc)
value := data_values[col].(time.Time).In(DefaultTimeLoc)
throwFail(t, AssertIs(v, value, test_Date))
case "datetime":
v := tmp.DateTime.In(DefaultTimeLoc)
value := data_values[col].(time.Time).In(DefaultTimeLoc)
throwFail(t, AssertIs(v, value, test_DateTime))
case "boolean":
throwFail(t, AssertIs(Boolean, data_values[col]))
case "text":
throwFail(t, AssertIs(Text, data_values[col]))
case "int64":
throwFail(t, AssertIs(Int64, data_values[col]))
case "uint":
throwFail(t, AssertIs(Uint, data_values[col]))
}
}
var ( var (
uid int uid int
status *int status *int
@ -1394,22 +1342,13 @@ func TestRawQueryRow(t *testing.T) {
func TestQueryRows(t *testing.T) { func TestQueryRows(t *testing.T) {
Q := dDbBaser.TableQuote() Q := dDbBaser.TableQuote()
cols := []string{
"id", "boolean", "char", "text", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32",
"int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal",
}
var datas []*Data var datas []*Data
var dids []int
sep := fmt.Sprintf("%s, %s", Q, Q) query := fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q)
query := fmt.Sprintf("SELECT %s%s%s, id FROM %sdata%s", Q, strings.Join(cols, sep), Q, Q, Q) num, err := dORM.Raw(query).QueryRows(&datas)
num, err := dORM.Raw(query).QueryRows(&datas, &dids)
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(num, 1))
throwFailNow(t, AssertIs(len(datas), 1)) throwFailNow(t, AssertIs(len(datas), 1))
throwFailNow(t, AssertIs(len(dids), 1))
throwFailNow(t, AssertIs(dids[0], 1))
ind := reflect.Indirect(reflect.ValueOf(datas[0])) ind := reflect.Indirect(reflect.ValueOf(datas[0]))
@ -1427,90 +1366,43 @@ func TestQueryRows(t *testing.T) {
throwFail(t, AssertIs(vu == value, true), value, vu) throwFail(t, AssertIs(vu == value, true), value, vu)
} }
type Tmp struct { var datas2 []Data
Id int
Name string query = fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q)
Skiped0 string `orm:"-"` num, err = dORM.Raw(query).QueryRows(&datas2)
Pid *int throwFailNow(t, err)
Skiped1 Data throwFailNow(t, AssertIs(num, 1))
Skiped2 *Data throwFailNow(t, AssertIs(len(datas2), 1))
ind = reflect.Indirect(reflect.ValueOf(datas2[0]))
for name, value := range Data_Values {
e := ind.FieldByName(name)
vu := e.Interface()
switch name {
case "Date":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date)
case "DateTime":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
}
throwFail(t, AssertIs(vu == value, true), value, vu)
} }
var ( var ids []int
ids []int var usernames []string
userNames []string query = fmt.Sprintf("SELECT %sid%s, %suser_name%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q, Q, Q)
profileIds1 []int num, err = dORM.Raw(query).QueryRows(&ids, &usernames)
profileIds2 []*int
createds []time.Time
updateds []time.Time
tmps1 []*Tmp
tmps2 []Tmp
)
cols = []string{
"id", "user_name", "profile_id", "profile_id", "id", "user_name", "profile_id", "id", "user_name", "profile_id", "created", "updated",
}
query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s ORDER BY id", Q, strings.Join(cols, sep), Q, Q, Q)
num, err = dORM.Raw(query).QueryRows(&ids, &userNames, &profileIds1, &profileIds2, &tmps1, &tmps2, &createds, &updateds)
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 3)) throwFailNow(t, AssertIs(num, 3))
throwFailNow(t, AssertIs(len(ids), 3))
var users []User throwFailNow(t, AssertIs(ids[0], 2))
dORM.QueryTable("user").OrderBy("Id").All(&users) throwFailNow(t, AssertIs(usernames[0], "slene"))
throwFailNow(t, AssertIs(ids[1], 3))
for i := 0; i < 3; i++ { throwFailNow(t, AssertIs(usernames[1], "astaxie"))
id := ids[i] throwFailNow(t, AssertIs(ids[2], 4))
name := userNames[i] throwFailNow(t, AssertIs(usernames[2], "nobody"))
pid1 := profileIds1[i]
pid2 := profileIds2[i]
created := createds[i]
updated := updateds[i]
user := users[i]
throwFailNow(t, AssertIs(id, user.Id))
throwFailNow(t, AssertIs(name, user.UserName))
if user.Profile != nil {
throwFailNow(t, AssertIs(pid1, user.Profile.Id))
throwFailNow(t, AssertIs(*pid2, user.Profile.Id))
} else {
throwFailNow(t, AssertIs(pid1, 0))
throwFailNow(t, AssertIs(pid2, nil))
}
throwFailNow(t, AssertIs(created, user.Created, test_Date))
throwFailNow(t, AssertIs(updated, user.Updated, test_DateTime))
tmp := tmps1[i]
tmp1 := *tmp
throwFailNow(t, AssertIs(tmp1.Id, user.Id))
throwFailNow(t, AssertIs(tmp1.Name, user.UserName))
if user.Profile != nil {
pid := tmp1.Pid
throwFailNow(t, AssertIs(*pid, user.Profile.Id))
} else {
throwFailNow(t, AssertIs(tmp1.Pid, nil))
}
tmp2 := tmps2[i]
throwFailNow(t, AssertIs(tmp2.Id, user.Id))
throwFailNow(t, AssertIs(tmp2.Name, user.UserName))
if user.Profile != nil {
pid := tmp2.Pid
throwFailNow(t, AssertIs(*pid, user.Profile.Id))
} else {
throwFailNow(t, AssertIs(tmp2.Pid, nil))
}
}
type Sec struct {
Id int
Name string
}
var tmp []*Sec
query = fmt.Sprintf("SELECT NULL, NULL FROM %suser%s LIMIT 1", Q, Q)
num, err = dORM.Raw(query).QueryRows(&tmp)
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
throwFail(t, AssertIs(tmp[0], nil))
} }
func TestRawValues(t *testing.T) { func TestRawValues(t *testing.T) {
@ -1669,6 +1561,32 @@ func TestDelete(t *testing.T) {
num, err = qs.Filter("user_name", "slene").Filter("profile__isnull", true).Count() num, err = qs.Filter("user_name", "slene").Filter("profile__isnull", true).Count()
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
qs = dORM.QueryTable("comment")
num, err = qs.Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 6))
qs = dORM.QueryTable("post")
num, err = qs.Filter("Id", 3).Delete()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
qs = dORM.QueryTable("comment")
num, err = qs.Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 4))
fmt.Println("...")
qs = dORM.QueryTable("comment")
num, err = qs.Filter("Post__User", 3).Delete()
throwFail(t, err)
throwFail(t, AssertIs(num, 3))
qs = dORM.QueryTable("comment")
num, err = qs.Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
} }
func TestTransaction(t *testing.T) { func TestTransaction(t *testing.T) {

View File

@ -6,11 +6,13 @@ import (
"time" "time"
) )
// database driver
type Driver interface { type Driver interface {
Name() string Name() string
Type() DriverType Type() DriverType
} }
// field info
type Fielder interface { type Fielder interface {
String() string String() string
FieldType() int FieldType() int
@ -18,9 +20,11 @@ type Fielder interface {
RawValue() interface{} RawValue() interface{}
} }
// orm struct
type Ormer interface { type Ormer interface {
Read(interface{}, ...string) error Read(interface{}, ...string) error
Insert(interface{}) (int64, error) Insert(interface{}) (int64, error)
InsertMulti(int, interface{}) (int64, error)
Update(interface{}, ...string) (int64, error) Update(interface{}, ...string) (int64, error)
Delete(interface{}) (int64, error) Delete(interface{}) (int64, error)
LoadRelated(interface{}, string, ...interface{}) (int64, error) LoadRelated(interface{}, string, ...interface{}) (int64, error)
@ -34,11 +38,13 @@ type Ormer interface {
Driver() Driver Driver() Driver
} }
// insert prepared statement
type Inserter interface { type Inserter interface {
Insert(interface{}) (int64, error) Insert(interface{}) (int64, error)
Close() error Close() error
} }
// query seter
type QuerySeter interface { type QuerySeter interface {
Filter(string, ...interface{}) QuerySeter Filter(string, ...interface{}) QuerySeter
Exclude(string, ...interface{}) QuerySeter Exclude(string, ...interface{}) QuerySeter
@ -59,6 +65,7 @@ type QuerySeter interface {
ValuesFlat(*ParamsList, string) (int64, error) ValuesFlat(*ParamsList, string) (int64, error)
} }
// model to model query struct
type QueryM2Mer interface { type QueryM2Mer interface {
Add(...interface{}) (int64, error) Add(...interface{}) (int64, error)
Remove(...interface{}) (int64, error) Remove(...interface{}) (int64, error)
@ -67,11 +74,13 @@ type QueryM2Mer interface {
Count() (int64, error) Count() (int64, error)
} }
// raw query statement
type RawPreparer interface { type RawPreparer interface {
Exec(...interface{}) (sql.Result, error) Exec(...interface{}) (sql.Result, error)
Close() error Close() error
} }
// raw query seter
type RawSeter interface { type RawSeter interface {
Exec() (sql.Result, error) Exec() (sql.Result, error)
QueryRow(...interface{}) error QueryRow(...interface{}) error
@ -83,6 +92,7 @@ type RawSeter interface {
Prepare() (RawPreparer, error) Prepare() (RawPreparer, error)
} }
// statement querier
type stmtQuerier interface { type stmtQuerier interface {
Close() error Close() error
Exec(args ...interface{}) (sql.Result, error) Exec(args ...interface{}) (sql.Result, error)
@ -90,6 +100,7 @@ type stmtQuerier interface {
QueryRow(args ...interface{}) *sql.Row QueryRow(args ...interface{}) *sql.Row
} }
// db querier
type dbQuerier interface { type dbQuerier interface {
Prepare(query string) (*sql.Stmt, error) Prepare(query string) (*sql.Stmt, error)
Exec(query string, args ...interface{}) (sql.Result, error) Exec(query string, args ...interface{}) (sql.Result, error)
@ -97,19 +108,23 @@ type dbQuerier interface {
QueryRow(query string, args ...interface{}) *sql.Row QueryRow(query string, args ...interface{}) *sql.Row
} }
// transaction beginner
type txer interface { type txer interface {
Begin() (*sql.Tx, error) Begin() (*sql.Tx, error)
} }
// transaction ending
type txEnder interface { type txEnder interface {
Commit() error Commit() error
Rollback() error Rollback() error
} }
// base database struct
type dbBaser interface { type dbBaser interface {
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
InsertValue(dbQuerier, *modelInfo, []string, []interface{}) (int64, error) InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)

View File

@ -10,6 +10,7 @@ import (
type StrTo string type StrTo string
// set string
func (f *StrTo) Set(v string) { func (f *StrTo) Set(v string) {
if v != "" { if v != "" {
*f = StrTo(v) *f = StrTo(v)
@ -18,77 +19,93 @@ func (f *StrTo) Set(v string) {
} }
} }
// clean string
func (f *StrTo) Clear() { func (f *StrTo) Clear() {
*f = StrTo(0x1E) *f = StrTo(0x1E)
} }
// check string exist
func (f StrTo) Exist() bool { func (f StrTo) Exist() bool {
return string(f) != string(0x1E) return string(f) != string(0x1E)
} }
// string to bool
func (f StrTo) Bool() (bool, error) { func (f StrTo) Bool() (bool, error) {
return strconv.ParseBool(f.String()) return strconv.ParseBool(f.String())
} }
// string to float32
func (f StrTo) Float32() (float32, error) { func (f StrTo) Float32() (float32, error) {
v, err := strconv.ParseFloat(f.String(), 32) v, err := strconv.ParseFloat(f.String(), 32)
return float32(v), err return float32(v), err
} }
// string to float64
func (f StrTo) Float64() (float64, error) { func (f StrTo) Float64() (float64, error) {
return strconv.ParseFloat(f.String(), 64) return strconv.ParseFloat(f.String(), 64)
} }
// string to int
func (f StrTo) Int() (int, error) { func (f StrTo) Int() (int, error) {
v, err := strconv.ParseInt(f.String(), 10, 32) v, err := strconv.ParseInt(f.String(), 10, 32)
return int(v), err return int(v), err
} }
// string to int8
func (f StrTo) Int8() (int8, error) { func (f StrTo) Int8() (int8, error) {
v, err := strconv.ParseInt(f.String(), 10, 8) v, err := strconv.ParseInt(f.String(), 10, 8)
return int8(v), err return int8(v), err
} }
// string to int16
func (f StrTo) Int16() (int16, error) { func (f StrTo) Int16() (int16, error) {
v, err := strconv.ParseInt(f.String(), 10, 16) v, err := strconv.ParseInt(f.String(), 10, 16)
return int16(v), err return int16(v), err
} }
// string to int32
func (f StrTo) Int32() (int32, error) { func (f StrTo) Int32() (int32, error) {
v, err := strconv.ParseInt(f.String(), 10, 32) v, err := strconv.ParseInt(f.String(), 10, 32)
return int32(v), err return int32(v), err
} }
// string to int64
func (f StrTo) Int64() (int64, error) { func (f StrTo) Int64() (int64, error) {
v, err := strconv.ParseInt(f.String(), 10, 64) v, err := strconv.ParseInt(f.String(), 10, 64)
return int64(v), err return int64(v), err
} }
// string to uint
func (f StrTo) Uint() (uint, error) { func (f StrTo) Uint() (uint, error) {
v, err := strconv.ParseUint(f.String(), 10, 32) v, err := strconv.ParseUint(f.String(), 10, 32)
return uint(v), err return uint(v), err
} }
// string to uint8
func (f StrTo) Uint8() (uint8, error) { func (f StrTo) Uint8() (uint8, error) {
v, err := strconv.ParseUint(f.String(), 10, 8) v, err := strconv.ParseUint(f.String(), 10, 8)
return uint8(v), err return uint8(v), err
} }
// string to uint16
func (f StrTo) Uint16() (uint16, error) { func (f StrTo) Uint16() (uint16, error) {
v, err := strconv.ParseUint(f.String(), 10, 16) v, err := strconv.ParseUint(f.String(), 10, 16)
return uint16(v), err return uint16(v), err
} }
// string to uint31
func (f StrTo) Uint32() (uint32, error) { func (f StrTo) Uint32() (uint32, error) {
v, err := strconv.ParseUint(f.String(), 10, 32) v, err := strconv.ParseUint(f.String(), 10, 32)
return uint32(v), err return uint32(v), err
} }
// string to uint64
func (f StrTo) Uint64() (uint64, error) { func (f StrTo) Uint64() (uint64, error) {
v, err := strconv.ParseUint(f.String(), 10, 64) v, err := strconv.ParseUint(f.String(), 10, 64)
return uint64(v), err return uint64(v), err
} }
// string to string
func (f StrTo) String() string { func (f StrTo) String() string {
if f.Exist() { if f.Exist() {
return string(f) return string(f)
@ -96,6 +113,7 @@ func (f StrTo) String() string {
return "" return ""
} }
// interface to string
func ToStr(value interface{}, args ...int) (s string) { func ToStr(value interface{}, args ...int) (s string) {
switch v := value.(type) { switch v := value.(type) {
case bool: case bool:
@ -134,6 +152,7 @@ func ToStr(value interface{}, args ...int) (s string) {
return s return s
} }
// interface to int64
func ToInt64(value interface{}) (d int64) { func ToInt64(value interface{}) (d int64) {
val := reflect.ValueOf(value) val := reflect.ValueOf(value)
switch value.(type) { switch value.(type) {
@ -147,6 +166,7 @@ func ToInt64(value interface{}) (d int64) {
return return
} }
// snake string, XxYy to xx_yy
func snakeString(s string) string { func snakeString(s string) string {
data := make([]byte, 0, len(s)*2) data := make([]byte, 0, len(s)*2)
j := false j := false
@ -164,6 +184,7 @@ func snakeString(s string) string {
return strings.ToLower(string(data[:len(data)])) return strings.ToLower(string(data[:len(data)]))
} }
// camel string, xx_yy to XxYy
func camelString(s string) string { func camelString(s string) string {
data := make([]byte, 0, len(s)) data := make([]byte, 0, len(s))
j := false j := false
@ -190,6 +211,7 @@ func camelString(s string) string {
type argString []string type argString []string
// get string by index from string slice
func (a argString) Get(i int, args ...string) (r string) { func (a argString) Get(i int, args ...string) (r string) {
if i >= 0 && i < len(a) { if i >= 0 && i < len(a) {
r = a[i] r = a[i]
@ -201,6 +223,7 @@ func (a argString) Get(i int, args ...string) (r string) {
type argInt []int type argInt []int
// get int by index from int slice
func (a argInt) Get(i int, args ...int) (r int) { func (a argInt) Get(i int, args ...int) (r int) {
if i >= 0 && i < len(a) { if i >= 0 && i < len(a) {
r = a[i] r = a[i]
@ -213,6 +236,7 @@ func (a argInt) Get(i int, args ...int) (r int) {
type argAny []interface{} type argAny []interface{}
// get interface by index from interface slice
func (a argAny) Get(i int, args ...interface{}) (r interface{}) { func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
if i >= 0 && i < len(a) { if i >= 0 && i < len(a) {
r = a[i] r = a[i]
@ -223,15 +247,18 @@ func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
return return
} }
// parse time to string with location
func timeParse(dateString, format string) (time.Time, error) { func timeParse(dateString, format string) (time.Time, error) {
tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc) tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
return tp, err return tp, err
} }
// format time string
func timeFormat(t time.Time, format string) string { func timeFormat(t time.Time, format string) string {
return t.Format(format) return t.Format(format)
} }
// get pointer indirect type
func indirectType(v reflect.Type) reflect.Type { func indirectType(v reflect.Type) reflect.Type {
switch v.Kind() { switch v.Kind() {
case reflect.Ptr: case reflect.Ptr:

View File

@ -1,7 +1,10 @@
package beego package beego
import ( import (
"bufio"
"errors"
"fmt" "fmt"
"net"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
@ -30,6 +33,14 @@ const (
var ( var (
// supported http methods. // supported http methods.
HTTPMETHOD = []string{"get", "post", "put", "delete", "patch", "options", "head"} HTTPMETHOD = []string{"get", "post", "put", "delete", "patch", "options", "head"}
// these beego.Controller's methods shouldn't reflect to AutoRouter
exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString",
"RenderBytes", "Redirect", "Abort", "StopRun", "UrlFor", "ServeJson", "ServeJsonp",
"ServeXml", "Input", "ParseForm", "GetString", "GetStrings", "GetInt", "GetBool",
"GetFloat", "GetFile", "SaveToFile", "StartSession", "SetSession", "GetSession",
"DelSession", "SessionRegenerateID", "DestroySession", "IsAjax", "GetSecureCookie",
"SetSecureCookie", "XsrfToken", "CheckXsrfCookie", "XsrfFormHtml",
"GetControllerAndAction"}
) )
type controllerInfo struct { type controllerInfo struct {
@ -77,7 +88,7 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM
params := make(map[int]string) params := make(map[int]string)
for i, part := range parts { for i, part := range parts {
if strings.HasPrefix(part, ":") { if strings.HasPrefix(part, ":") {
expr := "(.+)" expr := "(.*)"
//a user may choose to override the defult expression //a user may choose to override the defult expression
// similar to expressjs: /user/:id([0-9]+) // similar to expressjs: /user/:id([0-9]+)
if index := strings.Index(part, "("); index != -1 { if index := strings.Index(part, "("); index != -1 {
@ -100,7 +111,7 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM
j++ j++
} }
if strings.HasPrefix(part, "*") { if strings.HasPrefix(part, "*") {
expr := "(.+)" expr := "(.*)"
if part == "*.*" { if part == "*.*" {
params[j] = ":path" params[j] = ":path"
parts[i] = "([^.]+).([^.]+)" parts[i] = "([^.]+).([^.]+)"
@ -218,8 +229,8 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM
// Add auto router to ControllerRegistor. // Add auto router to ControllerRegistor.
// example beego.AddAuto(&MainContorlller{}), // example beego.AddAuto(&MainContorlller{}),
// MainController has method List and Page. // MainController has method List and Page.
// visit the url /main/list to exec List function // visit the url /main/list to execute List function
// /main/page to exec Page function. // /main/page to execute Page function.
func (p *ControllerRegistor) AddAuto(c ControllerInterface) { func (p *ControllerRegistor) AddAuto(c ControllerInterface) {
p.enableAuto = true p.enableAuto = true
reflectVal := reflect.ValueOf(c) reflectVal := reflect.ValueOf(c)
@ -232,14 +243,42 @@ func (p *ControllerRegistor) AddAuto(c ControllerInterface) {
p.autoRouter[firstParam] = make(map[string]reflect.Type) p.autoRouter[firstParam] = make(map[string]reflect.Type)
} }
for i := 0; i < rt.NumMethod(); i++ { for i := 0; i < rt.NumMethod(); i++ {
if !utils.InSlice(rt.Method(i).Name, exceptMethod) {
p.autoRouter[firstParam][rt.Method(i).Name] = ct p.autoRouter[firstParam][rt.Method(i).Name] = ct
} }
} }
}
// Add auto router to ControllerRegistor with prefix.
// example beego.AddAutoPrefix("/admin",&MainContorlller{}),
// MainController has method List and Page.
// visit the url /admin/main/list to execute List function
// /admin/main/page to execute Page function.
func (p *ControllerRegistor) AddAutoPrefix(prefix string, c ControllerInterface) {
p.enableAuto = true
reflectVal := reflect.ValueOf(c)
rt := reflectVal.Type()
ct := reflect.Indirect(reflectVal).Type()
firstParam := strings.Trim(prefix, "/") + "/" + strings.ToLower(strings.TrimSuffix(ct.Name(), "Controller"))
if _, ok := p.autoRouter[firstParam]; ok {
return
} else {
p.autoRouter[firstParam] = make(map[string]reflect.Type)
}
for i := 0; i < rt.NumMethod(); i++ {
if !utils.InSlice(rt.Method(i).Name, exceptMethod) {
p.autoRouter[firstParam][rt.Method(i).Name] = ct
}
}
}
// [Deprecated] use InsertFilter. // [Deprecated] use InsertFilter.
// Add FilterFunc with pattern for action. // Add FilterFunc with pattern for action.
func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc) { func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc) error {
mr := buildFilter(pattern, filter) mr, err := buildFilter(pattern, filter)
if err != nil {
return err
}
switch action { switch action {
case "BeforeRouter": case "BeforeRouter":
p.filters[BeforeRouter] = append(p.filters[BeforeRouter], mr) p.filters[BeforeRouter] = append(p.filters[BeforeRouter], mr)
@ -253,13 +292,18 @@ func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc
p.filters[FinishRouter] = append(p.filters[FinishRouter], mr) p.filters[FinishRouter] = append(p.filters[FinishRouter], mr)
} }
p.enableFilter = true p.enableFilter = true
return nil
} }
// Add a FilterFunc with pattern rule and action constant. // Add a FilterFunc with pattern rule and action constant.
func (p *ControllerRegistor) InsertFilter(pattern string, pos int, filter FilterFunc) { func (p *ControllerRegistor) InsertFilter(pattern string, pos int, filter FilterFunc) error {
mr := buildFilter(pattern, filter) mr, err := buildFilter(pattern, filter)
if err != nil {
return err
}
p.filters[pos] = append(p.filters[pos], mr) p.filters[pos] = append(p.filters[pos], mr)
p.enableFilter = true p.enableFilter = true
return nil
} }
// UrlFor does another controller handler in this request function. // UrlFor does another controller handler in this request function.
@ -485,7 +529,9 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
// session init // session init
if SessionOn { if SessionOn {
context.Input.CruSession = GlobalSessions.SessionStart(w, r) context.Input.CruSession = GlobalSessions.SessionStart(w, r)
defer context.Input.CruSession.SessionRelease() defer func() {
context.Input.CruSession.SessionRelease(w)
}()
} }
if !utils.InSlice(strings.ToLower(r.Method), HTTPMETHOD) { if !utils.InSlice(strings.ToLower(r.Method), HTTPMETHOD) {
@ -575,12 +621,11 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
} }
// pattern /admin url /admin 200 /admin/ 200 // pattern /admin url /admin 200 /admin/ 200
// pattern /admin/ url /admin 301 /admin/ 200 // pattern /admin/ url /admin 301 /admin/ 200
if requestPath[n-1] != '/' && len(route.pattern) == n+1 && if requestPath[n-1] != '/' && requestPath+"/" == route.pattern {
route.pattern[n] == '/' && route.pattern[:n] == requestPath {
http.Redirect(w, r, requestPath+"/", 301) http.Redirect(w, r, requestPath+"/", 301)
goto Admin goto Admin
} }
if requestPath[n-1] == '/' && n >= 2 && requestPath[:n-2] == route.pattern { if requestPath[n-1] == '/' && route.pattern+"/" == requestPath {
runMethod = p.getRunMethod(r.Method, context, route) runMethod = p.getRunMethod(r.Method, context, route)
if runMethod != "" { if runMethod != "" {
runrouter = route.controllerType runrouter = route.controllerType
@ -857,3 +902,13 @@ func (w *responseWriter) WriteHeader(code int) {
w.started = true w.started = true
w.writer.WriteHeader(code) w.writer.WriteHeader(code)
} }
// hijacker for http
func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hj, ok := w.writer.(http.Hijacker)
if !ok {
println("supported?")
return nil, nil, errors.New("webserver doesn't support hijacking")
}
return hj.Hijack()
}

View File

@ -198,3 +198,15 @@ func TestPrepare(t *testing.T) {
t.Errorf(w.Body.String() + "user define func can't run") t.Errorf(w.Body.String() + "user define func can't run")
} }
} }
func TestAutoPrefix(t *testing.T) {
r, _ := http.NewRequest("GET", "/admin/test/list", nil)
w := httptest.NewRecorder()
handler := NewControllerRegistor()
handler.AddAutoPrefix("/admin", &TestController{})
handler.ServeHTTP(w, r)
if w.Body.String() != "i am list" {
t.Errorf("TestAutoPrefix can't run")
}
}

View File

@ -28,21 +28,21 @@ Then in you web app init the global session manager
* Use **memory** as provider: * Use **memory** as provider:
func init() { func init() {
globalSessions, _ = session.NewManager("memory", "gosessionid", 3600,"") globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid","gclifetime":3600}`)
go globalSessions.GC() go globalSessions.GC()
} }
* Use **file** as provider, the last param is the path where you want file to be stored: * Use **file** as provider, the last param is the path where you want file to be stored:
func init() { func init() {
globalSessions, _ = session.NewManager("file", "gosessionid", 3600, "./tmp") globalSessions, _ = session.NewManager("file",`{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","./tmp"}`)
go globalSessions.GC() go globalSessions.GC()
} }
* Use **Redis** as provider, the last param is the Redis conn address,poolsize,password: * Use **Redis** as provider, the last param is the Redis conn address,poolsize,password:
func init() { func init() {
globalSessions, _ = session.NewManager("redis", "gosessionid", 3600, "127.0.0.1:6379,100,astaxie") globalSessions, _ = session.NewManager("redis", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","127.0.0.1:6379,100,astaxie"}`)
go globalSessions.GC() go globalSessions.GC()
} }
@ -50,15 +50,24 @@ Then in you web app init the global session manager
func init() { func init() {
globalSessions, _ = session.NewManager( globalSessions, _ = session.NewManager(
"mysql", "gosessionid", 3600, "username:password@protocol(address)/dbname?param=value") "mysql", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","username:password@protocol(address)/dbname?param=value"}`)
go globalSessions.GC() go globalSessions.GC()
} }
* Use **Cookie** as provider:
func init() {
globalSessions, _ = session.NewManager(
"cookie", `{"cookieName":"gosessionid","enableSetCookie":false,gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`)
go globalSessions.GC()
}
Finally in the handlerfunc you can use it like this Finally in the handlerfunc you can use it like this
func login(w http.ResponseWriter, r *http.Request) { func login(w http.ResponseWriter, r *http.Request) {
sess := globalSessions.SessionStart(w, r) sess := globalSessions.SessionStart(w, r)
defer sess.SessionRelease() defer sess.SessionRelease(w)
username := sess.Get("username") username := sess.Get("username")
fmt.Println(username) fmt.Println(username)
if r.Method == "GET" { if r.Method == "GET" {
@ -78,19 +87,19 @@ When you develop a web app, maybe you want to write own provider because you mus
Writing a provider is easy. You only need to define two struct types Writing a provider is easy. You only need to define two struct types
(Session and Provider), which satisfy the interface definition. (Session and Provider), which satisfy the interface definition.
Maybe you will find the **memory** provider as good example. Maybe you will find the **memory** provider is a good example.
type SessionStore interface { type SessionStore interface {
Set(key, value interface{}) error //set session value Set(key, value interface{}) error //set session value
Get(key interface{}) interface{} //get session value Get(key interface{}) interface{} //get session value
Delete(key interface{}) error //delete session value Delete(key interface{}) error //delete session value
SessionID() string //back current sessionID SessionID() string //back current sessionID
SessionRelease() // release the resource & save data to provider SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data
Flush() error //delete all data Flush() error //delete all data
} }
type Provider interface { type Provider interface {
SessionInit(maxlifetime int64, savePath string) error SessionInit(gclifetime int64, config string) error
SessionRead(sid string) (SessionStore, error) SessionRead(sid string) (SessionStore, error)
SessionExist(sid string) bool SessionExist(sid string) bool
SessionRegenerate(oldsid, sid string) (SessionStore, error) SessionRegenerate(oldsid, sid string) (SessionStore, error)

145
session/sess_cookie.go Normal file
View File

@ -0,0 +1,145 @@
package session
import (
"crypto/aes"
"crypto/cipher"
"encoding/json"
"net/http"
"net/url"
"sync"
)
var cookiepder = &CookieProvider{}
type CookieSessionStore struct {
sid string
values map[interface{}]interface{} //session data
lock sync.RWMutex
}
func (st *CookieSessionStore) Set(key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values[key] = value
return nil
}
func (st *CookieSessionStore) Get(key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.values[key]; ok {
return v
} else {
return nil
}
return nil
}
func (st *CookieSessionStore) Delete(key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.values, key)
return nil
}
func (st *CookieSessionStore) Flush() error {
st.lock.Lock()
defer st.lock.Unlock()
st.values = make(map[interface{}]interface{})
return nil
}
func (st *CookieSessionStore) SessionID() string {
return st.sid
}
func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) {
str, err := encodeCookie(cookiepder.block,
cookiepder.config.SecurityKey,
cookiepder.config.SecurityName,
st.values)
if err != nil {
return
}
cookie := &http.Cookie{Name: cookiepder.config.CookieName,
Value: url.QueryEscape(str),
Path: "/",
HttpOnly: true,
Secure: cookiepder.config.Secure}
http.SetCookie(w, cookie)
return
}
type cookieConfig struct {
SecurityKey string `json:"securityKey"`
BlockKey string `json:"blockKey"`
SecurityName string `json:"securityName"`
CookieName string `json:"cookieName"`
Secure bool `json:"secure"`
Maxage int `json:"maxage"`
}
type CookieProvider struct {
maxlifetime int64
config *cookieConfig
block cipher.Block
}
func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error {
pder.config = &cookieConfig{}
err := json.Unmarshal([]byte(config), pder.config)
if err != nil {
return err
}
if pder.config.BlockKey == "" {
pder.config.BlockKey = string(generateRandomKey(16))
}
if pder.config.SecurityName == "" {
pder.config.SecurityName = string(generateRandomKey(20))
}
pder.block, err = aes.NewCipher([]byte(pder.config.BlockKey))
if err != nil {
return err
}
return nil
}
func (pder *CookieProvider) SessionRead(sid string) (SessionStore, error) {
maps, _ := decodeCookie(pder.block,
pder.config.SecurityKey,
pder.config.SecurityName,
sid, pder.maxlifetime)
if maps == nil {
maps = make(map[interface{}]interface{})
}
rs := &CookieSessionStore{sid: sid, values: maps}
return rs, nil
}
func (pder *CookieProvider) SessionExist(sid string) bool {
return true
}
func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
return nil, nil
}
func (pder *CookieProvider) SessionDestroy(sid string) error {
return nil
}
func (pder *CookieProvider) SessionGC() {
return
}
func (pder *CookieProvider) SessionAll() int {
return 0
}
func (pder *CookieProvider) SessionUpdate(sid string) error {
return nil
}
func init() {
Register("cookie", cookiepder)
}

View File

@ -0,0 +1,38 @@
package session
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestCookie(t *testing.T) {
config := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
globalSessions, err := NewManager("cookie", config)
if err != nil {
t.Fatal("init cookie session err", err)
}
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
sess := globalSessions.SessionStart(w, r)
err = sess.Set("username", "astaxie")
if err != nil {
t.Fatal("set error,", err)
}
if username := sess.Get("username"); username != "astaxie" {
t.Fatal("get username error")
}
sess.SessionRelease(w)
if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" {
t.Fatal("setcookie error")
} else {
parts := strings.Split(strings.TrimSpace(cookiestr), ";")
for k, v := range parts {
nameval := strings.Split(v, "=")
if k == 0 && nameval[0] != "gosessionid" {
t.Fatal("error")
}
}
}
}

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http"
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
@ -60,7 +61,7 @@ func (fs *FileSessionStore) SessionID() string {
return fs.sid return fs.sid
} }
func (fs *FileSessionStore) SessionRelease() { func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) {
defer fs.f.Close() defer fs.f.Close()
b, err := encodeGob(fs.values) b, err := encodeGob(fs.values)
if err != nil { if err != nil {

View File

@ -1,38 +0,0 @@
package session
import (
"bytes"
"encoding/gob"
)
func init() {
gob.Register([]interface{}{})
gob.Register(map[int]interface{}{})
gob.Register(map[string]interface{}{})
gob.Register(map[interface{}]interface{}{})
gob.Register(map[string]string{})
gob.Register(map[int]string{})
gob.Register(map[int]int{})
gob.Register(map[int]int64{})
}
func encodeGob(obj map[interface{}]interface{}) ([]byte, error) {
buf := bytes.NewBuffer(nil)
enc := gob.NewEncoder(buf)
err := enc.Encode(obj)
if err != nil {
return []byte(""), err
}
return buf.Bytes(), nil
}
func decodeGob(encoded []byte) (map[interface{}]interface{}, error) {
buf := bytes.NewBuffer(encoded)
dec := gob.NewDecoder(buf)
var out map[interface{}]interface{}
err := dec.Decode(&out)
if err != nil {
return nil, err
}
return out, nil
}

View File

@ -2,6 +2,7 @@ package session
import ( import (
"container/list" "container/list"
"net/http"
"sync" "sync"
"time" "time"
) )
@ -9,9 +10,9 @@ import (
var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)} var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)}
type MemSessionStore struct { type MemSessionStore struct {
sid string //session id唯一标示 sid string //session id
timeAccessed time.Time //最后访问时间 timeAccessed time.Time //last access time
value map[interface{}]interface{} //session里面存储的值 value map[interface{}]interface{} //session store
lock sync.RWMutex lock sync.RWMutex
} }
@ -51,8 +52,7 @@ func (st *MemSessionStore) SessionID() string {
return st.sid return st.sid
} }
func (st *MemSessionStore) SessionRelease() { func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) {
} }
type MemProvider struct { type MemProvider struct {

35
session/sess_mem_test.go Normal file
View File

@ -0,0 +1,35 @@
package session
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestMem(t *testing.T) {
globalSessions, _ := NewManager("memory", `{"cookieName":"gosessionid","gclifetime":10}`)
go globalSessions.GC()
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
sess := globalSessions.SessionStart(w, r)
defer sess.SessionRelease(w)
err := sess.Set("username", "astaxie")
if err != nil {
t.Fatal("set error,", err)
}
if username := sess.Get("username"); username != "astaxie" {
t.Fatal("get username error")
}
if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" {
t.Fatal("setcookie error")
} else {
parts := strings.Split(strings.TrimSpace(cookiestr), ";")
for k, v := range parts {
nameval := strings.Split(v, "=")
if k == 0 && nameval[0] != "gosessionid" {
t.Fatal("error")
}
}
}
}

View File

@ -9,6 +9,7 @@ package session
import ( import (
"database/sql" "database/sql"
"net/http"
"sync" "sync"
"time" "time"
@ -60,15 +61,15 @@ func (st *MysqlSessionStore) SessionID() string {
return st.sid return st.sid
} }
func (st *MysqlSessionStore) SessionRelease() { func (st *MysqlSessionStore) SessionRelease(w http.ResponseWriter) {
defer st.c.Close() defer st.c.Close()
if len(st.values) > 0 {
b, err := encodeGob(st.values) b, err := encodeGob(st.values)
if err != nil { if err != nil {
return return
} }
st.c.Exec("UPDATE session set `session_data`= ? where session_key=?", b, st.sid) st.c.Exec("UPDATE session set `session_data`=?, `session_expiry`=? where session_key=?",
} b, time.Now().Unix(), st.sid)
} }
type MysqlProvider struct { type MysqlProvider struct {
@ -96,7 +97,8 @@ func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) {
var sessiondata []byte var sessiondata []byte
err := row.Scan(&sessiondata) err := row.Scan(&sessiondata)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", sid, "", time.Now().Unix()) c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)",
sid, "", time.Now().Unix())
} }
var kv map[interface{}]interface{} var kv map[interface{}]interface{}
if len(sessiondata) == 0 { if len(sessiondata) == 0 {
@ -113,6 +115,7 @@ func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) {
func (mp *MysqlProvider) SessionExist(sid string) bool { func (mp *MysqlProvider) SessionExist(sid string) bool {
c := mp.connectInit() c := mp.connectInit()
defer c.Close()
row := c.QueryRow("select session_data from session where session_key=?", sid) row := c.QueryRow("select session_data from session where session_key=?", sid)
var sessiondata []byte var sessiondata []byte
err := row.Scan(&sessiondata) err := row.Scan(&sessiondata)

View File

@ -1,6 +1,7 @@
package session package session
import ( import (
"net/http"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -58,9 +59,8 @@ func (rs *RedisSessionStore) SessionID() string {
return rs.sid return rs.sid
} }
func (rs *RedisSessionStore) SessionRelease() { func (rs *RedisSessionStore) SessionRelease(w http.ResponseWriter) {
defer rs.c.Close() defer rs.c.Close()
if len(rs.values) > 0 {
b, err := encodeGob(rs.values) b, err := encodeGob(rs.values)
if err != nil { if err != nil {
return return
@ -68,7 +68,6 @@ func (rs *RedisSessionStore) SessionRelease() {
rs.c.Do("SET", rs.sid, string(b)) rs.c.Do("SET", rs.sid, string(b))
rs.c.Do("EXPIRE", rs.sid, rs.maxlifetime) rs.c.Do("EXPIRE", rs.sid, rs.maxlifetime)
} }
}
type RedisProvider struct { type RedisProvider struct {
maxlifetime int64 maxlifetime int64

View File

@ -1,6 +1,8 @@
package session package session
import ( import (
"crypto/aes"
"encoding/json"
"testing" "testing"
) )
@ -26,3 +28,82 @@ func Test_gob(t *testing.T) {
t.Error("decode int error") t.Error("decode int error")
} }
} }
func TestGenerate(t *testing.T) {
str := generateRandomKey(20)
if len(str) != 20 {
t.Fatal("generate length is not equal to 20")
}
}
func TestCookieEncodeDecode(t *testing.T) {
hashKey := "testhashKey"
blockkey := generateRandomKey(16)
block, err := aes.NewCipher(blockkey)
if err != nil {
t.Fatal("NewCipher:", err)
}
securityName := string(generateRandomKey(20))
val := make(map[interface{}]interface{})
val["name"] = "astaxie"
val["gender"] = "male"
str, err := encodeCookie(block, hashKey, securityName, val)
if err != nil {
t.Fatal("encodeCookie:", err)
}
dst := make(map[interface{}]interface{})
dst, err = decodeCookie(block, hashKey, securityName, str, 3600)
if err != nil {
t.Fatal("decodeCookie", err)
}
if dst["name"] != "astaxie" {
t.Fatal("dst get map error")
}
if dst["gender"] != "male" {
t.Fatal("dst get map error")
}
}
func TestParseConfig(t *testing.T) {
s := `{"cookieName":"gosessionid","gclifetime":3600}`
cf := new(managerConfig)
cf.EnableSetCookie = true
err := json.Unmarshal([]byte(s), cf)
if err != nil {
t.Fatal("parse json error,", err)
}
if cf.CookieName != "gosessionid" {
t.Fatal("parseconfig get cookiename error")
}
if cf.Gclifetime != 3600 {
t.Fatal("parseconfig get gclifetime error")
}
cc := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
cf2 := new(managerConfig)
cf2.EnableSetCookie = true
err = json.Unmarshal([]byte(cc), cf2)
if err != nil {
t.Fatal("parse json error,", err)
}
if cf2.CookieName != "gosessionid" {
t.Fatal("parseconfig get cookiename error")
}
if cf2.Gclifetime != 3600 {
t.Fatal("parseconfig get gclifetime error")
}
if cf2.EnableSetCookie != false {
t.Fatal("parseconfig get enableSetCookie error")
}
cconfig := new(cookieConfig)
err = json.Unmarshal([]byte(cf2.ProviderConfig), cconfig)
if err != nil {
t.Fatal("parse ProviderConfig err,", err)
}
if cconfig.CookieName != "gosessionid" {
t.Fatal("ProviderConfig get cookieName error")
}
if cconfig.SecurityKey != "beegocookiehashkey" {
t.Fatal("ProviderConfig get securityKey error")
}
}

188
session/sess_utils.go Normal file
View File

@ -0,0 +1,188 @@
package session
import (
"bytes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha1"
"crypto/subtle"
"encoding/base64"
"encoding/gob"
"errors"
"fmt"
"io"
"strconv"
"time"
)
func init() {
gob.Register([]interface{}{})
gob.Register(map[int]interface{}{})
gob.Register(map[string]interface{}{})
gob.Register(map[interface{}]interface{}{})
gob.Register(map[string]string{})
gob.Register(map[int]string{})
gob.Register(map[int]int{})
gob.Register(map[int]int64{})
}
func encodeGob(obj map[interface{}]interface{}) ([]byte, error) {
buf := bytes.NewBuffer(nil)
enc := gob.NewEncoder(buf)
err := enc.Encode(obj)
if err != nil {
return []byte(""), err
}
return buf.Bytes(), nil
}
func decodeGob(encoded []byte) (map[interface{}]interface{}, error) {
buf := bytes.NewBuffer(encoded)
dec := gob.NewDecoder(buf)
var out map[interface{}]interface{}
err := dec.Decode(&out)
if err != nil {
return nil, err
}
return out, nil
}
// generateRandomKey creates a random key with the given strength.
func generateRandomKey(strength int) []byte {
k := make([]byte, strength)
if _, err := io.ReadFull(rand.Reader, k); err != nil {
return nil
}
return k
}
// Encryption -----------------------------------------------------------------
// encrypt encrypts a value using the given block in counter mode.
//
// A random initialization vector (http://goo.gl/zF67k) with the length of the
// block size is prepended to the resulting ciphertext.
func encrypt(block cipher.Block, value []byte) ([]byte, error) {
iv := generateRandomKey(block.BlockSize())
if iv == nil {
return nil, errors.New("encrypt: failed to generate random iv")
}
// Encrypt it.
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(value, value)
// Return iv + ciphertext.
return append(iv, value...), nil
}
// decrypt decrypts a value using the given block in counter mode.
//
// The value to be decrypted must be prepended by a initialization vector
// (http://goo.gl/zF67k) with the length of the block size.
func decrypt(block cipher.Block, value []byte) ([]byte, error) {
size := block.BlockSize()
if len(value) > size {
// Extract iv.
iv := value[:size]
// Extract ciphertext.
value = value[size:]
// Decrypt it.
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(value, value)
return value, nil
}
return nil, errors.New("decrypt: the value could not be decrypted")
}
func encodeCookie(block cipher.Block, hashKey, name string, value map[interface{}]interface{}) (string, error) {
var err error
var b []byte
// 1. encodeGob.
if b, err = encodeGob(value); err != nil {
return "", err
}
// 2. Encrypt (optional).
if b, err = encrypt(block, b); err != nil {
return "", err
}
b = encode(b)
// 3. Create MAC for "name|date|value". Extra pipe to be used later.
b = []byte(fmt.Sprintf("%s|%d|%s|", name, time.Now().UTC().Unix(), b))
h := hmac.New(sha1.New, []byte(hashKey))
h.Write(b)
sig := h.Sum(nil)
// Append mac, remove name.
b = append(b, sig...)[len(name)+1:]
// 4. Encode to base64.
b = encode(b)
// Done.
return string(b), nil
}
func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime int64) (map[interface{}]interface{}, error) {
// 1. Decode from base64.
b, err := decode([]byte(value))
if err != nil {
return nil, err
}
// 2. Verify MAC. Value is "date|value|mac".
parts := bytes.SplitN(b, []byte("|"), 3)
if len(parts) != 3 {
return nil, errors.New("Decode: invalid value %v")
}
b = append([]byte(name+"|"), b[:len(b)-len(parts[2])]...)
h := hmac.New(sha1.New, []byte(hashKey))
h.Write(b)
sig := h.Sum(nil)
if len(sig) != len(parts[2]) || subtle.ConstantTimeCompare(sig, parts[2]) != 1 {
return nil, errors.New("Decode: the value is not valid")
}
// 3. Verify date ranges.
var t1 int64
if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil {
return nil, errors.New("Decode: invalid timestamp")
}
t2 := time.Now().UTC().Unix()
if t1 > t2 {
return nil, errors.New("Decode: timestamp is too new")
}
if t1 < t2-gcmaxlifetime {
return nil, errors.New("Decode: expired timestamp")
}
// 4. Decrypt (optional).
b, err = decode(parts[1])
if err != nil {
return nil, err
}
if b, err = decrypt(block, b); err != nil {
return nil, err
}
// 5. decodeGob.
if dst, err := decodeGob(b); err != nil {
return nil, err
} else {
return dst, nil
}
// Done.
return nil, nil
}
// Encoding -------------------------------------------------------------------
// encode encodes a value using base64.
func encode(value []byte) []byte {
encoded := make([]byte, base64.URLEncoding.EncodedLen(len(value)))
base64.URLEncoding.Encode(encoded, value)
return encoded
}
// decode decodes a cookie using base64.
func decode(value []byte) ([]byte, error) {
decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value)))
b, err := base64.URLEncoding.Decode(decoded, value)
if err != nil {
return nil, err
}
return decoded[:b], nil
}

View File

@ -6,6 +6,7 @@ import (
"crypto/rand" "crypto/rand"
"crypto/sha1" "crypto/sha1"
"encoding/hex" "encoding/hex"
"encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -18,12 +19,12 @@ type SessionStore interface {
Get(key interface{}) interface{} //get session value Get(key interface{}) interface{} //get session value
Delete(key interface{}) error //delete session value Delete(key interface{}) error //delete session value
SessionID() string //back current sessionID SessionID() string //back current sessionID
SessionRelease() // release the resource & save data to provider SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data
Flush() error //delete all data Flush() error //delete all data
} }
type Provider interface { type Provider interface {
SessionInit(maxlifetime int64, savePath string) error SessionInit(gclifetime int64, config string) error
SessionRead(sid string) (SessionStore, error) SessionRead(sid string) (SessionStore, error)
SessionExist(sid string) bool SessionExist(sid string) bool
SessionRegenerate(oldsid, sid string) (SessionStore, error) SessionRegenerate(oldsid, sid string) (SessionStore, error)
@ -47,15 +48,22 @@ func Register(name string, provide Provider) {
provides[name] = provide provides[name] = provide
} }
type managerConfig struct {
CookieName string `json:"cookieName"`
EnableSetCookie bool `json:"enableSetCookie,omitempty"`
Gclifetime int64 `json:"gclifetime"`
Maxlifetime int64 `json:"maxLifetime"`
Maxage int `json:"maxage"`
Secure bool `json:"secure"`
SessionIDHashFunc string `json:"sessionIDHashFunc"`
SessionIDHashKey string `json:"sessionIDHashKey"`
CookieLifeTime int64 `json:"cookieLifeTime"`
ProviderConfig string `json:"providerConfig"`
}
type Manager struct { type Manager struct {
cookieName string //private cookiename
provider Provider provider Provider
maxlifetime int64 config *managerConfig
hashfunc string //support md5 & sha1
hashkey string
maxage int //cookielifetime
secure bool
options []interface{}
} }
//options //options
@ -63,74 +71,54 @@ type Manager struct {
//2. hashfunc default sha1 //2. hashfunc default sha1
//3. hashkey default beegosessionkey //3. hashkey default beegosessionkey
//4. maxage default is none //4. maxage default is none
func NewManager(provideName, cookieName string, maxlifetime int64, savePath string, options ...interface{}) (*Manager, error) { func NewManager(provideName, config string) (*Manager, error) {
provider, ok := provides[provideName] provider, ok := provides[provideName]
if !ok { if !ok {
return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName) return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName)
} }
provider.SessionInit(maxlifetime, savePath) cf := new(managerConfig)
secure := false cf.EnableSetCookie = true
if len(options) > 0 { err := json.Unmarshal([]byte(config), cf)
secure = options[0].(bool) if err != nil {
return nil, err
} }
hashfunc := "sha1" if cf.Maxlifetime == 0 {
if len(options) > 1 { cf.Maxlifetime = cf.Gclifetime
hashfunc = options[1].(string)
} }
hashkey := "beegosessionkey" err = provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig)
if len(options) > 2 { if err != nil {
hashkey = options[2].(string) return nil, err
}
maxage := -1
if len(options) > 3 {
switch options[3].(type) {
case int:
if options[3].(int) > 0 {
maxage = options[3].(int)
} else if options[3].(int) < 0 {
maxage = 0
}
case int64:
if options[3].(int64) > 0 {
maxage = int(options[3].(int64))
} else if options[3].(int64) < 0 {
maxage = 0
}
case int32:
if options[3].(int32) > 0 {
maxage = int(options[3].(int32))
} else if options[3].(int32) < 0 {
maxage = 0
} }
if cf.SessionIDHashFunc == "" {
cf.SessionIDHashFunc = "sha1"
} }
if cf.SessionIDHashKey == "" {
cf.SessionIDHashKey = string(generateRandomKey(16))
} }
return &Manager{ return &Manager{
provider: provider, provider,
cookieName: cookieName, cf,
maxlifetime: maxlifetime,
hashfunc: hashfunc,
hashkey: hashkey,
maxage: maxage,
secure: secure,
options: options,
}, nil }, nil
} }
//get Session //get Session
func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore) { func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore) {
cookie, err := r.Cookie(manager.cookieName) cookie, err := r.Cookie(manager.config.CookieName)
if err != nil || cookie.Value == "" { if err != nil || cookie.Value == "" {
sid := manager.sessionId(r) sid := manager.sessionId(r)
session, _ = manager.provider.SessionRead(sid) session, _ = manager.provider.SessionRead(sid)
cookie = &http.Cookie{Name: manager.cookieName, cookie = &http.Cookie{Name: manager.config.CookieName,
Value: url.QueryEscape(sid), Value: url.QueryEscape(sid),
Path: "/", Path: "/",
HttpOnly: true, HttpOnly: true,
Secure: manager.secure} Secure: manager.config.Secure}
if manager.maxage >= 0 { if manager.config.Maxage >= 0 {
cookie.MaxAge = manager.maxage cookie.MaxAge = manager.config.Maxage
} }
if manager.config.EnableSetCookie {
http.SetCookie(w, cookie) http.SetCookie(w, cookie)
}
r.AddCookie(cookie) r.AddCookie(cookie)
} else { } else {
sid, _ := url.QueryUnescape(cookie.Value) sid, _ := url.QueryUnescape(cookie.Value)
@ -139,15 +127,17 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
} else { } else {
sid = manager.sessionId(r) sid = manager.sessionId(r)
session, _ = manager.provider.SessionRead(sid) session, _ = manager.provider.SessionRead(sid)
cookie = &http.Cookie{Name: manager.cookieName, cookie = &http.Cookie{Name: manager.config.CookieName,
Value: url.QueryEscape(sid), Value: url.QueryEscape(sid),
Path: "/", Path: "/",
HttpOnly: true, HttpOnly: true,
Secure: manager.secure} Secure: manager.config.Secure}
if manager.maxage >= 0 { if manager.config.Maxage >= 0 {
cookie.MaxAge = manager.maxage cookie.MaxAge = manager.config.Maxage
} }
if manager.config.EnableSetCookie {
http.SetCookie(w, cookie) http.SetCookie(w, cookie)
}
r.AddCookie(cookie) r.AddCookie(cookie)
} }
} }
@ -156,13 +146,17 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
//Destroy sessionid //Destroy sessionid
func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(manager.cookieName) cookie, err := r.Cookie(manager.config.CookieName)
if err != nil || cookie.Value == "" { if err != nil || cookie.Value == "" {
return return
} else { } else {
manager.provider.SessionDestroy(cookie.Value) manager.provider.SessionDestroy(cookie.Value)
expiration := time.Now() expiration := time.Now()
cookie := http.Cookie{Name: manager.cookieName, Path: "/", HttpOnly: true, Expires: expiration, MaxAge: -1} cookie := http.Cookie{Name: manager.config.CookieName,
Path: "/",
HttpOnly: true,
Expires: expiration,
MaxAge: -1}
http.SetCookie(w, &cookie) http.SetCookie(w, &cookie)
} }
} }
@ -174,20 +168,20 @@ func (manager *Manager) GetProvider(sid string) (sessions SessionStore, err erro
func (manager *Manager) GC() { func (manager *Manager) GC() {
manager.provider.SessionGC() manager.provider.SessionGC()
time.AfterFunc(time.Duration(manager.maxlifetime)*time.Second, func() { manager.GC() }) time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() })
} }
func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Request) (session SessionStore) { func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Request) (session SessionStore) {
sid := manager.sessionId(r) sid := manager.sessionId(r)
cookie, err := r.Cookie(manager.cookieName) cookie, err := r.Cookie(manager.config.CookieName)
if err != nil && cookie.Value == "" { if err != nil && cookie.Value == "" {
//delete old cookie //delete old cookie
session, _ = manager.provider.SessionRead(sid) session, _ = manager.provider.SessionRead(sid)
cookie = &http.Cookie{Name: manager.cookieName, cookie = &http.Cookie{Name: manager.config.CookieName,
Value: url.QueryEscape(sid), Value: url.QueryEscape(sid),
Path: "/", Path: "/",
HttpOnly: true, HttpOnly: true,
Secure: manager.secure, Secure: manager.config.Secure,
} }
} else { } else {
oldsid, _ := url.QueryUnescape(cookie.Value) oldsid, _ := url.QueryUnescape(cookie.Value)
@ -196,8 +190,8 @@ func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Reque
cookie.HttpOnly = true cookie.HttpOnly = true
cookie.Path = "/" cookie.Path = "/"
} }
if manager.maxage >= 0 { if manager.config.Maxage >= 0 {
cookie.MaxAge = manager.maxage cookie.MaxAge = manager.config.Maxage
} }
http.SetCookie(w, cookie) http.SetCookie(w, cookie)
r.AddCookie(cookie) r.AddCookie(cookie)
@ -209,12 +203,12 @@ func (manager *Manager) GetActiveSession() int {
} }
func (manager *Manager) SetHashFunc(hasfunc, hashkey string) { func (manager *Manager) SetHashFunc(hasfunc, hashkey string) {
manager.hashfunc = hasfunc manager.config.SessionIDHashFunc = hasfunc
manager.hashkey = hashkey manager.config.SessionIDHashKey = hashkey
} }
func (manager *Manager) SetSecure(secure bool) { func (manager *Manager) SetSecure(secure bool) {
manager.secure = secure manager.config.Secure = secure
} }
//remote_addr cruunixnano randdata //remote_addr cruunixnano randdata
@ -224,16 +218,16 @@ func (manager *Manager) sessionId(r *http.Request) (sid string) {
return "" return ""
} }
sig := fmt.Sprintf("%s%d%s", r.RemoteAddr, time.Now().UnixNano(), bs) sig := fmt.Sprintf("%s%d%s", r.RemoteAddr, time.Now().UnixNano(), bs)
if manager.hashfunc == "md5" { if manager.config.SessionIDHashFunc == "md5" {
h := md5.New() h := md5.New()
h.Write([]byte(sig)) h.Write([]byte(sig))
sid = hex.EncodeToString(h.Sum(nil)) sid = hex.EncodeToString(h.Sum(nil))
} else if manager.hashfunc == "sha1" { } else if manager.config.SessionIDHashFunc == "sha1" {
h := hmac.New(sha1.New, []byte(manager.hashkey)) h := hmac.New(sha1.New, []byte(manager.config.SessionIDHashKey))
fmt.Fprintf(h, "%s", sig) fmt.Fprintf(h, "%s", sig)
sid = hex.EncodeToString(h.Sum(nil)) sid = hex.EncodeToString(h.Sum(nil))
} else { } else {
h := hmac.New(sha1.New, []byte(manager.hashkey)) h := hmac.New(sha1.New, []byte(manager.config.SessionIDHashKey))
fmt.Fprintf(h, "%s", sig) fmt.Fprintf(h, "%s", sig)
sid = hex.EncodeToString(h.Sum(nil)) sid = hex.EncodeToString(h.Sum(nil))
} }

45
utils/captcha/README.md Normal file
View File

@ -0,0 +1,45 @@
# Captcha
an example for use captcha
```
package controllers
import (
"github.com/astaxie/beego"
"github.com/astaxie/beego/cache"
"github.com/astaxie/beego/utils/captcha"
)
var cpt *captcha.Captcha
func init() {
// use beego cache system store the captcha data
store := cache.NewMemoryCache()
cpt = captcha.NewWithFilter("/captcha/", store)
}
type MainController struct {
beego.Controller
}
func (this *MainController) Get() {
this.TplNames = "index.tpl"
}
func (this *MainController) Post() {
this.TplNames = "index.tpl"
this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request)
}
```
template usage
```
{{.Success}}
<form action="/" method="post">
{{create_captcha}}
<input name="captcha" type="text">
</form>
```

248
utils/captcha/captcha.go Normal file
View File

@ -0,0 +1,248 @@
// an example for use captcha
//
// ```
// package controllers
//
// import (
// "github.com/astaxie/beego"
// "github.com/astaxie/beego/cache"
// "github.com/astaxie/beego/utils/captcha"
// )
//
// var cpt *captcha.Captcha
//
// func init() {
// // use beego cache system store the captcha data
// store := cache.NewMemoryCache()
// cpt = captcha.NewWithFilter("/captcha/", store)
// }
//
// type MainController struct {
// beego.Controller
// }
//
// func (this *MainController) Get() {
// this.TplNames = "index.tpl"
// }
//
// func (this *MainController) Post() {
// this.TplNames = "index.tpl"
//
// this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request)
// }
// ```
//
// template usage
//
// ```
// {{.Success}}
// <form action="/" method="post">
// {{create_captcha}}
// <input name="captcha" type="text">
// </form>
// ```
package captcha
import (
"fmt"
"html/template"
"net/http"
"path"
"strings"
"github.com/astaxie/beego"
"github.com/astaxie/beego/cache"
"github.com/astaxie/beego/context"
"github.com/astaxie/beego/utils"
)
var (
defaultChars = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
)
const (
// default captcha attributes
challengeNums = 6
expiration = 600
fieldIdName = "captcha_id"
fieldCaptchaName = "captcha"
cachePrefix = "captcha_"
urlPrefix = "/captcha/"
)
type Captcha struct {
// beego cache store
store cache.Cache
// url prefix for captcha image
urlPrefix string
// specify captcha id input field name
FieldIdName string
// specify captcha result input field name
FieldCaptchaName string
// captcha image width and height
StdWidth int
StdHeight int
// captcha chars nums
ChallengeNums int
// captcha expiration seconds
Expiration int64
// cache key prefix
CachePrefix string
}
func (c *Captcha) key(id string) string {
return c.CachePrefix + id
}
func (c *Captcha) genRandChars() []byte {
return utils.RandomCreateBytes(c.ChallengeNums, defaultChars...)
}
// beego filter handler for serve captcha image
func (c *Captcha) Handler(ctx *context.Context) {
var chars []byte
id := path.Base(ctx.Request.RequestURI)
if i := strings.Index(id, "."); i != -1 {
id = id[:i]
}
key := c.key(id)
if v, ok := c.store.Get(key).([]byte); ok {
chars = v
} else {
ctx.Output.SetStatus(404)
ctx.WriteString("captcha not found")
return
}
// reload captcha
if len(ctx.Input.Query("reload")) > 0 {
chars = c.genRandChars()
if err := c.store.Put(key, chars, c.Expiration); err != nil {
ctx.Output.SetStatus(500)
ctx.WriteString("captcha reload error")
beego.Error("Reload Create Captcha Error:", err)
return
}
}
img := NewImage(chars, c.StdWidth, c.StdHeight)
if _, err := img.WriteTo(ctx.ResponseWriter); err != nil {
beego.Error("Write Captcha Image Error:", err)
}
}
// tempalte func for output html
func (c *Captcha) CreateCaptchaHtml() template.HTML {
value, err := c.CreateCaptcha()
if err != nil {
beego.Error("Create Captcha Error:", err)
return ""
}
// create html
return template.HTML(fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`+
`<a class="captcha" href="javascript:">`+
`<img onclick="this.src=('%s%s.png?reload='+(new Date()).getTime())" class="captcha-img" src="%s%s.png">`+
`</a>`, c.FieldIdName, value, c.urlPrefix, value, c.urlPrefix, value))
}
// create a new captcha id
func (c *Captcha) CreateCaptcha() (string, error) {
// generate captcha id
id := string(utils.RandomCreateBytes(15))
// get the captcha chars
chars := c.genRandChars()
// save to store
if err := c.store.Put(c.key(id), chars, c.Expiration); err != nil {
return "", err
}
return id, nil
}
// verify from a request
func (c *Captcha) VerifyReq(req *http.Request) bool {
req.ParseForm()
return c.Verify(req.Form.Get(c.FieldIdName), req.Form.Get(c.FieldCaptchaName))
}
// direct verify id and challenge string
func (c *Captcha) Verify(id string, challenge string) (success bool) {
if len(challenge) == 0 || len(id) == 0 {
return
}
var chars []byte
key := c.key(id)
if v, ok := c.store.Get(key).([]byte); ok && len(v) == len(challenge) {
chars = v
} else {
return
}
defer func() {
// finally remove it
c.store.Delete(key)
}()
// verify challenge
for i, c := range chars {
if c != challenge[i]-48 {
return
}
}
return true
}
// create a new captcha.Captcha
func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha {
cpt := &Captcha{}
cpt.store = store
cpt.FieldIdName = fieldIdName
cpt.FieldCaptchaName = fieldCaptchaName
cpt.ChallengeNums = challengeNums
cpt.Expiration = expiration
cpt.CachePrefix = cachePrefix
cpt.StdWidth = stdWidth
cpt.StdHeight = stdHeight
if len(urlPrefix) == 0 {
urlPrefix = urlPrefix
}
if urlPrefix[len(urlPrefix)-1] != '/' {
urlPrefix += "/"
}
cpt.urlPrefix = urlPrefix
return cpt
}
// create a new captcha.Captcha and auto AddFilter for serve captacha image
// and add a tempalte func for output html
func NewWithFilter(urlPrefix string, store cache.Cache) *Captcha {
cpt := NewCaptcha(urlPrefix, store)
// create filter for serve captcha image
beego.AddFilter(urlPrefix+":", "BeforeRouter", cpt.Handler)
// add to template func map
beego.AddFuncMap("create_captcha", cpt.CreateCaptchaHtml)
return cpt
}

484
utils/captcha/image.go Normal file
View File

@ -0,0 +1,484 @@
// modifiy and integrated to Beego from https://github.com/dchest/captcha
package captcha
import (
"bytes"
"image"
"image/color"
"image/png"
"io"
"math"
)
const (
fontWidth = 11
fontHeight = 18
blackChar = 1
// Standard width and height of a captcha image.
stdWidth = 240
stdHeight = 80
// Maximum absolute skew factor of a single digit.
maxSkew = 0.7
// Number of background circles.
circleCount = 20
)
var font = [][]byte{
{ // 0
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0,
0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0,
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1,
0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0,
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
},
{ // 1
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,
0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0,
0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0,
0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0,
0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
},
{ // 2
0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0,
0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,
0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0,
0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
},
{ // 3
0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0,
0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
},
{ // 4
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0,
0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0,
0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0,
0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0,
0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0,
0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0,
0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0,
0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0,
1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0,
1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
},
{ // 5
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0,
},
{ // 6
0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0,
0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0,
0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0,
1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0,
1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0,
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1,
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
},
{ // 7
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0,
0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0,
},
{ // 8
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1,
0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1,
0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1,
0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0,
0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0,
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0,
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0,
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
},
{ // 9
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0,
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1,
0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1,
0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,
0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
},
}
type Image struct {
*image.Paletted
numWidth int
numHeight int
dotSize int
}
var prng = &siprng{}
// randIntn returns a pseudorandom non-negative int in range [0, n).
func randIntn(n int) int {
return prng.Intn(n)
}
// randInt returns a pseudorandom int in range [from, to].
func randInt(from, to int) int {
return prng.Intn(to+1-from) + from
}
// randFloat returns a pseudorandom float64 in range [from, to].
func randFloat(from, to float64) float64 {
return (to-from)*prng.Float64() + from
}
func randomPalette() color.Palette {
p := make([]color.Color, circleCount+1)
// Transparent color.
p[0] = color.RGBA{0xFF, 0xFF, 0xFF, 0x00}
// Primary color.
prim := color.RGBA{
uint8(randIntn(129)),
uint8(randIntn(129)),
uint8(randIntn(129)),
0xFF,
}
p[1] = prim
// Circle colors.
for i := 2; i <= circleCount; i++ {
p[i] = randomBrightness(prim, 255)
}
return p
}
// NewImage returns a new captcha image of the given width and height with the
// given digits, where each digit must be in range 0-9.
func NewImage(digits []byte, width, height int) *Image {
m := new(Image)
m.Paletted = image.NewPaletted(image.Rect(0, 0, width, height), randomPalette())
m.calculateSizes(width, height, len(digits))
// Randomly position captcha inside the image.
maxx := width - (m.numWidth+m.dotSize)*len(digits) - m.dotSize
maxy := height - m.numHeight - m.dotSize*2
var border int
if width > height {
border = height / 5
} else {
border = width / 5
}
x := randInt(border, maxx-border)
y := randInt(border, maxy-border)
// Draw digits.
for _, n := range digits {
m.drawDigit(font[n], x, y)
x += m.numWidth + m.dotSize
}
// Draw strike-through line.
m.strikeThrough()
// Apply wave distortion.
m.distort(randFloat(5, 10), randFloat(100, 200))
// Fill image with random circles.
m.fillWithCircles(circleCount, m.dotSize)
return m
}
// encodedPNG encodes an image to PNG and returns
// the result as a byte slice.
func (m *Image) encodedPNG() []byte {
var buf bytes.Buffer
if err := png.Encode(&buf, m.Paletted); err != nil {
panic(err.Error())
}
return buf.Bytes()
}
// WriteTo writes captcha image in PNG format into the given writer.
func (m *Image) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(m.encodedPNG())
return int64(n), err
}
func (m *Image) calculateSizes(width, height, ncount int) {
// Goal: fit all digits inside the image.
var border int
if width > height {
border = height / 4
} else {
border = width / 4
}
// Convert everything to floats for calculations.
w := float64(width - border*2)
h := float64(height - border*2)
// fw takes into account 1-dot spacing between digits.
fw := float64(fontWidth + 1)
fh := float64(fontHeight)
nc := float64(ncount)
// Calculate the width of a single digit taking into account only the
// width of the image.
nw := w / nc
// Calculate the height of a digit from this width.
nh := nw * fh / fw
// Digit too high?
if nh > h {
// Fit digits based on height.
nh = h
nw = fw / fh * nh
}
// Calculate dot size.
m.dotSize = int(nh / fh)
// Save everything, making the actual width smaller by 1 dot to account
// for spacing between digits.
m.numWidth = int(nw) - m.dotSize
m.numHeight = int(nh)
}
func (m *Image) drawHorizLine(fromX, toX, y int, colorIdx uint8) {
for x := fromX; x <= toX; x++ {
m.SetColorIndex(x, y, colorIdx)
}
}
func (m *Image) drawCircle(x, y, radius int, colorIdx uint8) {
f := 1 - radius
dfx := 1
dfy := -2 * radius
xo := 0
yo := radius
m.SetColorIndex(x, y+radius, colorIdx)
m.SetColorIndex(x, y-radius, colorIdx)
m.drawHorizLine(x-radius, x+radius, y, colorIdx)
for xo < yo {
if f >= 0 {
yo--
dfy += 2
f += dfy
}
xo++
dfx += 2
f += dfx
m.drawHorizLine(x-xo, x+xo, y+yo, colorIdx)
m.drawHorizLine(x-xo, x+xo, y-yo, colorIdx)
m.drawHorizLine(x-yo, x+yo, y+xo, colorIdx)
m.drawHorizLine(x-yo, x+yo, y-xo, colorIdx)
}
}
func (m *Image) fillWithCircles(n, maxradius int) {
maxx := m.Bounds().Max.X
maxy := m.Bounds().Max.Y
for i := 0; i < n; i++ {
colorIdx := uint8(randInt(1, circleCount-1))
r := randInt(1, maxradius)
m.drawCircle(randInt(r, maxx-r), randInt(r, maxy-r), r, colorIdx)
}
}
func (m *Image) strikeThrough() {
maxx := m.Bounds().Max.X
maxy := m.Bounds().Max.Y
y := randInt(maxy/3, maxy-maxy/3)
amplitude := randFloat(5, 20)
period := randFloat(80, 180)
dx := 2.0 * math.Pi / period
for x := 0; x < maxx; x++ {
xo := amplitude * math.Cos(float64(y)*dx)
yo := amplitude * math.Sin(float64(x)*dx)
for yn := 0; yn < m.dotSize; yn++ {
r := randInt(0, m.dotSize)
m.drawCircle(x+int(xo), y+int(yo)+(yn*m.dotSize), r/2, 1)
}
}
}
func (m *Image) drawDigit(digit []byte, x, y int) {
skf := randFloat(-maxSkew, maxSkew)
xs := float64(x)
r := m.dotSize / 2
y += randInt(-r, r)
for yo := 0; yo < fontHeight; yo++ {
for xo := 0; xo < fontWidth; xo++ {
if digit[yo*fontWidth+xo] != blackChar {
continue
}
m.drawCircle(x+xo*m.dotSize, y+yo*m.dotSize, r, 1)
}
xs += skf
x = int(xs)
}
}
func (m *Image) distort(amplude float64, period float64) {
w := m.Bounds().Max.X
h := m.Bounds().Max.Y
oldm := m.Paletted
newm := image.NewPaletted(image.Rect(0, 0, w, h), oldm.Palette)
dx := 2.0 * math.Pi / period
for x := 0; x < w; x++ {
for y := 0; y < h; y++ {
xo := amplude * math.Sin(float64(y)*dx)
yo := amplude * math.Cos(float64(x)*dx)
newm.SetColorIndex(x, y, oldm.ColorIndexAt(x+int(xo), y+int(yo)))
}
}
m.Paletted = newm
}
func randomBrightness(c color.RGBA, max uint8) color.RGBA {
minc := min3(c.R, c.G, c.B)
maxc := max3(c.R, c.G, c.B)
if maxc > max {
return c
}
n := randIntn(int(max-maxc)) - int(minc)
return color.RGBA{
uint8(int(c.R) + n),
uint8(int(c.G) + n),
uint8(int(c.B) + n),
uint8(c.A),
}
}
func min3(x, y, z uint8) (m uint8) {
m = x
if y < m {
m = y
}
if z < m {
m = z
}
return
}
func max3(x, y, z uint8) (m uint8) {
m = x
if y > m {
m = y
}
if z > m {
m = z
}
return
}

View File

@ -0,0 +1,38 @@
package captcha
import (
"testing"
"github.com/astaxie/beego/utils"
)
type byteCounter struct {
n int64
}
func (bc *byteCounter) Write(b []byte) (int, error) {
bc.n += int64(len(b))
return len(b), nil
}
func BenchmarkNewImage(b *testing.B) {
b.StopTimer()
d := utils.RandomCreateBytes(challengeNums, defaultChars...)
b.StartTimer()
for i := 0; i < b.N; i++ {
NewImage(d, stdWidth, stdHeight)
}
}
func BenchmarkImageWriteTo(b *testing.B) {
b.StopTimer()
d := utils.RandomCreateBytes(challengeNums, defaultChars...)
b.StartTimer()
counter := &byteCounter{}
for i := 0; i < b.N; i++ {
img := NewImage(d, stdWidth, stdHeight)
img.WriteTo(counter)
b.SetBytes(counter.n)
counter.n = 0
}
}

264
utils/captcha/siprng.go Normal file
View File

@ -0,0 +1,264 @@
// modifiy and integrated to Beego from https://github.com/dchest/captcha
package captcha
import (
"crypto/rand"
"encoding/binary"
"io"
"sync"
)
// siprng is PRNG based on SipHash-2-4.
type siprng struct {
mu sync.Mutex
k0, k1, ctr uint64
}
// siphash implements SipHash-2-4, accepting a uint64 as a message.
func siphash(k0, k1, m uint64) uint64 {
// Initialization.
v0 := k0 ^ 0x736f6d6570736575
v1 := k1 ^ 0x646f72616e646f6d
v2 := k0 ^ 0x6c7967656e657261
v3 := k1 ^ 0x7465646279746573
t := uint64(8) << 56
// Compression.
v3 ^= m
// Round 1.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
// Round 2.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
v0 ^= m
// Compress last block.
v3 ^= t
// Round 1.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
// Round 2.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
v0 ^= t
// Finalization.
v2 ^= 0xff
// Round 1.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
// Round 2.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
// Round 3.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
// Round 4.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
return v0 ^ v1 ^ v2 ^ v3
}
// rekey sets a new PRNG key, which is read from crypto/rand.
func (p *siprng) rekey() {
var k [16]byte
if _, err := io.ReadFull(rand.Reader, k[:]); err != nil {
panic(err.Error())
}
p.k0 = binary.LittleEndian.Uint64(k[0:8])
p.k1 = binary.LittleEndian.Uint64(k[8:16])
p.ctr = 1
}
// Uint64 returns a new pseudorandom uint64.
// It rekeys PRNG on the first call and every 64 MB of generated data.
func (p *siprng) Uint64() uint64 {
p.mu.Lock()
if p.ctr == 0 || p.ctr > 8*1024*1024 {
p.rekey()
}
v := siphash(p.k0, p.k1, p.ctr)
p.ctr++
p.mu.Unlock()
return v
}
func (p *siprng) Int63() int64 {
return int64(p.Uint64() & 0x7fffffffffffffff)
}
func (p *siprng) Uint32() uint32 {
return uint32(p.Uint64())
}
func (p *siprng) Int31() int32 {
return int32(p.Uint32() & 0x7fffffff)
}
func (p *siprng) Intn(n int) int {
if n <= 0 {
panic("invalid argument to Intn")
}
if n <= 1<<31-1 {
return int(p.Int31n(int32(n)))
}
return int(p.Int63n(int64(n)))
}
func (p *siprng) Int63n(n int64) int64 {
if n <= 0 {
panic("invalid argument to Int63n")
}
max := int64((1 << 63) - 1 - (1<<63)%uint64(n))
v := p.Int63()
for v > max {
v = p.Int63()
}
return v % n
}
func (p *siprng) Int31n(n int32) int32 {
if n <= 0 {
panic("invalid argument to Int31n")
}
max := int32((1 << 31) - 1 - (1<<31)%uint32(n))
v := p.Int31()
for v > max {
v = p.Int31()
}
return v % n
}
func (p *siprng) Float64() float64 { return float64(p.Int63()) / (1 << 63) }

View File

@ -0,0 +1,19 @@
package captcha
import "testing"
func TestSiphash(t *testing.T) {
good := uint64(0xe849e8bb6ffe2567)
cur := siphash(0, 0, 0)
if cur != good {
t.Fatalf("siphash: expected %x, got %x", good, cur)
}
}
func BenchmarkSiprng(b *testing.B) {
b.SetBytes(8)
p := &siprng{}
for i := 0; i < b.N; i++ {
p.Uint64()
}
}

299
utils/mail.go Normal file
View File

@ -0,0 +1,299 @@
package utils
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"mime"
"mime/multipart"
"net/mail"
"net/smtp"
"net/textproto"
"os"
"path"
"path/filepath"
"strconv"
"strings"
)
const (
maxLineLength = 76
)
// Email is the type used for email messages
type Email struct {
Auth smtp.Auth
Identity string `json:"identity"`
Username string `json:"username"`
Password string `json:"password"`
Host string `json:"host"`
Port int `json:"port"`
From string `json:"from"`
To []string
Bcc []string
Cc []string
Subject string
Text string // Plaintext message (optional)
HTML string // Html message (optional)
Headers textproto.MIMEHeader
Attachments []*Attachment
ReadReceipt []string
}
// Attachment is a struct representing an email attachment.
// Based on the mime/multipart.FileHeader struct, Attachment contains the name, MIMEHeader, and content of the attachment in question
type Attachment struct {
Filename string
Header textproto.MIMEHeader
Content []byte
}
func NewEMail(config string) *Email {
e := new(Email)
e.Headers = textproto.MIMEHeader{}
err := json.Unmarshal([]byte(config), e)
if err != nil {
return nil
}
if e.From == "" {
e.From = e.Username
}
return e
}
// make all send information to byte
func (e *Email) Bytes() ([]byte, error) {
buff := &bytes.Buffer{}
w := multipart.NewWriter(buff)
// Set the appropriate headers (overwriting any conflicts)
// Leave out Bcc (only included in envelope headers)
e.Headers.Set("To", strings.Join(e.To, ","))
if e.Cc != nil {
e.Headers.Set("Cc", strings.Join(e.Cc, ","))
}
e.Headers.Set("From", e.From)
e.Headers.Set("Subject", e.Subject)
if len(e.ReadReceipt) != 0 {
e.Headers.Set("Disposition-Notification-To", strings.Join(e.ReadReceipt, ","))
}
e.Headers.Set("MIME-Version", "1.0")
e.Headers.Set("Content-Type", fmt.Sprintf("multipart/mixed;\r\n boundary=%s\r\n", w.Boundary()))
// Write the envelope headers (including any custom headers)
if err := headerToBytes(buff, e.Headers); err != nil {
return nil, fmt.Errorf("Failed to render message headers: %s", err)
}
// Start the multipart/mixed part
fmt.Fprintf(buff, "--%s\r\n", w.Boundary())
header := textproto.MIMEHeader{}
// Check to see if there is a Text or HTML field
if e.Text != "" || e.HTML != "" {
subWriter := multipart.NewWriter(buff)
// Create the multipart alternative part
header.Set("Content-Type", fmt.Sprintf("multipart/alternative;\r\n boundary=%s\r\n", subWriter.Boundary()))
// Write the header
if err := headerToBytes(buff, header); err != nil {
return nil, fmt.Errorf("Failed to render multipart message headers: %s", err)
}
// Create the body sections
if e.Text != "" {
header.Set("Content-Type", fmt.Sprintf("text/plain; charset=UTF-8"))
header.Set("Content-Transfer-Encoding", "quoted-printable")
if _, err := subWriter.CreatePart(header); err != nil {
return nil, err
}
// Write the text
if err := quotePrintEncode(buff, e.Text); err != nil {
return nil, err
}
}
if e.HTML != "" {
header.Set("Content-Type", fmt.Sprintf("text/html; charset=UTF-8"))
header.Set("Content-Transfer-Encoding", "quoted-printable")
if _, err := subWriter.CreatePart(header); err != nil {
return nil, err
}
// Write the text
if err := quotePrintEncode(buff, e.HTML); err != nil {
return nil, err
}
}
if err := subWriter.Close(); err != nil {
return nil, err
}
}
// Create attachment part, if necessary
for _, a := range e.Attachments {
ap, err := w.CreatePart(a.Header)
if err != nil {
return nil, err
}
// Write the base64Wrapped content to the part
base64Wrap(ap, a.Content)
}
if err := w.Close(); err != nil {
return nil, err
}
return buff.Bytes(), nil
}
// Attach file to the send mail
func (e *Email) AttachFile(filename string) (a *Attachment, err error) {
f, err := os.Open(filename)
if err != nil {
return
}
ct := mime.TypeByExtension(filepath.Ext(filename))
basename := path.Base(filename)
return e.Attach(f, basename, ct)
}
// Attach is used to attach content from an io.Reader to the email.
// Required parameters include an io.Reader, the desired filename for the attachment, and the Content-Type
func (e *Email) Attach(r io.Reader, filename string, c string) (a *Attachment, err error) {
var buffer bytes.Buffer
if _, err = io.Copy(&buffer, r); err != nil {
return
}
at := &Attachment{
Filename: filename,
Header: textproto.MIMEHeader{},
Content: buffer.Bytes(),
}
// Get the Content-Type to be used in the MIMEHeader
if c != "" {
at.Header.Set("Content-Type", c)
} else {
// If the Content-Type is blank, set the Content-Type to "application/octet-stream"
at.Header.Set("Content-Type", "application/octet-stream")
}
at.Header.Set("Content-Disposition", fmt.Sprintf("attachment;\r\n filename=\"%s\"", filename))
at.Header.Set("Content-Transfer-Encoding", "base64")
e.Attachments = append(e.Attachments, at)
return at, nil
}
func (e *Email) Send() error {
if e.Auth == nil {
e.Auth = smtp.PlainAuth(e.Identity, e.Username, e.Password, e.Host)
}
// Merge the To, Cc, and Bcc fields
to := make([]string, 0, len(e.To)+len(e.Cc)+len(e.Bcc))
to = append(append(append(to, e.To...), e.Cc...), e.Bcc...)
// Check to make sure there is at least one recipient and one "From" address
if e.From == "" || len(to) == 0 {
return errors.New("Must specify at least one From address and one To address")
}
from, err := mail.ParseAddress(e.From)
if err != nil {
return err
}
raw, err := e.Bytes()
if err != nil {
return err
}
return smtp.SendMail(e.Host+":"+strconv.Itoa(e.Port), e.Auth, from.Address, to, raw)
}
// quotePrintEncode writes the quoted-printable text to the IO Writer (according to RFC 2045)
func quotePrintEncode(w io.Writer, s string) error {
var buf [3]byte
mc := 0
for i := 0; i < len(s); i++ {
c := s[i]
// We're assuming Unix style text formats as input (LF line break), and
// quoted-printble uses CRLF line breaks. (Literal CRs will become
// "=0D", but probably shouldn't be there to begin with!)
if c == '\n' {
io.WriteString(w, "\r\n")
mc = 0
continue
}
var nextOut []byte
if isPrintable(c) {
nextOut = append(buf[:0], c)
} else {
nextOut = buf[:]
qpEscape(nextOut, c)
}
// Add a soft line break if the next (encoded) byte would push this line
// to or past the limit.
if mc+len(nextOut) >= maxLineLength {
if _, err := io.WriteString(w, "=\r\n"); err != nil {
return err
}
mc = 0
}
if _, err := w.Write(nextOut); err != nil {
return err
}
mc += len(nextOut)
}
// No trailing end-of-line?? Soft line break, then. TODO: is this sane?
if mc > 0 {
io.WriteString(w, "=\r\n")
}
return nil
}
// isPrintable returns true if the rune given is "printable" according to RFC 2045, false otherwise
func isPrintable(c byte) bool {
return (c >= '!' && c <= '<') || (c >= '>' && c <= '~') || (c == ' ' || c == '\n' || c == '\t')
}
// qpEscape is a helper function for quotePrintEncode which escapes a
// non-printable byte. Expects len(dest) == 3.
func qpEscape(dest []byte, c byte) {
const nums = "0123456789ABCDEF"
dest[0] = '='
dest[1] = nums[(c&0xf0)>>4]
dest[2] = nums[(c & 0xf)]
}
// headerToBytes enumerates the key and values in the header, and writes the results to the IO Writer
func headerToBytes(w io.Writer, t textproto.MIMEHeader) error {
for k, v := range t {
// Write the header key
_, err := fmt.Fprintf(w, "%s:", k)
if err != nil {
return err
}
// Write each value in the header
for _, c := range v {
_, err := fmt.Fprintf(w, " %s\r\n", c)
if err != nil {
return err
}
}
}
return nil
}
// base64Wrap encodeds the attachment content, and wraps it according to RFC 2045 standards (every 76 chars)
// The output is then written to the specified io.Writer
func base64Wrap(w io.Writer, b []byte) {
// 57 raw bytes per 76-byte base64 line.
const maxRaw = 57
// Buffer for each line, including trailing CRLF.
var buffer [maxLineLength + len("\r\n")]byte
copy(buffer[maxLineLength:], "\r\n")
// Process raw chunks until there's no longer enough to fill a line.
for len(b) >= maxRaw {
base64.StdEncoding.Encode(buffer[:], b[:maxRaw])
w.Write(buffer[:])
b = b[maxRaw:]
}
// Handle the last chunk of bytes.
if len(b) > 0 {
out := buffer[:base64.StdEncoding.EncodedLen(len(b))]
base64.StdEncoding.Encode(out, b)
out = append(out, "\r\n"...)
w.Write(out)
}
}

27
utils/mail_test.go Normal file
View File

@ -0,0 +1,27 @@
package utils
import "testing"
func TestMail(t *testing.T) {
config := `{"username":"astaxie@gmail.com","password":"astaxie","host":"smtp.gmail.com","port":587}`
mail := NewEMail(config)
if mail.Username != "astaxie@gmail.com" {
t.Fatal("email parse get username error")
}
if mail.Password != "astaxie" {
t.Fatal("email parse get password error")
}
if mail.Host != "smtp.gmail.com" {
t.Fatal("email parse get host error")
}
if mail.Port != 587 {
t.Fatal("email parse get port error")
}
mail.To = []string{"xiemengjun@gmail.com"}
mail.From = "astaxie@gmail.com"
mail.Subject = "hi, just from beego!"
mail.Text = "Text Body is, of course, supported!"
mail.HTML = "<h1>Fancy Html is supported, too!</h1>"
mail.AttachFile("/Users/astaxie/github/beego/beego.go")
mail.Send()
}

20
utils/rand.go Normal file
View File

@ -0,0 +1,20 @@
package utils
import (
"crypto/rand"
)
// RandomCreateBytes generate random []byte by specify chars.
func RandomCreateBytes(n int, alphabets ...byte) []byte {
const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
var bytes = make([]byte, n)
rand.Read(bytes)
for i, b := range bytes {
if len(alphabets) == 0 {
bytes[i] = alphanum[b%byte(len(alphanum))]
} else {
bytes[i] = alphabets[b%byte(len(alphabets))]
}
}
return bytes
}