mirror of
https://github.com/astaxie/beego.git
synced 2024-11-22 07:00:55 +00:00
commit
6b5108ef92
@ -35,9 +35,5 @@ More info [beego.me](http://beego.me)
|
||||
beego is licensed under the Apache Licence, Version 2.0
|
||||
(http://www.apache.org/licenses/LICENSE-2.0.html).
|
||||
|
||||
|
||||
## Use case
|
||||
|
||||
- Displaying API documentation: [gowalker](https://github.com/Unknwon/gowalker)
|
||||
- seocms: [seocms](https://github.com/chinakr/seocms)
|
||||
- CMS: [toropress](https://github.com/insionng/toropress)
|
||||
[![Clone in Koding](http://learn.koding.com/btn/clone_d.png)][koding]
|
||||
[koding]: https://koding.com/Teamwork?import=https://github.com/astaxie/beego/archive/master.zip&c=git1
|
8
app.go
8
app.go
@ -118,6 +118,14 @@ func (app *App) AutoRouter(c ControllerInterface) *App {
|
||||
return app
|
||||
}
|
||||
|
||||
// AutoRouterWithPrefix adds beego-defined controller handler with prefix.
|
||||
// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page,
|
||||
// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function.
|
||||
func (app *App) AutoRouterWithPrefix(prefix string, c ControllerInterface) *App {
|
||||
app.Handlers.AddAutoPrefix(prefix, c)
|
||||
return app
|
||||
}
|
||||
|
||||
// UrlFor creates a url with another registered controller handler with params.
|
||||
// The endpoint is formed as path.controller.name to defined the controller method which will run.
|
||||
// The values need key-pair data to assign into controller method.
|
||||
|
126
beego.go
126
beego.go
@ -4,6 +4,7 @@ import (
|
||||
"net/http"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/astaxie/beego/middleware"
|
||||
@ -13,6 +14,76 @@ import (
|
||||
// beego web framework version.
|
||||
const VERSION = "1.0.1"
|
||||
|
||||
type hookfunc func() error //hook function to run
|
||||
var hooks []hookfunc //hook function slice to store the hookfunc
|
||||
|
||||
type groupRouter struct {
|
||||
pattern string
|
||||
controller ControllerInterface
|
||||
mappingMethods string
|
||||
}
|
||||
|
||||
// RouterGroups which will store routers
|
||||
type GroupRouters []groupRouter
|
||||
|
||||
// Get a new GroupRouters
|
||||
func NewGroupRouters() GroupRouters {
|
||||
return make([]groupRouter, 0)
|
||||
}
|
||||
|
||||
// Add Router in the GroupRouters
|
||||
// it is for plugin or module to register router
|
||||
func (gr GroupRouters) AddRouter(pattern string, c ControllerInterface, mappingMethod ...string) {
|
||||
var newRG groupRouter
|
||||
if len(mappingMethod) > 0 {
|
||||
newRG = groupRouter{
|
||||
pattern,
|
||||
c,
|
||||
mappingMethod[0],
|
||||
}
|
||||
} else {
|
||||
newRG = groupRouter{
|
||||
pattern,
|
||||
c,
|
||||
"",
|
||||
}
|
||||
}
|
||||
gr = append(gr, newRG)
|
||||
}
|
||||
|
||||
func (gr GroupRouters) AddAuto(c ControllerInterface) {
|
||||
newRG := groupRouter{
|
||||
"",
|
||||
c,
|
||||
"",
|
||||
}
|
||||
gr = append(gr, newRG)
|
||||
}
|
||||
|
||||
// AddGroupRouter with the prefix
|
||||
// it will register the router in BeeApp
|
||||
// the follow code is write in modules:
|
||||
// GR:=NewGroupRouters()
|
||||
// GR.AddRouter("/login",&UserController,"get:Login")
|
||||
// GR.AddRouter("/logout",&UserController,"get:Logout")
|
||||
// GR.AddRouter("/register",&UserController,"get:Reg")
|
||||
// the follow code is write in app:
|
||||
// import "github.com/beego/modules/auth"
|
||||
// AddRouterGroup("/admin", auth.GR)
|
||||
func AddGroupRouter(prefix string, groups GroupRouters) *App {
|
||||
for _, v := range groups {
|
||||
if v.pattern == "" {
|
||||
BeeApp.AutoRouterWithPrefix(prefix, v.controller)
|
||||
} else if v.mappingMethods != "" {
|
||||
BeeApp.Router(prefix+v.pattern, v.controller, v.mappingMethods)
|
||||
} else {
|
||||
BeeApp.Router(prefix+v.pattern, v.controller)
|
||||
}
|
||||
|
||||
}
|
||||
return BeeApp
|
||||
}
|
||||
|
||||
// Router adds a patterned controller handler to BeeApp.
|
||||
// it's an alias method of App.Router.
|
||||
func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App {
|
||||
@ -36,6 +107,13 @@ func AutoRouter(c ControllerInterface) *App {
|
||||
return BeeApp
|
||||
}
|
||||
|
||||
// AutoPrefix adds controller handler to BeeApp with prefix.
|
||||
// it's same to App.AutoRouterWithPrefix.
|
||||
func AutoPrefix(prefix string, c ControllerInterface) *App {
|
||||
BeeApp.AutoRouterWithPrefix(prefix, c)
|
||||
return BeeApp
|
||||
}
|
||||
|
||||
// ErrorHandler registers http.HandlerFunc to each http err code string.
|
||||
// usage:
|
||||
// beego.ErrorHandler("404",NotFound)
|
||||
@ -87,6 +165,12 @@ func InsertFilter(pattern string, pos int, filter FilterFunc) *App {
|
||||
return BeeApp
|
||||
}
|
||||
|
||||
// The hookfunc will run in beego.Run()
|
||||
// such as sessionInit, middlerware start, buildtemplate, admin start
|
||||
func AddAPPStartHook(hf hookfunc) {
|
||||
hooks = append(hooks, hf)
|
||||
}
|
||||
|
||||
// Run beego application.
|
||||
// it's alias of App.Run.
|
||||
func Run() {
|
||||
@ -99,18 +183,32 @@ func Run() {
|
||||
}
|
||||
}
|
||||
|
||||
//init mime
|
||||
initMime()
|
||||
// do hooks function
|
||||
for _, hk := range hooks {
|
||||
err := hk()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
if SessionOn {
|
||||
GlobalSessions, _ = session.NewManager(SessionProvider,
|
||||
SessionName,
|
||||
SessionGCMaxLifetime,
|
||||
SessionSavePath,
|
||||
HttpTLS,
|
||||
SessionHashFunc,
|
||||
SessionHashKey,
|
||||
SessionCookieLifeTime)
|
||||
var err error
|
||||
sessionConfig := AppConfig.String("sessionConfig")
|
||||
if sessionConfig == "" {
|
||||
sessionConfig = `{"cookieName":"` + SessionName + `",` +
|
||||
`"gclifetime":` + strconv.FormatInt(SessionGCMaxLifetime, 10) + `,` +
|
||||
`"providerConfig":"` + SessionSavePath + `",` +
|
||||
`"secure":` + strconv.FormatBool(HttpTLS) + `,` +
|
||||
`"sessionIDHashFunc":"` + SessionHashFunc + `",` +
|
||||
`"sessionIDHashKey":"` + SessionHashKey + `",` +
|
||||
`"enableSetCookie":` + strconv.FormatBool(SessionAutoSetCookie) + `,` +
|
||||
`"cookieLifeTime":` + strconv.Itoa(SessionCookieLifeTime) + `}`
|
||||
}
|
||||
GlobalSessions, err = session.NewManager(SessionProvider,
|
||||
sessionConfig)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
go GlobalSessions.GC()
|
||||
}
|
||||
|
||||
@ -123,7 +221,7 @@ func Run() {
|
||||
|
||||
middleware.VERSION = VERSION
|
||||
middleware.AppName = AppName
|
||||
middleware.RegisterErrorHander()
|
||||
middleware.RegisterErrorHandler()
|
||||
|
||||
if EnableAdmin {
|
||||
go BeeAdminApp.Run()
|
||||
@ -131,3 +229,9 @@ func Run() {
|
||||
|
||||
BeeApp.Run()
|
||||
}
|
||||
|
||||
func init() {
|
||||
hooks = make([]hookfunc, 0)
|
||||
//init mime
|
||||
AddAPPStartHook(initMime)
|
||||
}
|
||||
|
2
cache/README.md
vendored
2
cache/README.md
vendored
@ -43,7 +43,7 @@ interval means the gc time. The cache will check at each time interval, whether
|
||||
|
||||
## Memcache adapter
|
||||
|
||||
memory adapter use the vitess's [Memcache](http://code.google.com/p/vitess/go/memcache) client.
|
||||
Memcache adapter use the vitess's [Memcache](http://code.google.com/p/vitess/go/memcache) client.
|
||||
|
||||
Configure like this:
|
||||
|
||||
|
50
cache/cache_test.go
vendored
50
cache/cache_test.go
vendored
@ -5,7 +5,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_cache(t *testing.T) {
|
||||
func TestCache(t *testing.T) {
|
||||
bm, err := NewCache("memory", `{"interval":20}`)
|
||||
if err != nil {
|
||||
t.Error("init err")
|
||||
@ -51,3 +51,51 @@ func Test_cache(t *testing.T) {
|
||||
t.Error("delete err")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileCache(t *testing.T) {
|
||||
bm, err := NewCache("file", `{"CachePath":"/cache","FileSuffix":".bin","DirectoryLevel":2,"EmbedExpiry":0}`)
|
||||
if err != nil {
|
||||
t.Error("init err")
|
||||
}
|
||||
if err = bm.Put("astaxie", 1, 10); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
if !bm.IsExist("astaxie") {
|
||||
t.Error("check err")
|
||||
}
|
||||
|
||||
if v := bm.Get("astaxie"); v.(int) != 1 {
|
||||
t.Error("get err")
|
||||
}
|
||||
|
||||
if err = bm.Incr("astaxie"); err != nil {
|
||||
t.Error("Incr Error", err)
|
||||
}
|
||||
|
||||
if v := bm.Get("astaxie"); v.(int) != 2 {
|
||||
t.Error("get err")
|
||||
}
|
||||
|
||||
if err = bm.Decr("astaxie"); err != nil {
|
||||
t.Error("Incr Error", err)
|
||||
}
|
||||
|
||||
if v := bm.Get("astaxie"); v.(int) != 1 {
|
||||
t.Error("get err")
|
||||
}
|
||||
bm.Delete("astaxie")
|
||||
if bm.IsExist("astaxie") {
|
||||
t.Error("delete err")
|
||||
}
|
||||
//test string
|
||||
if err = bm.Put("astaxie", "author", 10); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
if !bm.IsExist("astaxie") {
|
||||
t.Error("check err")
|
||||
}
|
||||
|
||||
if v := bm.Get("astaxie"); v.(string) != "author" {
|
||||
t.Error("get err")
|
||||
}
|
||||
}
|
||||
|
10
cache/file.go
vendored
10
cache/file.go
vendored
@ -61,6 +61,7 @@ func (this *FileCache) StartAndGC(config string) error {
|
||||
var cfg map[string]string
|
||||
json.Unmarshal([]byte(config), &cfg)
|
||||
//fmt.Println(cfg)
|
||||
//fmt.Println(config)
|
||||
if _, ok := cfg["CachePath"]; !ok {
|
||||
cfg["CachePath"] = FileCachePath
|
||||
}
|
||||
@ -135,7 +136,7 @@ func (this *FileCache) Get(key string) interface{} {
|
||||
return ""
|
||||
}
|
||||
var to FileCacheItem
|
||||
Gob_decode([]byte(filedata), &to)
|
||||
Gob_decode(filedata, &to)
|
||||
if to.Expired < time.Now().Unix() {
|
||||
return ""
|
||||
}
|
||||
@ -177,7 +178,7 @@ func (this *FileCache) Delete(key string) error {
|
||||
func (this *FileCache) Incr(key string) error {
|
||||
data := this.Get(key)
|
||||
var incr int
|
||||
fmt.Println(reflect.TypeOf(data).Name())
|
||||
//fmt.Println(reflect.TypeOf(data).Name())
|
||||
if reflect.TypeOf(data).Name() != "int" {
|
||||
incr = 0
|
||||
} else {
|
||||
@ -210,8 +211,7 @@ func (this *FileCache) IsExist(key string) bool {
|
||||
// Clean cached files.
|
||||
// not implemented.
|
||||
func (this *FileCache) ClearAll() error {
|
||||
//this.CachePath .递归删除
|
||||
|
||||
//this.CachePath
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -271,7 +271,7 @@ func Gob_encode(data interface{}) ([]byte, error) {
|
||||
}
|
||||
|
||||
// Gob decodes file cache item.
|
||||
func Gob_decode(data []byte, to interface{}) error {
|
||||
func Gob_decode(data []byte, to *FileCacheItem) error {
|
||||
buf := bytes.NewBuffer(data)
|
||||
dec := gob.NewDecoder(buf)
|
||||
return dec.Decode(&to)
|
||||
|
41
cache/memcache.go
vendored
41
cache/memcache.go
vendored
@ -21,7 +21,11 @@ func NewMemCache() *MemcacheCache {
|
||||
// get value from memcache.
|
||||
func (rc *MemcacheCache) Get(key string) interface{} {
|
||||
if rc.c == nil {
|
||||
rc.c = rc.connectInit()
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
v, err := rc.c.Get(key)
|
||||
if err != nil {
|
||||
@ -39,7 +43,11 @@ func (rc *MemcacheCache) Get(key string) interface{} {
|
||||
// put value to memcache. only support string.
|
||||
func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error {
|
||||
if rc.c == nil {
|
||||
rc.c = rc.connectInit()
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
v, ok := val.(string)
|
||||
if !ok {
|
||||
@ -55,7 +63,11 @@ func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error {
|
||||
// delete value in memcache.
|
||||
func (rc *MemcacheCache) Delete(key string) error {
|
||||
if rc.c == nil {
|
||||
rc.c = rc.connectInit()
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := rc.c.Delete(key)
|
||||
return err
|
||||
@ -76,7 +88,11 @@ func (rc *MemcacheCache) Decr(key string) error {
|
||||
// check value exists in memcache.
|
||||
func (rc *MemcacheCache) IsExist(key string) bool {
|
||||
if rc.c == nil {
|
||||
rc.c = rc.connectInit()
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
v, err := rc.c.Get(key)
|
||||
if err != nil {
|
||||
@ -93,7 +109,11 @@ func (rc *MemcacheCache) IsExist(key string) bool {
|
||||
// clear all cached in memcache.
|
||||
func (rc *MemcacheCache) ClearAll() error {
|
||||
if rc.c == nil {
|
||||
rc.c = rc.connectInit()
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err := rc.c.FlushAll()
|
||||
return err
|
||||
@ -109,20 +129,21 @@ func (rc *MemcacheCache) StartAndGC(config string) error {
|
||||
return errors.New("config has no conn key")
|
||||
}
|
||||
rc.conninfo = cf["conn"]
|
||||
rc.c = rc.connectInit()
|
||||
if rc.c == nil {
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
return errors.New("dial tcp conn error")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// connect to memcache and keep the connection.
|
||||
func (rc *MemcacheCache) connectInit() *memcache.Connection {
|
||||
func (rc *MemcacheCache) connectInit() (*memcache.Connection, error) {
|
||||
c, err := memcache.Connect(rc.conninfo)
|
||||
if err != nil {
|
||||
return nil
|
||||
return nil, err
|
||||
}
|
||||
return c
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
118
cache/redis.go
vendored
118
cache/redis.go
vendored
@ -3,6 +3,7 @@ package cache
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/beego/redigo/redis"
|
||||
)
|
||||
@ -14,7 +15,7 @@ var (
|
||||
|
||||
// Redis cache adapter.
|
||||
type RedisCache struct {
|
||||
c redis.Conn
|
||||
p *redis.Pool // redis connection pool
|
||||
conninfo string
|
||||
key string
|
||||
}
|
||||
@ -24,107 +25,62 @@ func NewRedisCache() *RedisCache {
|
||||
return &RedisCache{key: DefaultKey}
|
||||
}
|
||||
|
||||
// actually do the redis cmds
|
||||
func (rc *RedisCache) do(commandName string, args ...interface{}) (reply interface{}, err error) {
|
||||
c := rc.p.Get()
|
||||
defer c.Close()
|
||||
|
||||
return c.Do(commandName, args...)
|
||||
}
|
||||
|
||||
// Get cache from redis.
|
||||
func (rc *RedisCache) Get(key string) interface{} {
|
||||
if rc.c == nil {
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
v, err := rc.c.Do("HGET", rc.key, key)
|
||||
v, err := rc.do("HGET", rc.key, key)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
// put cache to redis.
|
||||
// timeout is ignored.
|
||||
func (rc *RedisCache) Put(key string, val interface{}, timeout int64) error {
|
||||
if rc.c == nil {
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := rc.c.Do("HSET", rc.key, key, val)
|
||||
_, err := rc.do("HSET", rc.key, key, val)
|
||||
return err
|
||||
}
|
||||
|
||||
// delete cache in redis.
|
||||
func (rc *RedisCache) Delete(key string) error {
|
||||
if rc.c == nil {
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := rc.c.Do("HDEL", rc.key, key)
|
||||
_, err := rc.do("HDEL", rc.key, key)
|
||||
return err
|
||||
}
|
||||
|
||||
// check cache exist in redis.
|
||||
func (rc *RedisCache) IsExist(key string) bool {
|
||||
if rc.c == nil {
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
v, err := redis.Bool(rc.c.Do("HEXISTS", rc.key, key))
|
||||
v, err := redis.Bool(rc.do("HEXISTS", rc.key, key))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
// increase counter in redis.
|
||||
func (rc *RedisCache) Incr(key string) error {
|
||||
if rc.c == nil {
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := redis.Bool(rc.c.Do("HINCRBY", rc.key, key, 1))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
_, err := redis.Bool(rc.do("HINCRBY", rc.key, key, 1))
|
||||
return err
|
||||
}
|
||||
|
||||
// decrease counter in redis.
|
||||
func (rc *RedisCache) Decr(key string) error {
|
||||
if rc.c == nil {
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := redis.Bool(rc.c.Do("HINCRBY", rc.key, key, -1))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
_, err := redis.Bool(rc.do("HINCRBY", rc.key, key, -1))
|
||||
return err
|
||||
}
|
||||
|
||||
// clean all cache in redis. delete this redis collection.
|
||||
func (rc *RedisCache) ClearAll() error {
|
||||
if rc.c == nil {
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := rc.c.Do("DEL", rc.key)
|
||||
_, err := rc.do("DEL", rc.key)
|
||||
return err
|
||||
}
|
||||
|
||||
@ -135,32 +91,42 @@ func (rc *RedisCache) ClearAll() error {
|
||||
func (rc *RedisCache) StartAndGC(config string) error {
|
||||
var cf map[string]string
|
||||
json.Unmarshal([]byte(config), &cf)
|
||||
|
||||
if _, ok := cf["key"]; !ok {
|
||||
cf["key"] = DefaultKey
|
||||
}
|
||||
|
||||
if _, ok := cf["conn"]; !ok {
|
||||
return errors.New("config has no conn key")
|
||||
}
|
||||
|
||||
rc.key = cf["key"]
|
||||
rc.conninfo = cf["conn"]
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
rc.connectInit()
|
||||
|
||||
c := rc.p.Get()
|
||||
defer c.Close()
|
||||
if err := c.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if rc.c == nil {
|
||||
return errors.New("dial tcp conn error")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// connect to redis.
|
||||
func (rc *RedisCache) connectInit() (redis.Conn, error) {
|
||||
c, err := redis.Dial("tcp", rc.conninfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func (rc *RedisCache) connectInit() {
|
||||
// initialize a new pool
|
||||
rc.p = &redis.Pool{
|
||||
MaxIdle: 3,
|
||||
IdleTimeout: 180 * time.Second,
|
||||
Dial: func() (redis.Conn, error) {
|
||||
c, err := redis.Dial("tcp", rc.conninfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
},
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -40,6 +40,7 @@ var (
|
||||
SessionHashFunc string // session hash generation func.
|
||||
SessionHashKey string // session hash salt string.
|
||||
SessionCookieLifeTime int // the life time of session id in cookie.
|
||||
SessionAutoSetCookie bool // auto setcookie
|
||||
UseFcgi bool
|
||||
MaxMemory int64
|
||||
EnableGzip bool // flag of enable gzip
|
||||
@ -96,6 +97,7 @@ func init() {
|
||||
SessionHashFunc = "sha1"
|
||||
SessionHashKey = "beegoserversessionkey"
|
||||
SessionCookieLifeTime = 0 //set cookie default is the brower life
|
||||
SessionAutoSetCookie = true
|
||||
|
||||
UseFcgi = false
|
||||
|
||||
@ -139,6 +141,7 @@ func init() {
|
||||
func ParseConfig() (err error) {
|
||||
AppConfig, err = config.NewConfig("ini", AppConfigPath)
|
||||
if err != nil {
|
||||
AppConfig = config.NewFakeConfig()
|
||||
return err
|
||||
} else {
|
||||
HttpAddr = AppConfig.String("HttpAddr")
|
||||
|
@ -6,8 +6,9 @@ import (
|
||||
|
||||
// ConfigContainer defines how to get and set value from configuration raw data.
|
||||
type ConfigContainer interface {
|
||||
Set(key, val string) error // support section::key type in given key when using ini type.
|
||||
String(key string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
|
||||
Set(key, val string) error // support section::key type in given key when using ini type.
|
||||
String(key string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
|
||||
Strings(key string) []string //get string slice
|
||||
Int(key string) (int, error)
|
||||
Int64(key string) (int64, error)
|
||||
Bool(key string) (bool, error)
|
||||
|
62
config/fake.go
Normal file
62
config/fake.go
Normal file
@ -0,0 +1,62 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type fakeConfigContainer struct {
|
||||
data map[string]string
|
||||
}
|
||||
|
||||
func (c *fakeConfigContainer) getData(key string) string {
|
||||
key = strings.ToLower(key)
|
||||
return c.data[key]
|
||||
}
|
||||
|
||||
func (c *fakeConfigContainer) Set(key, val string) error {
|
||||
key = strings.ToLower(key)
|
||||
c.data[key] = val
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeConfigContainer) String(key string) string {
|
||||
return c.getData(key)
|
||||
}
|
||||
|
||||
func (c *fakeConfigContainer) Strings(key string) []string {
|
||||
return strings.Split(c.getData(key), ";")
|
||||
}
|
||||
|
||||
func (c *fakeConfigContainer) Int(key string) (int, error) {
|
||||
return strconv.Atoi(c.getData(key))
|
||||
}
|
||||
|
||||
func (c *fakeConfigContainer) Int64(key string) (int64, error) {
|
||||
return strconv.ParseInt(c.getData(key), 10, 64)
|
||||
}
|
||||
|
||||
func (c *fakeConfigContainer) Bool(key string) (bool, error) {
|
||||
return strconv.ParseBool(c.getData(key))
|
||||
}
|
||||
|
||||
func (c *fakeConfigContainer) Float(key string) (float64, error) {
|
||||
return strconv.ParseFloat(c.getData(key), 64)
|
||||
}
|
||||
|
||||
func (c *fakeConfigContainer) DIY(key string) (interface{}, error) {
|
||||
key = strings.ToLower(key)
|
||||
if v, ok := c.data[key]; ok {
|
||||
return v, nil
|
||||
}
|
||||
return nil, errors.New("key not find")
|
||||
}
|
||||
|
||||
var _ ConfigContainer = new(fakeConfigContainer)
|
||||
|
||||
func NewFakeConfig() ConfigContainer {
|
||||
return &fakeConfigContainer{
|
||||
data: make(map[string]string),
|
||||
}
|
||||
}
|
@ -146,6 +146,11 @@ func (c *IniConfigContainer) String(key string) string {
|
||||
return c.getdata(key)
|
||||
}
|
||||
|
||||
// Strings returns the []string value for a given key.
|
||||
func (c *IniConfigContainer) Strings(key string) []string {
|
||||
return strings.Split(c.String(key), ";")
|
||||
}
|
||||
|
||||
// WriteValue writes a new value for key.
|
||||
// if write to one section, the key need be "section::key".
|
||||
// if the section is not existed, it panics.
|
||||
|
@ -19,6 +19,7 @@ copyrequestbody = true
|
||||
key1="asta"
|
||||
key2 = "xie"
|
||||
CaseInsensitive = true
|
||||
peers = one;two;three
|
||||
`
|
||||
|
||||
func TestIni(t *testing.T) {
|
||||
@ -78,4 +79,11 @@ func TestIni(t *testing.T) {
|
||||
if v, err := iniconf.Bool("demo::caseinsensitive"); err != nil || v != true {
|
||||
t.Fatal("get demo.caseinsensitive error")
|
||||
}
|
||||
|
||||
if data := iniconf.Strings("demo::peers"); len(data) != 3 {
|
||||
t.Fatal("get strings error", data)
|
||||
} else if data[0] != "one" {
|
||||
t.Fatal("get first params error not equat to one")
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -116,6 +116,11 @@ func (c *JsonConfigContainer) String(key string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Strings returns the []string value for a given key.
|
||||
func (c *JsonConfigContainer) Strings(key string) []string {
|
||||
return strings.Split(c.String(key), ";")
|
||||
}
|
||||
|
||||
// WriteValue writes a new value for key.
|
||||
func (c *JsonConfigContainer) Set(key, val string) error {
|
||||
c.Lock()
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/beego/x2j"
|
||||
@ -72,6 +73,11 @@ func (c *XMLConfigContainer) String(key string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Strings returns the []string value for a given key.
|
||||
func (c *XMLConfigContainer) Strings(key string) []string {
|
||||
return strings.Split(c.String(key), ";")
|
||||
}
|
||||
|
||||
// WriteValue writes a new value for key.
|
||||
func (c *XMLConfigContainer) Set(key, val string) error {
|
||||
c.Lock()
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/beego/goyaml2"
|
||||
@ -117,6 +118,11 @@ func (c *YAMLConfigContainer) String(key string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Strings returns the []string value for a given key.
|
||||
func (c *YAMLConfigContainer) Strings(key string) []string {
|
||||
return strings.Split(c.String(key), ";")
|
||||
}
|
||||
|
||||
// WriteValue writes a new value for key.
|
||||
func (c *YAMLConfigContainer) Set(key, val string) error {
|
||||
c.Lock()
|
||||
|
@ -3,7 +3,6 @@ package beego
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
@ -22,6 +21,7 @@ import (
|
||||
|
||||
"github.com/astaxie/beego/context"
|
||||
"github.com/astaxie/beego/session"
|
||||
"github.com/astaxie/beego/utils"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -140,7 +140,7 @@ func (c *Controller) RenderString() (string, error) {
|
||||
return string(b), e
|
||||
}
|
||||
|
||||
// RenderBytes returns the bytes of renderd tempate string. Do not send out response.
|
||||
// RenderBytes returns the bytes of rendered template string. Do not send out response.
|
||||
func (c *Controller) RenderBytes() ([]byte, error) {
|
||||
//if the controller has set layout, then first get the tplname's content set the content to the layout
|
||||
if c.Layout != "" {
|
||||
@ -165,7 +165,7 @@ func (c *Controller) RenderBytes() ([]byte, error) {
|
||||
|
||||
if c.LayoutSections != nil {
|
||||
for sectionName, sectionTpl := range c.LayoutSections {
|
||||
if (sectionTpl == "") {
|
||||
if sectionTpl == "" {
|
||||
c.Data[sectionName] = ""
|
||||
continue
|
||||
}
|
||||
@ -391,6 +391,7 @@ func (c *Controller) DelSession(name interface{}) {
|
||||
// SessionRegenerateID regenerates session id for this session.
|
||||
// the session data have no changes.
|
||||
func (c *Controller) SessionRegenerateID() {
|
||||
c.CruSession.SessionRelease(c.Ctx.ResponseWriter)
|
||||
c.CruSession = GlobalSessions.SessionRegenerateId(c.Ctx.ResponseWriter, c.Ctx.Request)
|
||||
c.Ctx.Input.CruSession = c.CruSession
|
||||
}
|
||||
@ -454,7 +455,7 @@ func (c *Controller) XsrfToken() string {
|
||||
} else {
|
||||
expire = int64(XSRFExpire)
|
||||
}
|
||||
token = getRandomString(15)
|
||||
token = string(utils.RandomCreateBytes(15))
|
||||
c.SetSecureCookie(XSRFKEY, "_xsrf", token, expire)
|
||||
}
|
||||
c._xsrf_token = token
|
||||
@ -491,14 +492,3 @@ func (c *Controller) XsrfFormHtml() string {
|
||||
func (c *Controller) GetControllerAndAction() (controllerName, actionName string) {
|
||||
return c.controllerName, c.actionName
|
||||
}
|
||||
|
||||
// getRandomString returns random string.
|
||||
func getRandomString(n int) string {
|
||||
const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
var bytes = make([]byte, n)
|
||||
rand.Read(bytes)
|
||||
for i, b := range bytes {
|
||||
bytes[i] = alphanum[b%byte(len(alphanum))]
|
||||
}
|
||||
return string(bytes)
|
||||
}
|
||||
|
@ -1,12 +1,13 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"github.com/astaxie/beego"
|
||||
"github.com/garyburd/go-websocket/websocket"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/astaxie/beego"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
|
17
filter.go
17
filter.go
@ -28,6 +28,12 @@ func (mr *FilterRouter) ValidRouter(router string) (bool, map[string]string) {
|
||||
if router == mr.pattern {
|
||||
return true, nil
|
||||
}
|
||||
//pattern /admin router /admin/ match
|
||||
//pattern /admin/ router /admin don't match, because url will 301 in router
|
||||
if n := len(router); n > 1 && router[n-1] == '/' && router[:n-2] == mr.pattern {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if mr.hasregex {
|
||||
if !mr.regex.MatchString(router) {
|
||||
return false, nil
|
||||
@ -46,7 +52,7 @@ func (mr *FilterRouter) ValidRouter(router string) (bool, map[string]string) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
|
||||
func buildFilter(pattern string, filter FilterFunc) (*FilterRouter, error) {
|
||||
mr := new(FilterRouter)
|
||||
mr.params = make(map[int]string)
|
||||
mr.filterFunc = filter
|
||||
@ -54,7 +60,7 @@ func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
|
||||
j := 0
|
||||
for i, part := range parts {
|
||||
if strings.HasPrefix(part, ":") {
|
||||
expr := "(.+)"
|
||||
expr := "(.*)"
|
||||
//a user may choose to override the default expression
|
||||
// similar to expressjs: ‘/user/:id([0-9]+)’
|
||||
if index := strings.Index(part, "("); index != -1 {
|
||||
@ -77,7 +83,7 @@ func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
|
||||
j++
|
||||
}
|
||||
if strings.HasPrefix(part, "*") {
|
||||
expr := "(.+)"
|
||||
expr := "(.*)"
|
||||
if part == "*.*" {
|
||||
mr.params[j] = ":path"
|
||||
parts[i] = "([^.]+).([^.]+)"
|
||||
@ -137,12 +143,11 @@ func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
|
||||
pattern = strings.Join(parts, "/")
|
||||
regex, regexErr := regexp.Compile(pattern)
|
||||
if regexErr != nil {
|
||||
//TODO add error handling here to avoid panic
|
||||
panic(regexErr)
|
||||
return nil, regexErr
|
||||
}
|
||||
mr.regex = regex
|
||||
mr.hasregex = true
|
||||
}
|
||||
mr.pattern = pattern
|
||||
return mr
|
||||
return mr, nil
|
||||
}
|
||||
|
@ -23,3 +23,32 @@ func TestFilter(t *testing.T) {
|
||||
t.Errorf("user define func can't run")
|
||||
}
|
||||
}
|
||||
|
||||
var FilterAdminUser = func(ctx *context.Context) {
|
||||
ctx.Output.Body([]byte("i am admin"))
|
||||
}
|
||||
|
||||
// Filter pattern /admin/:all
|
||||
// all url like /admin/ /admin/xie will all get filter
|
||||
|
||||
func TestPatternTwo(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "/admin/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler := NewControllerRegistor()
|
||||
handler.AddFilter("/admin/:all", "AfterStatic", FilterAdminUser)
|
||||
handler.ServeHTTP(w, r)
|
||||
if w.Body.String() != "i am admin" {
|
||||
t.Errorf("filter /admin/ can't run")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternThree(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "/admin/astaxie", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler := NewControllerRegistor()
|
||||
handler.AddFilter("/admin/:all", "AfterStatic", FilterAdminUser)
|
||||
handler.ServeHTTP(w, r)
|
||||
if w.Body.String() != "i am admin" {
|
||||
t.Errorf("filter /admin/astaxie can't run")
|
||||
}
|
||||
}
|
||||
|
@ -7,6 +7,8 @@ import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// ConnWriter implements LoggerInterface.
|
||||
// it writes messages in keep-live tcp connection.
|
||||
type ConnWriter struct {
|
||||
lg *log.Logger
|
||||
innerWriter io.WriteCloser
|
||||
@ -17,12 +19,15 @@ type ConnWriter struct {
|
||||
Level int `json:"level"`
|
||||
}
|
||||
|
||||
// create new ConnWrite returning as LoggerInterface.
|
||||
func NewConn() LoggerInterface {
|
||||
conn := new(ConnWriter)
|
||||
conn.Level = LevelTrace
|
||||
return conn
|
||||
}
|
||||
|
||||
// init connection writer with json config.
|
||||
// json config only need key "level".
|
||||
func (c *ConnWriter) Init(jsonconfig string) error {
|
||||
err := json.Unmarshal([]byte(jsonconfig), c)
|
||||
if err != nil {
|
||||
@ -31,6 +36,8 @@ func (c *ConnWriter) Init(jsonconfig string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// write message in connection.
|
||||
// if connection is down, try to re-connect.
|
||||
func (c *ConnWriter) WriteMsg(msg string, level int) error {
|
||||
if level < c.Level {
|
||||
return nil
|
||||
@ -49,10 +56,12 @@ func (c *ConnWriter) WriteMsg(msg string, level int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// implementing method. empty.
|
||||
func (c *ConnWriter) Flush() {
|
||||
|
||||
}
|
||||
|
||||
// destroy connection writer and close tcp listener.
|
||||
func (c *ConnWriter) Destroy() {
|
||||
if c.innerWriter == nil {
|
||||
return
|
||||
|
@ -6,11 +6,13 @@ import (
|
||||
"os"
|
||||
)
|
||||
|
||||
// ConsoleWriter implements LoggerInterface and writes messages to terminal.
|
||||
type ConsoleWriter struct {
|
||||
lg *log.Logger
|
||||
Level int `json:"level"`
|
||||
}
|
||||
|
||||
// create ConsoleWriter returning as LoggerInterface.
|
||||
func NewConsole() LoggerInterface {
|
||||
cw := new(ConsoleWriter)
|
||||
cw.lg = log.New(os.Stdout, "", log.Ldate|log.Ltime)
|
||||
@ -18,6 +20,8 @@ func NewConsole() LoggerInterface {
|
||||
return cw
|
||||
}
|
||||
|
||||
// init console logger.
|
||||
// jsonconfig like '{"level":LevelTrace}'.
|
||||
func (c *ConsoleWriter) Init(jsonconfig string) error {
|
||||
err := json.Unmarshal([]byte(jsonconfig), c)
|
||||
if err != nil {
|
||||
@ -26,6 +30,7 @@ func (c *ConsoleWriter) Init(jsonconfig string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// write message in console.
|
||||
func (c *ConsoleWriter) WriteMsg(msg string, level int) error {
|
||||
if level < c.Level {
|
||||
return nil
|
||||
@ -34,10 +39,12 @@ func (c *ConsoleWriter) WriteMsg(msg string, level int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// implementing method. empty.
|
||||
func (c *ConsoleWriter) Destroy() {
|
||||
|
||||
}
|
||||
|
||||
// implementing method. empty.
|
||||
func (c *ConsoleWriter) Flush() {
|
||||
|
||||
}
|
||||
|
21
logs/file.go
21
logs/file.go
@ -13,6 +13,8 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// FileLogWriter implements LoggerInterface.
|
||||
// It writes messages by lines limit, file size limit, or time frequency.
|
||||
type FileLogWriter struct {
|
||||
*log.Logger
|
||||
mw *MuxWriter
|
||||
@ -38,17 +40,20 @@ type FileLogWriter struct {
|
||||
Level int `json:"level"`
|
||||
}
|
||||
|
||||
// an *os.File writer with locker.
|
||||
type MuxWriter struct {
|
||||
sync.Mutex
|
||||
fd *os.File
|
||||
}
|
||||
|
||||
// write to os.File.
|
||||
func (l *MuxWriter) Write(b []byte) (int, error) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
return l.fd.Write(b)
|
||||
}
|
||||
|
||||
// set os.File in writer.
|
||||
func (l *MuxWriter) SetFd(fd *os.File) {
|
||||
if l.fd != nil {
|
||||
l.fd.Close()
|
||||
@ -56,6 +61,7 @@ func (l *MuxWriter) SetFd(fd *os.File) {
|
||||
l.fd = fd
|
||||
}
|
||||
|
||||
// create a FileLogWriter returning as LoggerInterface.
|
||||
func NewFileWriter() LoggerInterface {
|
||||
w := &FileLogWriter{
|
||||
Filename: "",
|
||||
@ -73,15 +79,16 @@ func NewFileWriter() LoggerInterface {
|
||||
return w
|
||||
}
|
||||
|
||||
// jsonconfig like this
|
||||
//{
|
||||
// Init file logger with json config.
|
||||
// jsonconfig like:
|
||||
// {
|
||||
// "filename":"logs/beego.log",
|
||||
// "maxlines":10000,
|
||||
// "maxsize":1<<30,
|
||||
// "daily":true,
|
||||
// "maxdays":15,
|
||||
// "rotate":true
|
||||
//}
|
||||
// }
|
||||
func (w *FileLogWriter) Init(jsonconfig string) error {
|
||||
err := json.Unmarshal([]byte(jsonconfig), w)
|
||||
if err != nil {
|
||||
@ -94,6 +101,7 @@ func (w *FileLogWriter) Init(jsonconfig string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// start file logger. create log file and set to locker-inside file writer.
|
||||
func (w *FileLogWriter) StartLogger() error {
|
||||
fd, err := w.createLogFile()
|
||||
if err != nil {
|
||||
@ -122,6 +130,7 @@ func (w *FileLogWriter) docheck(size int) {
|
||||
w.maxsize_cursize += size
|
||||
}
|
||||
|
||||
// write logger message into file.
|
||||
func (w *FileLogWriter) WriteMsg(msg string, level int) error {
|
||||
if level < w.Level {
|
||||
return nil
|
||||
@ -158,6 +167,8 @@ func (w *FileLogWriter) initFd() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// DoRotate means it need to write file in new file.
|
||||
// new file name like xx.log.2013-01-01.2
|
||||
func (w *FileLogWriter) DoRotate() error {
|
||||
_, err := os.Lstat(w.Filename)
|
||||
if err == nil { // file exists
|
||||
@ -211,10 +222,14 @@ func (w *FileLogWriter) deleteOldLog() {
|
||||
})
|
||||
}
|
||||
|
||||
// destroy file logger, close file writer.
|
||||
func (w *FileLogWriter) Destroy() {
|
||||
w.mw.fd.Close()
|
||||
}
|
||||
|
||||
// flush file logger.
|
||||
// there are no buffering messages in file logger in memory.
|
||||
// flush file means sync file from disk.
|
||||
func (w *FileLogWriter) Flush() {
|
||||
w.mw.fd.Sync()
|
||||
}
|
||||
|
24
logs/log.go
24
logs/log.go
@ -6,6 +6,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// log message levels
|
||||
LevelTrace = iota
|
||||
LevelDebug
|
||||
LevelInfo
|
||||
@ -16,6 +17,7 @@ const (
|
||||
|
||||
type loggerType func() LoggerInterface
|
||||
|
||||
// LoggerInterface defines the behavior of a log provider.
|
||||
type LoggerInterface interface {
|
||||
Init(config string) error
|
||||
WriteMsg(msg string, level int) error
|
||||
@ -38,6 +40,8 @@ func Register(name string, log loggerType) {
|
||||
adapters[name] = log
|
||||
}
|
||||
|
||||
// BeeLogger is default logger in beego application.
|
||||
// it can contain several providers and log message into all providers.
|
||||
type BeeLogger struct {
|
||||
lock sync.Mutex
|
||||
level int
|
||||
@ -50,7 +54,9 @@ type logMsg struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
// config need to be correct JSON as string: {"interval":360}
|
||||
// NewLogger returns a new BeeLogger.
|
||||
// channellen means the number of messages in chan.
|
||||
// if the buffering chan is full, logger adapters write to file or other way.
|
||||
func NewLogger(channellen int64) *BeeLogger {
|
||||
bl := new(BeeLogger)
|
||||
bl.msg = make(chan *logMsg, channellen)
|
||||
@ -60,6 +66,8 @@ func NewLogger(channellen int64) *BeeLogger {
|
||||
return bl
|
||||
}
|
||||
|
||||
// SetLogger provides a given logger adapter into BeeLogger with config string.
|
||||
// config need to be correct JSON as string: {"interval":360}.
|
||||
func (bl *BeeLogger) SetLogger(adaptername string, config string) error {
|
||||
bl.lock.Lock()
|
||||
defer bl.lock.Unlock()
|
||||
@ -73,6 +81,7 @@ func (bl *BeeLogger) SetLogger(adaptername string, config string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// remove a logger adapter in BeeLogger.
|
||||
func (bl *BeeLogger) DelLogger(adaptername string) error {
|
||||
bl.lock.Lock()
|
||||
defer bl.lock.Unlock()
|
||||
@ -96,10 +105,14 @@ func (bl *BeeLogger) writerMsg(loglevel int, msg string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// set log message level.
|
||||
// if message level (such as LevelTrace) is less than logger level (such as LevelWarn), ignore message.
|
||||
func (bl *BeeLogger) SetLevel(l int) {
|
||||
bl.level = l
|
||||
}
|
||||
|
||||
// start logger chan reading.
|
||||
// when chan is full, write logs.
|
||||
func (bl *BeeLogger) StartLogger() {
|
||||
for {
|
||||
select {
|
||||
@ -111,43 +124,50 @@ func (bl *BeeLogger) StartLogger() {
|
||||
}
|
||||
}
|
||||
|
||||
// log trace level message.
|
||||
func (bl *BeeLogger) Trace(format string, v ...interface{}) {
|
||||
msg := fmt.Sprintf("[T] "+format, v...)
|
||||
bl.writerMsg(LevelTrace, msg)
|
||||
}
|
||||
|
||||
// log debug level message.
|
||||
func (bl *BeeLogger) Debug(format string, v ...interface{}) {
|
||||
msg := fmt.Sprintf("[D] "+format, v...)
|
||||
bl.writerMsg(LevelDebug, msg)
|
||||
}
|
||||
|
||||
// log info level message.
|
||||
func (bl *BeeLogger) Info(format string, v ...interface{}) {
|
||||
msg := fmt.Sprintf("[I] "+format, v...)
|
||||
bl.writerMsg(LevelInfo, msg)
|
||||
}
|
||||
|
||||
// log warn level message.
|
||||
func (bl *BeeLogger) Warn(format string, v ...interface{}) {
|
||||
msg := fmt.Sprintf("[W] "+format, v...)
|
||||
bl.writerMsg(LevelWarn, msg)
|
||||
}
|
||||
|
||||
// log error level message.
|
||||
func (bl *BeeLogger) Error(format string, v ...interface{}) {
|
||||
msg := fmt.Sprintf("[E] "+format, v...)
|
||||
bl.writerMsg(LevelError, msg)
|
||||
}
|
||||
|
||||
// log critical level message.
|
||||
func (bl *BeeLogger) Critical(format string, v ...interface{}) {
|
||||
msg := fmt.Sprintf("[C] "+format, v...)
|
||||
bl.writerMsg(LevelCritical, msg)
|
||||
}
|
||||
|
||||
//flush all chan data
|
||||
// flush all chan data.
|
||||
func (bl *BeeLogger) Flush() {
|
||||
for _, l := range bl.outputs {
|
||||
l.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// close logger, flush all chan data and destroy all adapters in BeeLogger.
|
||||
func (bl *BeeLogger) Close() {
|
||||
for {
|
||||
if len(bl.msg) > 0 {
|
||||
|
18
logs/smtp.go
18
logs/smtp.go
@ -12,7 +12,7 @@ const (
|
||||
subjectPhrase = "Diagnostic message from server"
|
||||
)
|
||||
|
||||
// smtpWriter is used to send emails via given SMTP-server.
|
||||
// smtpWriter implements LoggerInterface and is used to send emails via given SMTP-server.
|
||||
type SmtpWriter struct {
|
||||
Username string `json:"Username"`
|
||||
Password string `json:"password"`
|
||||
@ -22,10 +22,21 @@ type SmtpWriter struct {
|
||||
Level int `json:"level"`
|
||||
}
|
||||
|
||||
// create smtp writer.
|
||||
func NewSmtpWriter() LoggerInterface {
|
||||
return &SmtpWriter{Level: LevelTrace}
|
||||
}
|
||||
|
||||
// init smtp writer with json config.
|
||||
// config like:
|
||||
// {
|
||||
// "Username":"example@gmail.com",
|
||||
// "password:"password",
|
||||
// "host":"smtp.gmail.com:465",
|
||||
// "subject":"email title",
|
||||
// "sendTos":["email1","email2"],
|
||||
// "level":LevelError
|
||||
// }
|
||||
func (s *SmtpWriter) Init(jsonconfig string) error {
|
||||
err := json.Unmarshal([]byte(jsonconfig), s)
|
||||
if err != nil {
|
||||
@ -34,6 +45,8 @@ func (s *SmtpWriter) Init(jsonconfig string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// write message in smtp writer.
|
||||
// it will send an email with subject and only this message.
|
||||
func (s *SmtpWriter) WriteMsg(msg string, level int) error {
|
||||
if level < s.Level {
|
||||
return nil
|
||||
@ -65,9 +78,12 @@ func (s *SmtpWriter) WriteMsg(msg string, level int) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// implementing method. empty.
|
||||
func (s *SmtpWriter) Flush() {
|
||||
return
|
||||
}
|
||||
|
||||
// implementing method. empty.
|
||||
func (s *SmtpWriter) Destroy() {
|
||||
return
|
||||
}
|
||||
|
@ -5,16 +5,17 @@ import (
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
//"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var gmfim map[string]*MemFileInfo = make(map[string]*MemFileInfo)
|
||||
var lock sync.RWMutex
|
||||
|
||||
// OpenMemZipFile returns MemFile object with a compressed static file.
|
||||
// it's used for serve static file if gzip enable.
|
||||
@ -32,12 +33,12 @@ func OpenMemZipFile(path string, zip string) (*MemFile, error) {
|
||||
|
||||
modtime := osfileinfo.ModTime()
|
||||
fileSize := osfileinfo.Size()
|
||||
|
||||
lock.RLock()
|
||||
cfi, ok := gmfim[zip+":"+path]
|
||||
lock.RUnlock()
|
||||
if ok && cfi.ModTime() == modtime && cfi.fileSize == fileSize {
|
||||
//fmt.Printf("read %s file %s from cache\n", zip, path)
|
||||
|
||||
} else {
|
||||
//fmt.Printf("NOT read %s file %s from cache\n", zip, path)
|
||||
var content []byte
|
||||
if zip == "gzip" {
|
||||
//将文件内容压缩到zipbuf中
|
||||
@ -81,8 +82,9 @@ func OpenMemZipFile(path string, zip string) (*MemFile, error) {
|
||||
}
|
||||
|
||||
cfi = &MemFileInfo{osfileinfo, modtime, content, int64(len(content)), fileSize}
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
gmfim[zip+":"+path] = cfi
|
||||
//fmt.Printf("%s file %s to %d, cache it\n", zip, path, len(content))
|
||||
}
|
||||
return &MemFile{fi: cfi, offset: 0}, nil
|
||||
}
|
||||
|
@ -61,6 +61,7 @@ var tpl = `
|
||||
</html>
|
||||
`
|
||||
|
||||
// render default application error page with error and stack string.
|
||||
func ShowErr(err interface{}, rw http.ResponseWriter, r *http.Request, Stack string) {
|
||||
t, _ := template.New("beegoerrortemp").Parse(tpl)
|
||||
data := make(map[string]string)
|
||||
@ -71,6 +72,7 @@ func ShowErr(err interface{}, rw http.ResponseWriter, r *http.Request, Stack str
|
||||
data["Stack"] = Stack
|
||||
data["BeegoVersion"] = VERSION
|
||||
data["GoVersion"] = runtime.Version()
|
||||
rw.WriteHeader(500)
|
||||
t.Execute(rw, data)
|
||||
}
|
||||
|
||||
@ -174,18 +176,19 @@ var errtpl = `
|
||||
</html>
|
||||
`
|
||||
|
||||
// map of http handlers for each error string.
|
||||
var ErrorMaps map[string]http.HandlerFunc
|
||||
|
||||
func init() {
|
||||
ErrorMaps = make(map[string]http.HandlerFunc)
|
||||
}
|
||||
|
||||
//404
|
||||
// show 404 notfound error.
|
||||
func NotFound(rw http.ResponseWriter, r *http.Request) {
|
||||
t, _ := template.New("beegoerrortemp").Parse(errtpl)
|
||||
data := make(map[string]interface{})
|
||||
data["Title"] = "Page Not Found"
|
||||
data["Content"] = template.HTML("<br>The Page You have requested flown the coop." +
|
||||
data["Content"] = template.HTML("<br>The page you have requested has flown the coop." +
|
||||
"<br>Perhaps you are here because:" +
|
||||
"<br><br><ul>" +
|
||||
"<br>The page has moved" +
|
||||
@ -198,28 +201,28 @@ func NotFound(rw http.ResponseWriter, r *http.Request) {
|
||||
t.Execute(rw, data)
|
||||
}
|
||||
|
||||
//401
|
||||
// show 401 unauthorized error.
|
||||
func Unauthorized(rw http.ResponseWriter, r *http.Request) {
|
||||
t, _ := template.New("beegoerrortemp").Parse(errtpl)
|
||||
data := make(map[string]interface{})
|
||||
data["Title"] = "Unauthorized"
|
||||
data["Content"] = template.HTML("<br>The Page You have requested can't authorized." +
|
||||
data["Content"] = template.HTML("<br>The page you have requested can't be authorized." +
|
||||
"<br>Perhaps you are here because:" +
|
||||
"<br><br><ul>" +
|
||||
"<br>Check the credentials that you supplied" +
|
||||
"<br>Check the address for errors" +
|
||||
"<br>The credentials you supplied are incorrect" +
|
||||
"<br>There are errors in the website address" +
|
||||
"</ul>")
|
||||
data["BeegoVersion"] = VERSION
|
||||
//rw.WriteHeader(http.StatusUnauthorized)
|
||||
t.Execute(rw, data)
|
||||
}
|
||||
|
||||
//403
|
||||
// show 403 forbidden error.
|
||||
func Forbidden(rw http.ResponseWriter, r *http.Request) {
|
||||
t, _ := template.New("beegoerrortemp").Parse(errtpl)
|
||||
data := make(map[string]interface{})
|
||||
data["Title"] = "Forbidden"
|
||||
data["Content"] = template.HTML("<br>The Page You have requested forbidden." +
|
||||
data["Content"] = template.HTML("<br>The page you have requested is forbidden." +
|
||||
"<br>Perhaps you are here because:" +
|
||||
"<br><br><ul>" +
|
||||
"<br>Your address may be blocked" +
|
||||
@ -231,12 +234,12 @@ func Forbidden(rw http.ResponseWriter, r *http.Request) {
|
||||
t.Execute(rw, data)
|
||||
}
|
||||
|
||||
//503
|
||||
// show 503 service unavailable error.
|
||||
func ServiceUnavailable(rw http.ResponseWriter, r *http.Request) {
|
||||
t, _ := template.New("beegoerrortemp").Parse(errtpl)
|
||||
data := make(map[string]interface{})
|
||||
data["Title"] = "Service Unavailable"
|
||||
data["Content"] = template.HTML("<br>The Page You have requested unavailable." +
|
||||
data["Content"] = template.HTML("<br>The page you have requested is unavailable." +
|
||||
"<br>Perhaps you are here because:" +
|
||||
"<br><br><ul>" +
|
||||
"<br><br>The page is overloaded" +
|
||||
@ -247,30 +250,32 @@ func ServiceUnavailable(rw http.ResponseWriter, r *http.Request) {
|
||||
t.Execute(rw, data)
|
||||
}
|
||||
|
||||
//500
|
||||
// show 500 internal server error.
|
||||
func InternalServerError(rw http.ResponseWriter, r *http.Request) {
|
||||
t, _ := template.New("beegoerrortemp").Parse(errtpl)
|
||||
data := make(map[string]interface{})
|
||||
data["Title"] = "Internal Server Error"
|
||||
data["Content"] = template.HTML("<br>The Page You have requested has down now." +
|
||||
data["Content"] = template.HTML("<br>The page you have requested is down right now." +
|
||||
"<br><br><ul>" +
|
||||
"<br>simply try again later" +
|
||||
"<br>you should report the fault to the website administrator" +
|
||||
"</ul>")
|
||||
"<br>Please try again later and report the error to the website administrator" +
|
||||
"<br></ul>")
|
||||
data["BeegoVersion"] = VERSION
|
||||
//rw.WriteHeader(http.StatusInternalServerError)
|
||||
t.Execute(rw, data)
|
||||
}
|
||||
|
||||
// show 500 internal error with simple text string.
|
||||
func SimpleServerError(rw http.ResponseWriter, r *http.Request) {
|
||||
http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// add http handler for given error string.
|
||||
func Errorhandler(err string, h http.HandlerFunc) {
|
||||
ErrorMaps[err] = h
|
||||
}
|
||||
|
||||
func RegisterErrorHander() {
|
||||
// register default error http handlers, 404,401,403,500 and 503.
|
||||
func RegisterErrorHandler() {
|
||||
if _, ok := ErrorMaps["404"]; !ok {
|
||||
ErrorMaps["404"] = NotFound
|
||||
}
|
||||
@ -292,6 +297,8 @@ func RegisterErrorHander() {
|
||||
}
|
||||
}
|
||||
|
||||
// show error string as simple text message.
|
||||
// if error string is empty, show 500 error as default.
|
||||
func Exception(errcode string, w http.ResponseWriter, r *http.Request, msg string) {
|
||||
if h, ok := ErrorMaps[errcode]; ok {
|
||||
isint, err := strconv.Atoi(errcode)
|
||||
|
@ -2,16 +2,19 @@ package middleware
|
||||
|
||||
import "fmt"
|
||||
|
||||
// http exceptions
|
||||
type HTTPException struct {
|
||||
StatusCode int // http status code 4xx, 5xx
|
||||
Description string
|
||||
}
|
||||
|
||||
// return http exception error string, e.g. "400 Bad Request".
|
||||
func (e *HTTPException) Error() string {
|
||||
// return `status description`, e.g. `400 Bad Request`
|
||||
return fmt.Sprintf("%d %s", e.StatusCode, e.Description)
|
||||
}
|
||||
|
||||
// map of http exceptions for each http status code int.
|
||||
// defined 400,401,403,404,405,500,502,503 and 504 default.
|
||||
var HTTPExceptionMaps map[int]HTTPException
|
||||
|
||||
func init() {
|
||||
|
3
mime.go
3
mime.go
@ -544,8 +544,9 @@ var mimemaps map[string]string = map[string]string{
|
||||
".mustache": "text/html",
|
||||
}
|
||||
|
||||
func initMime() {
|
||||
func initMime() error {
|
||||
for k, v := range mimemaps {
|
||||
mime.AddExtensionType(k, v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
12
orm/cmd.go
12
orm/cmd.go
@ -16,6 +16,7 @@ var (
|
||||
commands = make(map[string]commander)
|
||||
)
|
||||
|
||||
// print help.
|
||||
func printHelp(errs ...string) {
|
||||
content := `orm command usage:
|
||||
|
||||
@ -31,6 +32,7 @@ func printHelp(errs ...string) {
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
// listen for orm command and then run it if command arguments passed.
|
||||
func RunCommand() {
|
||||
if len(os.Args) < 2 || os.Args[1] != "orm" {
|
||||
return
|
||||
@ -58,6 +60,7 @@ func RunCommand() {
|
||||
}
|
||||
}
|
||||
|
||||
// sync database struct command interface.
|
||||
type commandSyncDb struct {
|
||||
al *alias
|
||||
force bool
|
||||
@ -66,6 +69,7 @@ type commandSyncDb struct {
|
||||
rtOnError bool
|
||||
}
|
||||
|
||||
// parse orm command line arguments.
|
||||
func (d *commandSyncDb) Parse(args []string) {
|
||||
var name string
|
||||
|
||||
@ -78,6 +82,7 @@ func (d *commandSyncDb) Parse(args []string) {
|
||||
d.al = getDbAlias(name)
|
||||
}
|
||||
|
||||
// run orm line command.
|
||||
func (d *commandSyncDb) Run() error {
|
||||
var drops []string
|
||||
if d.force {
|
||||
@ -208,10 +213,12 @@ func (d *commandSyncDb) Run() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// database creation commander interface implement.
|
||||
type commandSqlAll struct {
|
||||
al *alias
|
||||
}
|
||||
|
||||
// parse orm command line arguments.
|
||||
func (d *commandSqlAll) Parse(args []string) {
|
||||
var name string
|
||||
|
||||
@ -222,6 +229,7 @@ func (d *commandSqlAll) Parse(args []string) {
|
||||
d.al = getDbAlias(name)
|
||||
}
|
||||
|
||||
// run orm line command.
|
||||
func (d *commandSqlAll) Run() error {
|
||||
sqls, indexes := getDbCreateSql(d.al)
|
||||
var all []string
|
||||
@ -243,6 +251,10 @@ func init() {
|
||||
commands["sqlall"] = new(commandSqlAll)
|
||||
}
|
||||
|
||||
// run syncdb command line.
|
||||
// name means table's alias name. default is "default".
|
||||
// force means run next sql if the current is error.
|
||||
// verbose means show all info when running command or not.
|
||||
func RunSyncdb(name string, force bool, verbose bool) error {
|
||||
BootStrap()
|
||||
|
||||
|
@ -12,6 +12,7 @@ type dbIndex struct {
|
||||
Sql string
|
||||
}
|
||||
|
||||
// create database drop sql.
|
||||
func getDbDropSql(al *alias) (sqls []string) {
|
||||
if len(modelCache.cache) == 0 {
|
||||
fmt.Println("no Model found, need register your model")
|
||||
@ -26,6 +27,7 @@ func getDbDropSql(al *alias) (sqls []string) {
|
||||
return sqls
|
||||
}
|
||||
|
||||
// get database column type string.
|
||||
func getColumnTyp(al *alias, fi *fieldInfo) (col string) {
|
||||
T := al.DbBaser.DbTypes()
|
||||
fieldType := fi.fieldType
|
||||
@ -79,6 +81,7 @@ checkColumn:
|
||||
return
|
||||
}
|
||||
|
||||
// create alter sql string.
|
||||
func getColumnAddQuery(al *alias, fi *fieldInfo) string {
|
||||
Q := al.DbBaser.TableQuote()
|
||||
typ := getColumnTyp(al, fi)
|
||||
@ -90,6 +93,7 @@ func getColumnAddQuery(al *alias, fi *fieldInfo) string {
|
||||
return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s", Q, fi.mi.table, Q, Q, fi.column, Q, typ)
|
||||
}
|
||||
|
||||
// create database creation string.
|
||||
func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
|
||||
if len(modelCache.cache) == 0 {
|
||||
fmt.Println("no Model found, need register your model")
|
||||
|
173
orm/db.go
173
orm/db.go
@ -15,7 +15,7 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMissPK = errors.New("missed pk value")
|
||||
ErrMissPK = errors.New("missed pk value") // missing pk error
|
||||
)
|
||||
|
||||
var (
|
||||
@ -45,13 +45,22 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
// an instance of dbBaser interface/
|
||||
type dbBase struct {
|
||||
ins dbBaser
|
||||
}
|
||||
|
||||
// check dbBase implements dbBaser interface.
|
||||
var _ dbBaser = new(dbBase)
|
||||
|
||||
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) {
|
||||
// get struct columns values as interface slice.
|
||||
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, err error) {
|
||||
var columns []string
|
||||
|
||||
if names != nil {
|
||||
columns = *names
|
||||
}
|
||||
|
||||
for _, column := range cols {
|
||||
var fi *fieldInfo
|
||||
if fi, _ = mi.fields.GetByAny(column); fi != nil {
|
||||
@ -64,14 +73,24 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
|
||||
}
|
||||
value, err := d.collectFieldValue(mi, fi, ind, insert, tz)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
columns = append(columns, column)
|
||||
|
||||
if names != nil {
|
||||
columns = append(columns, column)
|
||||
}
|
||||
|
||||
values = append(values, value)
|
||||
}
|
||||
|
||||
if names != nil {
|
||||
*names = columns
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// get one field value in struct column as interface.
|
||||
func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) {
|
||||
var value interface{}
|
||||
if fi.pk {
|
||||
@ -140,6 +159,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// create insert sql preparation statement object.
|
||||
func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) {
|
||||
Q := d.ins.TableQuote()
|
||||
|
||||
@ -165,8 +185,9 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
|
||||
return stmt, query, err
|
||||
}
|
||||
|
||||
// insert struct with prepared statement and given struct reflect value.
|
||||
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||
_, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz)
|
||||
values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -185,6 +206,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
|
||||
}
|
||||
}
|
||||
|
||||
// query sql ,read records and persist in dbBaser.
|
||||
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) error {
|
||||
var whereCols []string
|
||||
var args []interface{}
|
||||
@ -192,7 +214,8 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
||||
// if specify cols length > 0, then use it for where condition.
|
||||
if len(cols) > 0 {
|
||||
var err error
|
||||
whereCols, args, err = d.collectValues(mi, ind, cols, false, false, tz)
|
||||
whereCols = make([]string, 0, len(cols))
|
||||
args, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -202,7 +225,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
||||
if ok == false {
|
||||
return ErrMissPK
|
||||
}
|
||||
whereCols = append(whereCols, pkColumn)
|
||||
whereCols = []string{pkColumn}
|
||||
args = append(args, pkValue)
|
||||
}
|
||||
|
||||
@ -243,16 +266,77 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
||||
return nil
|
||||
}
|
||||
|
||||
// execute insert sql dbQuerier with given struct reflect.Value.
|
||||
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||
names, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz)
|
||||
names := make([]string, 0, len(mi.fields.dbcols)-1)
|
||||
values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return d.InsertValue(q, mi, names, values)
|
||||
return d.InsertValue(q, mi, false, names, values)
|
||||
}
|
||||
|
||||
func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values []interface{}) (int64, error) {
|
||||
// multi-insert sql with given slice struct reflect.Value.
|
||||
func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) {
|
||||
var (
|
||||
cnt int64
|
||||
nums int
|
||||
values []interface{}
|
||||
names []string
|
||||
)
|
||||
|
||||
// typ := reflect.Indirect(mi.addrField).Type()
|
||||
|
||||
length := sind.Len()
|
||||
|
||||
for i := 1; i <= length; i++ {
|
||||
|
||||
ind := reflect.Indirect(sind.Index(i - 1))
|
||||
|
||||
// Is this needed ?
|
||||
// if !ind.Type().AssignableTo(typ) {
|
||||
// return cnt, ErrArgs
|
||||
// }
|
||||
|
||||
if i == 1 {
|
||||
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
|
||||
if err != nil {
|
||||
return cnt, err
|
||||
}
|
||||
values = make([]interface{}, bulk*len(vus))
|
||||
nums += copy(values, vus)
|
||||
|
||||
} else {
|
||||
|
||||
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
|
||||
if err != nil {
|
||||
return cnt, err
|
||||
}
|
||||
|
||||
if len(vus) != len(names) {
|
||||
return cnt, ErrArgs
|
||||
}
|
||||
|
||||
nums += copy(values[nums:], vus)
|
||||
}
|
||||
|
||||
if i > 1 && i%bulk == 0 || length == i {
|
||||
num, err := d.InsertValue(q, mi, true, names, values[:nums])
|
||||
if err != nil {
|
||||
return cnt, err
|
||||
}
|
||||
cnt += num
|
||||
nums = 0
|
||||
}
|
||||
}
|
||||
|
||||
return cnt, nil
|
||||
}
|
||||
|
||||
// execute insert sql with given struct and given values.
|
||||
// insert the given values, not the field values in struct.
|
||||
func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
|
||||
Q := d.ins.TableQuote()
|
||||
|
||||
marks := make([]string, len(names))
|
||||
@ -264,36 +348,51 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values
|
||||
qmarks := strings.Join(marks, ", ")
|
||||
columns := strings.Join(names, sep)
|
||||
|
||||
multi := len(values) / len(names)
|
||||
|
||||
if isMulti {
|
||||
qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks)
|
||||
|
||||
d.ins.ReplaceMarks(&query)
|
||||
|
||||
if d.ins.HasReturningID(mi, &query) {
|
||||
row := q.QueryRow(query, values...)
|
||||
var id int64
|
||||
err := row.Scan(&id)
|
||||
return id, err
|
||||
} else {
|
||||
if isMulti || !d.ins.HasReturningID(mi, &query) {
|
||||
if res, err := q.Exec(query, values...); err == nil {
|
||||
if isMulti {
|
||||
return res.RowsAffected()
|
||||
}
|
||||
return res.LastInsertId()
|
||||
} else {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
row := q.QueryRow(query, values...)
|
||||
var id int64
|
||||
err := row.Scan(&id)
|
||||
return id, err
|
||||
}
|
||||
}
|
||||
|
||||
// execute update sql dbQuerier with given struct reflect.Value.
|
||||
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
|
||||
pkName, pkValue, ok := getExistPk(mi, ind)
|
||||
if ok == false {
|
||||
return 0, ErrMissPK
|
||||
}
|
||||
|
||||
var setNames []string
|
||||
|
||||
// if specify cols length is zero, then commit all columns.
|
||||
if len(cols) == 0 {
|
||||
cols = mi.fields.dbcols
|
||||
setNames = make([]string, 0, len(mi.fields.dbcols)-1)
|
||||
} else {
|
||||
setNames = make([]string, 0, len(cols))
|
||||
}
|
||||
|
||||
setNames, setValues, err := d.collectValues(mi, ind, cols, true, false, tz)
|
||||
setValues, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -317,6 +416,8 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// execute delete sql dbQuerier with given struct reflect.Value.
|
||||
// delete index is pk.
|
||||
func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||
pkName, pkValue, ok := getExistPk(mi, ind)
|
||||
if ok == false {
|
||||
@ -358,6 +459,8 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// update table-related record by querySet.
|
||||
// need querySet not struct reflect.Value to update related records.
|
||||
func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) {
|
||||
columns := make([]string, 0, len(params))
|
||||
values := make([]interface{}, 0, len(params))
|
||||
@ -433,6 +536,8 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// delete related records.
|
||||
// do UpdateBanch or DeleteBanch by condition of tables' relationship.
|
||||
func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error {
|
||||
for _, fi := range mi.fields.fieldsReverse {
|
||||
fi = fi.reverseFieldInfo
|
||||
@ -459,8 +564,11 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *
|
||||
return nil
|
||||
}
|
||||
|
||||
// delete table-related records.
|
||||
func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) {
|
||||
tables := newDbTables(mi, d.ins)
|
||||
tables.skipEnd = true
|
||||
|
||||
if qs != nil {
|
||||
tables.parseRelated(qs.related, qs.relDepth)
|
||||
}
|
||||
@ -486,6 +594,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
||||
rs = r
|
||||
}
|
||||
|
||||
defer rs.Close()
|
||||
|
||||
var ref interface{}
|
||||
|
||||
args = make([]interface{}, 0)
|
||||
@ -532,6 +642,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// read related records.
|
||||
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
|
||||
|
||||
val := reflect.ValueOf(container)
|
||||
@ -640,6 +751,8 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
||||
refs[i] = &ref
|
||||
}
|
||||
|
||||
defer rs.Close()
|
||||
|
||||
slice := ind
|
||||
|
||||
var cnt int64
|
||||
@ -739,6 +852,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
||||
return cnt, nil
|
||||
}
|
||||
|
||||
// excute count sql and return count result int64.
|
||||
func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
|
||||
tables := newDbTables(mi, d.ins)
|
||||
tables.parseRelated(qs.related, qs.relDepth)
|
||||
@ -759,6 +873,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
|
||||
return
|
||||
}
|
||||
|
||||
// generate sql with replacing operator string placeholders and replaced values.
|
||||
func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
|
||||
sql := ""
|
||||
params := getFlatParams(fi, args, tz)
|
||||
@ -812,10 +927,12 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri
|
||||
return sql, params
|
||||
}
|
||||
|
||||
// gernerate sql string with inner function, such as UPPER(text).
|
||||
func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) {
|
||||
// default not use
|
||||
}
|
||||
|
||||
// set values to struct column.
|
||||
func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) {
|
||||
for i, column := range cols {
|
||||
val := reflect.Indirect(reflect.ValueOf(values[i])).Interface()
|
||||
@ -837,6 +954,7 @@ func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string,
|
||||
}
|
||||
}
|
||||
|
||||
// convert value from database result to value following in field type.
|
||||
func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) {
|
||||
if val == nil {
|
||||
return nil, nil
|
||||
@ -989,6 +1107,7 @@ end:
|
||||
|
||||
}
|
||||
|
||||
// set one value to struct column field.
|
||||
func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) {
|
||||
|
||||
fieldType := fi.fieldType
|
||||
@ -1063,6 +1182,7 @@ setValue:
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// query sql, read values , save to *[]ParamList.
|
||||
func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) {
|
||||
|
||||
var (
|
||||
@ -1150,6 +1270,8 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
|
||||
refs[i] = &ref
|
||||
}
|
||||
|
||||
defer rs.Close()
|
||||
|
||||
var (
|
||||
cnt int64
|
||||
columns []string
|
||||
@ -1228,6 +1350,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
|
||||
return cnt, nil
|
||||
}
|
||||
|
||||
// flag of update joined record.
|
||||
func (d *dbBase) SupportUpdateJoin() bool {
|
||||
return true
|
||||
}
|
||||
@ -1236,30 +1359,37 @@ func (d *dbBase) MaxLimit() uint64 {
|
||||
return 18446744073709551615
|
||||
}
|
||||
|
||||
// return quote.
|
||||
func (d *dbBase) TableQuote() string {
|
||||
return "`"
|
||||
}
|
||||
|
||||
// replace value placeholer in parametered sql string.
|
||||
func (d *dbBase) ReplaceMarks(query *string) {
|
||||
// default use `?` as mark, do nothing
|
||||
}
|
||||
|
||||
// flag of RETURNING sql.
|
||||
func (d *dbBase) HasReturningID(*modelInfo, *string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// convert time from db.
|
||||
func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) {
|
||||
*t = t.In(tz)
|
||||
}
|
||||
|
||||
// convert time to db.
|
||||
func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) {
|
||||
*t = t.In(tz)
|
||||
}
|
||||
|
||||
// get database types.
|
||||
func (d *dbBase) DbTypes() map[string]string {
|
||||
return nil
|
||||
}
|
||||
|
||||
// gt all tables.
|
||||
func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
|
||||
tables := make(map[string]bool)
|
||||
query := d.ins.ShowTablesQuery()
|
||||
@ -1268,6 +1398,8 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
|
||||
return tables, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var table string
|
||||
err := rows.Scan(&table)
|
||||
@ -1282,6 +1414,7 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
// get all cloumns in table.
|
||||
func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
|
||||
columns := make(map[string][3]string)
|
||||
query := d.ins.ShowColumnsQuery(table)
|
||||
@ -1290,6 +1423,8 @@ func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, e
|
||||
return columns, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var (
|
||||
name string
|
||||
@ -1306,18 +1441,22 @@ func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, e
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
// not implement.
|
||||
func (d *dbBase) OperatorSql(operator string) string {
|
||||
panic(ErrNotImplement)
|
||||
}
|
||||
|
||||
// not implement.
|
||||
func (d *dbBase) ShowTablesQuery() string {
|
||||
panic(ErrNotImplement)
|
||||
}
|
||||
|
||||
// not implement.
|
||||
func (d *dbBase) ShowColumnsQuery(table string) string {
|
||||
panic(ErrNotImplement)
|
||||
}
|
||||
|
||||
// not implement.
|
||||
func (d *dbBase) IndexExists(dbQuerier, string, string) bool {
|
||||
panic(ErrNotImplement)
|
||||
}
|
||||
|
@ -9,27 +9,32 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// database driver constant int.
|
||||
type DriverType int
|
||||
|
||||
const (
|
||||
_ DriverType = iota
|
||||
DR_MySQL
|
||||
DR_Sqlite
|
||||
DR_Oracle
|
||||
DR_Postgres
|
||||
_ DriverType = iota // int enum type
|
||||
DR_MySQL // mysql
|
||||
DR_Sqlite // sqlite
|
||||
DR_Oracle // oracle
|
||||
DR_Postgres // pgsql
|
||||
)
|
||||
|
||||
// database driver string.
|
||||
type driver string
|
||||
|
||||
// get type constant int of current driver..
|
||||
func (d driver) Type() DriverType {
|
||||
a, _ := dataBaseCache.get(string(d))
|
||||
return a.Driver
|
||||
}
|
||||
|
||||
// get name of current driver
|
||||
func (d driver) Name() string {
|
||||
return string(d)
|
||||
}
|
||||
|
||||
// check driver iis implemented Driver interface or not.
|
||||
var _ Driver = new(driver)
|
||||
|
||||
var (
|
||||
@ -47,11 +52,13 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
// database alias cacher.
|
||||
type _dbCache struct {
|
||||
mux sync.RWMutex
|
||||
cache map[string]*alias
|
||||
}
|
||||
|
||||
// add database alias with original name.
|
||||
func (ac *_dbCache) add(name string, al *alias) (added bool) {
|
||||
ac.mux.Lock()
|
||||
defer ac.mux.Unlock()
|
||||
@ -62,6 +69,7 @@ func (ac *_dbCache) add(name string, al *alias) (added bool) {
|
||||
return
|
||||
}
|
||||
|
||||
// get database alias if cached.
|
||||
func (ac *_dbCache) get(name string) (al *alias, ok bool) {
|
||||
ac.mux.RLock()
|
||||
defer ac.mux.RUnlock()
|
||||
@ -69,6 +77,7 @@ func (ac *_dbCache) get(name string) (al *alias, ok bool) {
|
||||
return
|
||||
}
|
||||
|
||||
// get default alias.
|
||||
func (ac *_dbCache) getDefault() (al *alias) {
|
||||
al, _ = ac.get("default")
|
||||
return
|
||||
@ -123,21 +132,18 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
|
||||
|
||||
switch al.Driver {
|
||||
case DR_MySQL:
|
||||
row := al.DB.QueryRow("SELECT @@session.time_zone")
|
||||
row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
|
||||
var tz string
|
||||
row.Scan(&tz)
|
||||
if tz == "SYSTEM" {
|
||||
tz = ""
|
||||
row = al.DB.QueryRow("SELECT @@system_time_zone")
|
||||
row.Scan(&tz)
|
||||
t, err := time.Parse("MST", tz)
|
||||
if err == nil {
|
||||
al.TZ = t.Location()
|
||||
if len(tz) >= 8 {
|
||||
if tz[0] != '-' {
|
||||
tz = "+" + tz
|
||||
}
|
||||
} else {
|
||||
t, err := time.Parse("-07:00", tz)
|
||||
t, err := time.Parse("-07:00:00", tz)
|
||||
if err == nil {
|
||||
al.TZ = t.Location()
|
||||
} else {
|
||||
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@ -163,6 +169,8 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
|
||||
loc, err := time.LoadLocation(tz)
|
||||
if err == nil {
|
||||
al.TZ = loc
|
||||
} else {
|
||||
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// mysql operators.
|
||||
var mysqlOperators = map[string]string{
|
||||
"exact": "= ?",
|
||||
"iexact": "LIKE ?",
|
||||
@ -21,6 +22,7 @@ var mysqlOperators = map[string]string{
|
||||
"iendswith": "LIKE ?",
|
||||
}
|
||||
|
||||
// mysql column field types.
|
||||
var mysqlTypes = map[string]string{
|
||||
"auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY",
|
||||
"pk": "NOT NULL PRIMARY KEY",
|
||||
@ -41,29 +43,35 @@ var mysqlTypes = map[string]string{
|
||||
"float64-decimal": "numeric(%d, %d)",
|
||||
}
|
||||
|
||||
// mysql dbBaser implementation.
|
||||
type dbBaseMysql struct {
|
||||
dbBase
|
||||
}
|
||||
|
||||
var _ dbBaser = new(dbBaseMysql)
|
||||
|
||||
// get mysql operator.
|
||||
func (d *dbBaseMysql) OperatorSql(operator string) string {
|
||||
return mysqlOperators[operator]
|
||||
}
|
||||
|
||||
// get mysql table field types.
|
||||
func (d *dbBaseMysql) DbTypes() map[string]string {
|
||||
return mysqlTypes
|
||||
}
|
||||
|
||||
// show table sql for mysql.
|
||||
func (d *dbBaseMysql) ShowTablesQuery() string {
|
||||
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
|
||||
}
|
||||
|
||||
// show columns sql of table for mysql.
|
||||
func (d *dbBaseMysql) ShowColumnsQuery(table string) string {
|
||||
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
|
||||
"WHERE table_schema = DATABASE() AND table_name = '%s'", table)
|
||||
}
|
||||
|
||||
// execute sql to check index exist.
|
||||
func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool {
|
||||
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
|
||||
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
|
||||
@ -72,6 +80,7 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool
|
||||
return cnt > 0
|
||||
}
|
||||
|
||||
// create new mysql dbBaser.
|
||||
func newdbBaseMysql() dbBaser {
|
||||
b := new(dbBaseMysql)
|
||||
b.ins = b
|
||||
|
@ -1,11 +1,13 @@
|
||||
package orm
|
||||
|
||||
// oracle dbBaser
|
||||
type dbBaseOracle struct {
|
||||
dbBase
|
||||
}
|
||||
|
||||
var _ dbBaser = new(dbBaseOracle)
|
||||
|
||||
// create oracle dbBaser.
|
||||
func newdbBaseOracle() dbBaser {
|
||||
b := new(dbBaseOracle)
|
||||
b.ins = b
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// postgresql operators.
|
||||
var postgresOperators = map[string]string{
|
||||
"exact": "= ?",
|
||||
"iexact": "= UPPER(?)",
|
||||
@ -20,6 +21,7 @@ var postgresOperators = map[string]string{
|
||||
"iendswith": "LIKE UPPER(?)",
|
||||
}
|
||||
|
||||
// postgresql column field types.
|
||||
var postgresTypes = map[string]string{
|
||||
"auto": "serial NOT NULL PRIMARY KEY",
|
||||
"pk": "NOT NULL PRIMARY KEY",
|
||||
@ -40,16 +42,19 @@ var postgresTypes = map[string]string{
|
||||
"float64-decimal": "numeric(%d, %d)",
|
||||
}
|
||||
|
||||
// postgresql dbBaser.
|
||||
type dbBasePostgres struct {
|
||||
dbBase
|
||||
}
|
||||
|
||||
var _ dbBaser = new(dbBasePostgres)
|
||||
|
||||
// get postgresql operator.
|
||||
func (d *dbBasePostgres) OperatorSql(operator string) string {
|
||||
return postgresOperators[operator]
|
||||
}
|
||||
|
||||
// generate functioned sql string, such as contains(text).
|
||||
func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
|
||||
switch operator {
|
||||
case "contains", "startswith", "endswith":
|
||||
@ -59,6 +64,7 @@ func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string,
|
||||
}
|
||||
}
|
||||
|
||||
// postgresql unsupports updating joined record.
|
||||
func (d *dbBasePostgres) SupportUpdateJoin() bool {
|
||||
return false
|
||||
}
|
||||
@ -67,10 +73,13 @@ func (d *dbBasePostgres) MaxLimit() uint64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// postgresql quote is ".
|
||||
func (d *dbBasePostgres) TableQuote() string {
|
||||
return `"`
|
||||
}
|
||||
|
||||
// postgresql value placeholder is $n.
|
||||
// replace default ? to $n.
|
||||
func (d *dbBasePostgres) ReplaceMarks(query *string) {
|
||||
q := *query
|
||||
num := 0
|
||||
@ -97,6 +106,7 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
|
||||
*query = string(data)
|
||||
}
|
||||
|
||||
// make returning sql support for postgresql.
|
||||
func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) {
|
||||
if mi.fields.pk.auto {
|
||||
if query != nil {
|
||||
@ -107,18 +117,22 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool)
|
||||
return
|
||||
}
|
||||
|
||||
// show table sql for postgresql.
|
||||
func (d *dbBasePostgres) ShowTablesQuery() string {
|
||||
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')"
|
||||
}
|
||||
|
||||
// show table columns sql for postgresql.
|
||||
func (d *dbBasePostgres) ShowColumnsQuery(table string) string {
|
||||
return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table)
|
||||
}
|
||||
|
||||
// get column types of postgresql.
|
||||
func (d *dbBasePostgres) DbTypes() map[string]string {
|
||||
return postgresTypes
|
||||
}
|
||||
|
||||
// check index exist in postgresql.
|
||||
func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool {
|
||||
query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name)
|
||||
row := db.QueryRow(query)
|
||||
@ -127,6 +141,7 @@ func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bo
|
||||
return cnt > 0
|
||||
}
|
||||
|
||||
// create new postgresql dbBaser.
|
||||
func newdbBasePostgres() dbBaser {
|
||||
b := new(dbBasePostgres)
|
||||
b.ins = b
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// sqlite operators.
|
||||
var sqliteOperators = map[string]string{
|
||||
"exact": "= ?",
|
||||
"iexact": "LIKE ? ESCAPE '\\'",
|
||||
@ -20,6 +21,7 @@ var sqliteOperators = map[string]string{
|
||||
"iendswith": "LIKE ? ESCAPE '\\'",
|
||||
}
|
||||
|
||||
// sqlite column types.
|
||||
var sqliteTypes = map[string]string{
|
||||
"auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT",
|
||||
"pk": "NOT NULL PRIMARY KEY",
|
||||
@ -40,38 +42,47 @@ var sqliteTypes = map[string]string{
|
||||
"float64-decimal": "decimal",
|
||||
}
|
||||
|
||||
// sqlite dbBaser.
|
||||
type dbBaseSqlite struct {
|
||||
dbBase
|
||||
}
|
||||
|
||||
var _ dbBaser = new(dbBaseSqlite)
|
||||
|
||||
// get sqlite operator.
|
||||
func (d *dbBaseSqlite) OperatorSql(operator string) string {
|
||||
return sqliteOperators[operator]
|
||||
}
|
||||
|
||||
// generate functioned sql for sqlite.
|
||||
// only support DATE(text).
|
||||
func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
|
||||
if fi.fieldType == TypeDateField {
|
||||
*leftCol = fmt.Sprintf("DATE(%s)", *leftCol)
|
||||
}
|
||||
}
|
||||
|
||||
// unable updating joined record in sqlite.
|
||||
func (d *dbBaseSqlite) SupportUpdateJoin() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// max int in sqlite.
|
||||
func (d *dbBaseSqlite) MaxLimit() uint64 {
|
||||
return 9223372036854775807
|
||||
}
|
||||
|
||||
// get column types in sqlite.
|
||||
func (d *dbBaseSqlite) DbTypes() map[string]string {
|
||||
return sqliteTypes
|
||||
}
|
||||
|
||||
// get show tables sql in sqlite.
|
||||
func (d *dbBaseSqlite) ShowTablesQuery() string {
|
||||
return "SELECT name FROM sqlite_master WHERE type = 'table'"
|
||||
}
|
||||
|
||||
// get columns in sqlite.
|
||||
func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
|
||||
query := d.ins.ShowColumnsQuery(table)
|
||||
rows, err := db.Query(query)
|
||||
@ -92,10 +103,12 @@ func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]str
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
// get show columns sql in sqlite.
|
||||
func (d *dbBaseSqlite) ShowColumnsQuery(table string) string {
|
||||
return fmt.Sprintf("pragma table_info('%s')", table)
|
||||
}
|
||||
|
||||
// check index exist in sqlite.
|
||||
func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool {
|
||||
query := fmt.Sprintf("PRAGMA index_list('%s')", table)
|
||||
rows, err := db.Query(query)
|
||||
@ -113,6 +126,7 @@ func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool
|
||||
return false
|
||||
}
|
||||
|
||||
// create new sqlite dbBaser.
|
||||
func newdbBaseSqlite() dbBaser {
|
||||
b := new(dbBaseSqlite)
|
||||
b.ins = b
|
||||
|
112
orm/db_tables.go
112
orm/db_tables.go
@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// table info struct.
|
||||
type dbTable struct {
|
||||
id int
|
||||
index string
|
||||
@ -18,13 +19,17 @@ type dbTable struct {
|
||||
jtl *dbTable
|
||||
}
|
||||
|
||||
// tables collection struct, contains some tables.
|
||||
type dbTables struct {
|
||||
tablesM map[string]*dbTable
|
||||
tables []*dbTable
|
||||
mi *modelInfo
|
||||
base dbBaser
|
||||
skipEnd bool
|
||||
}
|
||||
|
||||
// set table info to collection.
|
||||
// if not exist, create new.
|
||||
func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
|
||||
name := strings.Join(names, ExprSep)
|
||||
if j, ok := t.tablesM[name]; ok {
|
||||
@ -41,6 +46,7 @@ func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool)
|
||||
return t.tablesM[name]
|
||||
}
|
||||
|
||||
// add table info to collection.
|
||||
func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
|
||||
name := strings.Join(names, ExprSep)
|
||||
if _, ok := t.tablesM[name]; ok == false {
|
||||
@ -53,11 +59,14 @@ func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool)
|
||||
return t.tablesM[name], false
|
||||
}
|
||||
|
||||
// get table info in collection.
|
||||
func (t *dbTables) get(name string) (*dbTable, bool) {
|
||||
j, ok := t.tablesM[name]
|
||||
return j, ok
|
||||
}
|
||||
|
||||
// get related fields info in recursive depth loop.
|
||||
// loop once, depth decreases one.
|
||||
func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
|
||||
if depth < 0 || fi.fieldType == RelManyToMany {
|
||||
return related
|
||||
@ -78,6 +87,7 @@ func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []
|
||||
return related
|
||||
}
|
||||
|
||||
// parse related fields.
|
||||
func (t *dbTables) parseRelated(rels []string, depth int) {
|
||||
|
||||
relsNum := len(rels)
|
||||
@ -111,7 +121,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
|
||||
names = append(names, fi.name)
|
||||
mmi = fi.relModelInfo
|
||||
|
||||
if fi.null {
|
||||
if fi.null || t.skipEnd {
|
||||
inner = false
|
||||
}
|
||||
|
||||
@ -139,6 +149,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
|
||||
}
|
||||
}
|
||||
|
||||
// generate join string.
|
||||
func (t *dbTables) getJoinSql() (join string) {
|
||||
Q := t.base.TableQuote()
|
||||
|
||||
@ -185,9 +196,12 @@ func (t *dbTables) getJoinSql() (join string) {
|
||||
return
|
||||
}
|
||||
|
||||
// parse orm model struct field tag expression.
|
||||
func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) {
|
||||
var (
|
||||
jtl *dbTable
|
||||
fi *fieldInfo
|
||||
fiN *fieldInfo
|
||||
mmi = mi
|
||||
)
|
||||
|
||||
@ -196,9 +210,22 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
|
||||
|
||||
inner := true
|
||||
|
||||
loopFor:
|
||||
for i, ex := range exprs {
|
||||
|
||||
fi, ok := mmi.fields.GetByAny(ex)
|
||||
var ok, okN bool
|
||||
|
||||
if fiN != nil {
|
||||
fi = fiN
|
||||
ok = true
|
||||
fiN = nil
|
||||
}
|
||||
|
||||
if i == 0 {
|
||||
fi, ok = mmi.fields.GetByAny(ex)
|
||||
}
|
||||
|
||||
_ = okN
|
||||
|
||||
if ok {
|
||||
|
||||
@ -216,44 +243,61 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
|
||||
mmi = fi.reverseFieldInfo.mi
|
||||
}
|
||||
|
||||
if i < num {
|
||||
fiN, okN = mmi.fields.GetByAny(exprs[i+1])
|
||||
}
|
||||
|
||||
if isRel && (fi.mi.isThrough == false || num != i) {
|
||||
if fi.null {
|
||||
if fi.null || t.skipEnd {
|
||||
inner = false
|
||||
}
|
||||
|
||||
jt, _ := t.add(names, mmi, fi, inner)
|
||||
jt.jtl = jtl
|
||||
jtl = jt
|
||||
}
|
||||
|
||||
if num == i {
|
||||
if i == 0 || jtl == nil {
|
||||
index = "T0"
|
||||
} else {
|
||||
index = jtl.index
|
||||
}
|
||||
|
||||
info = fi
|
||||
|
||||
if jtl == nil {
|
||||
name = fi.name
|
||||
} else {
|
||||
name = jtl.name + ExprSep + fi.name
|
||||
}
|
||||
|
||||
switch {
|
||||
case fi.rel:
|
||||
|
||||
case fi.reverse:
|
||||
switch fi.reverseFieldInfo.fieldType {
|
||||
case RelOneToOne, RelForeignKey:
|
||||
index = jtl.index
|
||||
info = fi.reverseFieldInfo.mi.fields.pk
|
||||
name = info.name
|
||||
if t.skipEnd && okN || !t.skipEnd {
|
||||
if t.skipEnd && okN && fiN.pk {
|
||||
goto loopEnd
|
||||
}
|
||||
|
||||
jt, _ := t.add(names, mmi, fi, inner)
|
||||
jt.jtl = jtl
|
||||
jtl = jt
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if num != i {
|
||||
continue
|
||||
}
|
||||
|
||||
loopEnd:
|
||||
|
||||
if i == 0 || jtl == nil {
|
||||
index = "T0"
|
||||
} else {
|
||||
index = jtl.index
|
||||
}
|
||||
|
||||
info = fi
|
||||
|
||||
if jtl == nil {
|
||||
name = fi.name
|
||||
} else {
|
||||
name = jtl.name + ExprSep + fi.name
|
||||
}
|
||||
|
||||
switch {
|
||||
case fi.rel:
|
||||
|
||||
case fi.reverse:
|
||||
switch fi.reverseFieldInfo.fieldType {
|
||||
case RelOneToOne, RelForeignKey:
|
||||
index = jtl.index
|
||||
info = fi.reverseFieldInfo.mi.fields.pk
|
||||
name = info.name
|
||||
}
|
||||
}
|
||||
|
||||
break loopFor
|
||||
|
||||
} else {
|
||||
index = ""
|
||||
name = ""
|
||||
@ -267,6 +311,7 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
|
||||
return
|
||||
}
|
||||
|
||||
// generate condition sql.
|
||||
func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
|
||||
if cond == nil || cond.IsEmpty() {
|
||||
return
|
||||
@ -331,6 +376,7 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
|
||||
return
|
||||
}
|
||||
|
||||
// generate order sql.
|
||||
func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
|
||||
if len(orders) == 0 {
|
||||
return
|
||||
@ -359,6 +405,7 @@ func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
|
||||
return
|
||||
}
|
||||
|
||||
// generate limit sql.
|
||||
func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits string) {
|
||||
if limit == 0 {
|
||||
limit = int64(DefaultRowsLimit)
|
||||
@ -381,6 +428,7 @@ func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits
|
||||
return
|
||||
}
|
||||
|
||||
// crete new tables collection.
|
||||
func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
|
||||
tables := &dbTables{}
|
||||
tables.tablesM = make(map[string]*dbTable)
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// get table alias.
|
||||
func getDbAlias(name string) *alias {
|
||||
if al, ok := dataBaseCache.get(name); ok {
|
||||
return al
|
||||
@ -15,6 +16,7 @@ func getDbAlias(name string) *alias {
|
||||
return nil
|
||||
}
|
||||
|
||||
// get pk column info.
|
||||
func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
|
||||
fi := mi.fields.pk
|
||||
|
||||
@ -37,6 +39,7 @@ func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interfac
|
||||
return
|
||||
}
|
||||
|
||||
// get fields description as flatted string.
|
||||
func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) {
|
||||
|
||||
outFor:
|
||||
|
@ -41,6 +41,7 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
// model info collection
|
||||
type _modelCache struct {
|
||||
sync.RWMutex
|
||||
orders []string
|
||||
@ -49,6 +50,7 @@ type _modelCache struct {
|
||||
done bool
|
||||
}
|
||||
|
||||
// get all model info
|
||||
func (mc *_modelCache) all() map[string]*modelInfo {
|
||||
m := make(map[string]*modelInfo, len(mc.cache))
|
||||
for k, v := range mc.cache {
|
||||
@ -57,6 +59,7 @@ func (mc *_modelCache) all() map[string]*modelInfo {
|
||||
return m
|
||||
}
|
||||
|
||||
// get orderd model info
|
||||
func (mc *_modelCache) allOrdered() []*modelInfo {
|
||||
m := make([]*modelInfo, 0, len(mc.orders))
|
||||
for _, table := range mc.orders {
|
||||
@ -65,16 +68,19 @@ func (mc *_modelCache) allOrdered() []*modelInfo {
|
||||
return m
|
||||
}
|
||||
|
||||
// get model info by table name
|
||||
func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) {
|
||||
mi, ok = mc.cache[table]
|
||||
return
|
||||
}
|
||||
|
||||
// get model info by field name
|
||||
func (mc *_modelCache) getByFN(name string) (mi *modelInfo, ok bool) {
|
||||
mi, ok = mc.cacheByFN[name]
|
||||
return
|
||||
}
|
||||
|
||||
// set model info to collection
|
||||
func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
|
||||
mii := mc.cache[table]
|
||||
mc.cache[table] = mi
|
||||
@ -85,6 +91,7 @@ func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
|
||||
return mii
|
||||
}
|
||||
|
||||
// clean all model info.
|
||||
func (mc *_modelCache) clean() {
|
||||
mc.orders = make([]string, 0)
|
||||
mc.cache = make(map[string]*modelInfo)
|
||||
|
@ -8,6 +8,8 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// register models.
|
||||
// prefix means table name prefix.
|
||||
func registerModel(model interface{}, prefix string) {
|
||||
val := reflect.ValueOf(model)
|
||||
ind := reflect.Indirect(val)
|
||||
@ -67,6 +69,7 @@ func registerModel(model interface{}, prefix string) {
|
||||
modelCache.set(table, info)
|
||||
}
|
||||
|
||||
// boostrap models
|
||||
func bootStrap() {
|
||||
if modelCache.done {
|
||||
return
|
||||
@ -281,6 +284,7 @@ end:
|
||||
}
|
||||
}
|
||||
|
||||
// register models
|
||||
func RegisterModel(models ...interface{}) {
|
||||
if modelCache.done {
|
||||
panic(fmt.Errorf("RegisterModel must be run before BootStrap"))
|
||||
@ -302,6 +306,8 @@ func RegisterModelWithPrefix(prefix string, models ...interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
// bootrap models.
|
||||
// make all model parsed and can not add more models
|
||||
func BootStrap() {
|
||||
if modelCache.done {
|
||||
return
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
|
||||
var errSkipField = errors.New("skip field")
|
||||
|
||||
// field info collection
|
||||
type fields struct {
|
||||
pk *fieldInfo
|
||||
columns map[string]*fieldInfo
|
||||
@ -23,6 +24,7 @@ type fields struct {
|
||||
dbcols []string
|
||||
}
|
||||
|
||||
// add field info
|
||||
func (f *fields) Add(fi *fieldInfo) (added bool) {
|
||||
if f.fields[fi.name] == nil && f.columns[fi.column] == nil {
|
||||
f.columns[fi.column] = fi
|
||||
@ -49,14 +51,17 @@ func (f *fields) Add(fi *fieldInfo) (added bool) {
|
||||
return true
|
||||
}
|
||||
|
||||
// get field info by name
|
||||
func (f *fields) GetByName(name string) *fieldInfo {
|
||||
return f.fields[name]
|
||||
}
|
||||
|
||||
// get field info by column name
|
||||
func (f *fields) GetByColumn(column string) *fieldInfo {
|
||||
return f.columns[column]
|
||||
}
|
||||
|
||||
// get field info by string, name is prior
|
||||
func (f *fields) GetByAny(name string) (*fieldInfo, bool) {
|
||||
if fi, ok := f.fields[name]; ok {
|
||||
return fi, ok
|
||||
@ -70,6 +75,7 @@ func (f *fields) GetByAny(name string) (*fieldInfo, bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// create new field info collection
|
||||
func newFields() *fields {
|
||||
f := new(fields)
|
||||
f.fields = make(map[string]*fieldInfo)
|
||||
@ -79,6 +85,7 @@ func newFields() *fields {
|
||||
return f
|
||||
}
|
||||
|
||||
// single field info
|
||||
type fieldInfo struct {
|
||||
mi *modelInfo
|
||||
fieldIndex int
|
||||
@ -115,6 +122,7 @@ type fieldInfo struct {
|
||||
onDelete string
|
||||
}
|
||||
|
||||
// new field info
|
||||
func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField) (fi *fieldInfo, err error) {
|
||||
var (
|
||||
tag string
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// single model info
|
||||
type modelInfo struct {
|
||||
pkg string
|
||||
name string
|
||||
@ -20,6 +21,7 @@ type modelInfo struct {
|
||||
isThrough bool
|
||||
}
|
||||
|
||||
// new model info
|
||||
func newModelInfo(val reflect.Value) (info *modelInfo) {
|
||||
var (
|
||||
err error
|
||||
@ -79,6 +81,8 @@ func newModelInfo(val reflect.Value) (info *modelInfo) {
|
||||
return
|
||||
}
|
||||
|
||||
// combine related model info to new model info.
|
||||
// prepare for relation models query.
|
||||
func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
|
||||
info = new(modelInfo)
|
||||
info.fields = newFields()
|
||||
|
@ -7,10 +7,12 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// get reflect.Type name with package path.
|
||||
func getFullName(typ reflect.Type) string {
|
||||
return typ.PkgPath() + "." + typ.Name()
|
||||
}
|
||||
|
||||
// get table name. method, or field name. auto snaked.
|
||||
func getTableName(val reflect.Value) string {
|
||||
ind := reflect.Indirect(val)
|
||||
fun := val.MethodByName("TableName")
|
||||
@ -26,6 +28,7 @@ func getTableName(val reflect.Value) string {
|
||||
return snakeString(ind.Type().Name())
|
||||
}
|
||||
|
||||
// get table engine, mysiam or innodb.
|
||||
func getTableEngine(val reflect.Value) string {
|
||||
fun := val.MethodByName("TableEngine")
|
||||
if fun.IsValid() {
|
||||
@ -40,6 +43,7 @@ func getTableEngine(val reflect.Value) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// get table index from method.
|
||||
func getTableIndex(val reflect.Value) [][]string {
|
||||
fun := val.MethodByName("TableIndex")
|
||||
if fun.IsValid() {
|
||||
@ -56,6 +60,7 @@ func getTableIndex(val reflect.Value) [][]string {
|
||||
return nil
|
||||
}
|
||||
|
||||
// get table unique from method
|
||||
func getTableUnique(val reflect.Value) [][]string {
|
||||
fun := val.MethodByName("TableUnique")
|
||||
if fun.IsValid() {
|
||||
@ -72,6 +77,7 @@ func getTableUnique(val reflect.Value) [][]string {
|
||||
return nil
|
||||
}
|
||||
|
||||
// get snaked column name
|
||||
func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string {
|
||||
col = strings.ToLower(col)
|
||||
column := col
|
||||
@ -89,6 +95,7 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col
|
||||
return column
|
||||
}
|
||||
|
||||
// return field type as type constant from reflect.Value
|
||||
func getFieldType(val reflect.Value) (ft int, err error) {
|
||||
elm := reflect.Indirect(val)
|
||||
switch elm.Kind() {
|
||||
@ -128,6 +135,7 @@ func getFieldType(val reflect.Value) (ft int, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// parse struct tag string
|
||||
func parseStructTag(data string, attrs *map[string]bool, tags *map[string]string) {
|
||||
attr := make(map[string]bool)
|
||||
tag := make(map[string]string)
|
||||
|
118
orm/orm.go
118
orm/orm.go
@ -25,6 +25,7 @@ var (
|
||||
ErrMultiRows = errors.New("<QuerySeter> return multi rows")
|
||||
ErrNoRows = errors.New("<QuerySeter> no row found")
|
||||
ErrStmtClosed = errors.New("<QuerySeter> stmt already closed")
|
||||
ErrArgs = errors.New("<Ormer> args error may be empty")
|
||||
ErrNotImplement = errors.New("have not implement")
|
||||
)
|
||||
|
||||
@ -39,11 +40,12 @@ type orm struct {
|
||||
|
||||
var _ Ormer = new(orm)
|
||||
|
||||
func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) {
|
||||
// get model info and model reflect value
|
||||
func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) {
|
||||
val := reflect.ValueOf(md)
|
||||
ind = reflect.Indirect(val)
|
||||
typ := ind.Type()
|
||||
if val.Kind() != reflect.Ptr {
|
||||
if needPtr && val.Kind() != reflect.Ptr {
|
||||
panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ)))
|
||||
}
|
||||
name := getFullName(typ)
|
||||
@ -53,6 +55,7 @@ func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) {
|
||||
panic(fmt.Errorf("<Ormer> table: `%s` not found, maybe not RegisterModel", name))
|
||||
}
|
||||
|
||||
// get field info from model info by given field name
|
||||
func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
|
||||
fi, ok := mi.fields.GetByAny(name)
|
||||
if !ok {
|
||||
@ -61,8 +64,9 @@ func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
|
||||
return fi
|
||||
}
|
||||
|
||||
// read data to model
|
||||
func (o *orm) Read(md interface{}, cols ...string) error {
|
||||
mi, ind := o.getMiInd(md)
|
||||
mi, ind := o.getMiInd(md, true)
|
||||
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -70,26 +74,69 @@ func (o *orm) Read(md interface{}, cols ...string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// insert model data to database
|
||||
func (o *orm) Insert(md interface{}) (int64, error) {
|
||||
mi, ind := o.getMiInd(md)
|
||||
mi, ind := o.getMiInd(md, true)
|
||||
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
|
||||
if err != nil {
|
||||
return id, err
|
||||
}
|
||||
if id > 0 {
|
||||
if mi.fields.pk.auto {
|
||||
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
|
||||
ind.Field(mi.fields.pk.fieldIndex).SetUint(uint64(id))
|
||||
} else {
|
||||
ind.Field(mi.fields.pk.fieldIndex).SetInt(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
o.setPk(mi, ind, id)
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// set auto pk field
|
||||
func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) {
|
||||
if mi.fields.pk.auto {
|
||||
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
|
||||
ind.Field(mi.fields.pk.fieldIndex).SetUint(uint64(id))
|
||||
} else {
|
||||
ind.Field(mi.fields.pk.fieldIndex).SetInt(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// insert some models to database
|
||||
func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) {
|
||||
var cnt int64
|
||||
|
||||
sind := reflect.Indirect(reflect.ValueOf(mds))
|
||||
|
||||
switch sind.Kind() {
|
||||
case reflect.Array, reflect.Slice:
|
||||
if sind.Len() == 0 {
|
||||
return cnt, ErrArgs
|
||||
}
|
||||
default:
|
||||
return cnt, ErrArgs
|
||||
}
|
||||
|
||||
if bulk <= 1 {
|
||||
for i := 0; i < sind.Len(); i++ {
|
||||
ind := sind.Index(i)
|
||||
mi, _ := o.getMiInd(ind.Interface(), false)
|
||||
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
|
||||
if err != nil {
|
||||
return cnt, err
|
||||
}
|
||||
|
||||
o.setPk(mi, ind, id)
|
||||
|
||||
cnt += 1
|
||||
}
|
||||
} else {
|
||||
mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
|
||||
return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ)
|
||||
}
|
||||
return cnt, nil
|
||||
}
|
||||
|
||||
// update model to database.
|
||||
// cols set the columns those want to update.
|
||||
func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
|
||||
mi, ind := o.getMiInd(md)
|
||||
mi, ind := o.getMiInd(md, true)
|
||||
num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
|
||||
if err != nil {
|
||||
return num, err
|
||||
@ -97,26 +144,22 @@ func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
|
||||
return num, nil
|
||||
}
|
||||
|
||||
// delete model in database
|
||||
func (o *orm) Delete(md interface{}) (int64, error) {
|
||||
mi, ind := o.getMiInd(md)
|
||||
mi, ind := o.getMiInd(md, true)
|
||||
num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ)
|
||||
if err != nil {
|
||||
return num, err
|
||||
}
|
||||
if num > 0 {
|
||||
if mi.fields.pk.auto {
|
||||
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
|
||||
ind.Field(mi.fields.pk.fieldIndex).SetUint(0)
|
||||
} else {
|
||||
ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
|
||||
}
|
||||
}
|
||||
o.setPk(mi, ind, 0)
|
||||
}
|
||||
return num, nil
|
||||
}
|
||||
|
||||
// create a models to models queryer
|
||||
func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer {
|
||||
mi, ind := o.getMiInd(md)
|
||||
mi, ind := o.getMiInd(md, true)
|
||||
fi := o.getFieldInfo(mi, name)
|
||||
|
||||
switch {
|
||||
@ -129,6 +172,14 @@ func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer {
|
||||
return newQueryM2M(md, o, mi, fi, ind)
|
||||
}
|
||||
|
||||
// load related models to md model.
|
||||
// args are limit, offset int and order string.
|
||||
//
|
||||
// example:
|
||||
// orm.LoadRelated(post,"Tags")
|
||||
// for _,tag := range post.Tags{...}
|
||||
//
|
||||
// make sure the relation is defined in model struct tags.
|
||||
func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) {
|
||||
_, fi, ind, qseter := o.queryRelated(md, name)
|
||||
|
||||
@ -190,14 +241,21 @@ func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int
|
||||
return nums, err
|
||||
}
|
||||
|
||||
// return a QuerySeter for related models to md model.
|
||||
// it can do all, update, delete in QuerySeter.
|
||||
// example:
|
||||
// qs := orm.QueryRelated(post,"Tag")
|
||||
// qs.All(&[]*Tag{})
|
||||
//
|
||||
func (o *orm) QueryRelated(md interface{}, name string) QuerySeter {
|
||||
// is this api needed ?
|
||||
_, _, _, qs := o.queryRelated(md, name)
|
||||
return qs
|
||||
}
|
||||
|
||||
// get QuerySeter for related models to md model
|
||||
func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) {
|
||||
mi, ind := o.getMiInd(md)
|
||||
mi, ind := o.getMiInd(md, true)
|
||||
fi := o.getFieldInfo(mi, name)
|
||||
|
||||
_, _, exist := getExistPk(mi, ind)
|
||||
@ -227,6 +285,7 @@ func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo,
|
||||
return mi, fi, ind, qs
|
||||
}
|
||||
|
||||
// get reverse relation QuerySeter
|
||||
func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
|
||||
switch fi.fieldType {
|
||||
case RelReverseOne, RelReverseMany:
|
||||
@ -247,6 +306,7 @@ func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS
|
||||
return q
|
||||
}
|
||||
|
||||
// get relation QuerySeter
|
||||
func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
|
||||
switch fi.fieldType {
|
||||
case RelOneToOne, RelForeignKey, RelManyToMany:
|
||||
@ -266,6 +326,9 @@ func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
|
||||
return q
|
||||
}
|
||||
|
||||
// return a QuerySeter for table operations.
|
||||
// table name can be string or struct.
|
||||
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
|
||||
func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
|
||||
name := ""
|
||||
if table, ok := ptrStructOrTableName.(string); ok {
|
||||
@ -285,6 +348,7 @@ func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
|
||||
return
|
||||
}
|
||||
|
||||
// switch to another registered database driver by given name.
|
||||
func (o *orm) Using(name string) error {
|
||||
if o.isTx {
|
||||
panic(fmt.Errorf("<Ormer.Using> transaction has been start, cannot change db"))
|
||||
@ -302,6 +366,7 @@ func (o *orm) Using(name string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// begin transaction
|
||||
func (o *orm) Begin() error {
|
||||
if o.isTx {
|
||||
return ErrTxHasBegan
|
||||
@ -320,6 +385,7 @@ func (o *orm) Begin() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// commit transaction
|
||||
func (o *orm) Commit() error {
|
||||
if o.isTx == false {
|
||||
return ErrTxDone
|
||||
@ -334,6 +400,7 @@ func (o *orm) Commit() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// rollback transaction
|
||||
func (o *orm) Rollback() error {
|
||||
if o.isTx == false {
|
||||
return ErrTxDone
|
||||
@ -348,14 +415,17 @@ func (o *orm) Rollback() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// return a raw query seter for raw sql string.
|
||||
func (o *orm) Raw(query string, args ...interface{}) RawSeter {
|
||||
return newRawSet(o, query, args)
|
||||
}
|
||||
|
||||
// return current using database Driver
|
||||
func (o *orm) Driver() Driver {
|
||||
return driver(o.alias.Name)
|
||||
}
|
||||
|
||||
// create new orm
|
||||
func NewOrm() Ormer {
|
||||
BootStrap() // execute only once
|
||||
|
||||
|
@ -18,15 +18,19 @@ type condValue struct {
|
||||
isCond bool
|
||||
}
|
||||
|
||||
// condition struct.
|
||||
// work for WHERE conditions.
|
||||
type Condition struct {
|
||||
params []condValue
|
||||
}
|
||||
|
||||
// return new condition struct
|
||||
func NewCondition() *Condition {
|
||||
c := &Condition{}
|
||||
return c
|
||||
}
|
||||
|
||||
// add expression to condition
|
||||
func (c Condition) And(expr string, args ...interface{}) *Condition {
|
||||
if expr == "" || len(args) == 0 {
|
||||
panic(fmt.Errorf("<Condition.And> args cannot empty"))
|
||||
@ -35,6 +39,7 @@ func (c Condition) And(expr string, args ...interface{}) *Condition {
|
||||
return &c
|
||||
}
|
||||
|
||||
// add NOT expression to condition
|
||||
func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
|
||||
if expr == "" || len(args) == 0 {
|
||||
panic(fmt.Errorf("<Condition.AndNot> args cannot empty"))
|
||||
@ -43,6 +48,7 @@ func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
|
||||
return &c
|
||||
}
|
||||
|
||||
// combine a condition to current condition
|
||||
func (c *Condition) AndCond(cond *Condition) *Condition {
|
||||
c = c.clone()
|
||||
if c == cond {
|
||||
@ -54,6 +60,7 @@ func (c *Condition) AndCond(cond *Condition) *Condition {
|
||||
return c
|
||||
}
|
||||
|
||||
// add OR expression to condition
|
||||
func (c Condition) Or(expr string, args ...interface{}) *Condition {
|
||||
if expr == "" || len(args) == 0 {
|
||||
panic(fmt.Errorf("<Condition.Or> args cannot empty"))
|
||||
@ -62,6 +69,7 @@ func (c Condition) Or(expr string, args ...interface{}) *Condition {
|
||||
return &c
|
||||
}
|
||||
|
||||
// add OR NOT expression to condition
|
||||
func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
|
||||
if expr == "" || len(args) == 0 {
|
||||
panic(fmt.Errorf("<Condition.OrNot> args cannot empty"))
|
||||
@ -70,6 +78,7 @@ func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
|
||||
return &c
|
||||
}
|
||||
|
||||
// combine a OR condition to current condition
|
||||
func (c *Condition) OrCond(cond *Condition) *Condition {
|
||||
c = c.clone()
|
||||
if c == cond {
|
||||
@ -81,10 +90,12 @@ func (c *Condition) OrCond(cond *Condition) *Condition {
|
||||
return c
|
||||
}
|
||||
|
||||
// check the condition arguments are empty or not.
|
||||
func (c *Condition) IsEmpty() bool {
|
||||
return len(c.params) == 0
|
||||
}
|
||||
|
||||
// clone a condition
|
||||
func (c Condition) clone() *Condition {
|
||||
return &c
|
||||
}
|
||||
|
@ -13,6 +13,7 @@ type Log struct {
|
||||
*log.Logger
|
||||
}
|
||||
|
||||
// set io.Writer to create a Logger.
|
||||
func NewLog(out io.Writer) *Log {
|
||||
d := new(Log)
|
||||
d.Logger = log.New(out, "[ORM]", 1e9)
|
||||
@ -40,6 +41,8 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error
|
||||
DebugLog.Println(con)
|
||||
}
|
||||
|
||||
// statement query logger struct.
|
||||
// if dev mode, use stmtQueryLog, or use stmtQuerier.
|
||||
type stmtQueryLog struct {
|
||||
alias *alias
|
||||
query string
|
||||
@ -84,6 +87,8 @@ func newStmtQueryLog(alias *alias, stmt stmtQuerier, query string) stmtQuerier {
|
||||
return d
|
||||
}
|
||||
|
||||
// database query logger struct.
|
||||
// if dev mode, use dbQueryLog, or use dbQuerier.
|
||||
type dbQueryLog struct {
|
||||
alias *alias
|
||||
db dbQuerier
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// an insert queryer struct
|
||||
type insertSet struct {
|
||||
mi *modelInfo
|
||||
orm *orm
|
||||
@ -14,6 +15,7 @@ type insertSet struct {
|
||||
|
||||
var _ Inserter = new(insertSet)
|
||||
|
||||
// insert model ignore it's registered or not.
|
||||
func (o *insertSet) Insert(md interface{}) (int64, error) {
|
||||
if o.closed {
|
||||
return 0, ErrStmtClosed
|
||||
@ -44,6 +46,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) {
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// close insert queryer statement
|
||||
func (o *insertSet) Close() error {
|
||||
if o.closed {
|
||||
return ErrStmtClosed
|
||||
@ -52,6 +55,7 @@ func (o *insertSet) Close() error {
|
||||
return o.stmt.Close()
|
||||
}
|
||||
|
||||
// create new insert queryer.
|
||||
func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) {
|
||||
bi := new(insertSet)
|
||||
bi.orm = orm
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// model to model struct
|
||||
type queryM2M struct {
|
||||
md interface{}
|
||||
mi *modelInfo
|
||||
@ -12,6 +13,13 @@ type queryM2M struct {
|
||||
ind reflect.Value
|
||||
}
|
||||
|
||||
// add models to origin models when creating queryM2M.
|
||||
// example:
|
||||
// m2m := orm.QueryM2M(post,"Tag")
|
||||
// m2m.Add(&Tag1{},&Tag2{})
|
||||
// for _,tag := range post.Tags{}
|
||||
//
|
||||
// make sure the relation is defined in post model struct tag.
|
||||
func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
|
||||
fi := o.fi
|
||||
mi := fi.relThroughModelInfo
|
||||
@ -44,7 +52,8 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
|
||||
|
||||
names := []string{mfi.column, rfi.column}
|
||||
|
||||
var nums int64
|
||||
values := make([]interface{}, 0, len(models)*2)
|
||||
|
||||
for _, md := range models {
|
||||
|
||||
ind := reflect.Indirect(reflect.ValueOf(md))
|
||||
@ -59,18 +68,14 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
|
||||
}
|
||||
}
|
||||
|
||||
values := []interface{}{v1, v2}
|
||||
_, err := dbase.InsertValue(orm.db, mi, names, values)
|
||||
if err != nil {
|
||||
return nums, err
|
||||
}
|
||||
values = append(values, v1, v2)
|
||||
|
||||
nums += 1
|
||||
}
|
||||
|
||||
return nums, nil
|
||||
return dbase.InsertValue(orm.db, mi, true, names, values)
|
||||
}
|
||||
|
||||
// remove models following the origin model relationship
|
||||
func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
|
||||
fi := o.fi
|
||||
qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md)
|
||||
@ -82,17 +87,20 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
|
||||
return nums, nil
|
||||
}
|
||||
|
||||
// check model is existed in relationship of origin model
|
||||
func (o *queryM2M) Exist(md interface{}) bool {
|
||||
fi := o.fi
|
||||
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).
|
||||
Filter(fi.reverseFieldInfoTwo.name, md).Exist()
|
||||
}
|
||||
|
||||
// clean all models in related of origin model
|
||||
func (o *queryM2M) Clear() (int64, error) {
|
||||
fi := o.fi
|
||||
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete()
|
||||
}
|
||||
|
||||
// count all related models of origin model
|
||||
func (o *queryM2M) Count() (int64, error) {
|
||||
fi := o.fi
|
||||
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count()
|
||||
@ -100,6 +108,7 @@ func (o *queryM2M) Count() (int64, error) {
|
||||
|
||||
var _ QueryM2Mer = new(queryM2M)
|
||||
|
||||
// create new M2M queryer.
|
||||
func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer {
|
||||
qm2m := new(queryM2M)
|
||||
qm2m.md = md
|
||||
|
@ -18,6 +18,10 @@ const (
|
||||
Col_Except
|
||||
)
|
||||
|
||||
// ColValue do the field raw changes. e.g Nums = Nums + 10. usage:
|
||||
// Params{
|
||||
// "Nums": ColValue(Col_Add, 10),
|
||||
// }
|
||||
func ColValue(opt operator, value interface{}) interface{} {
|
||||
switch opt {
|
||||
case Col_Add, Col_Minus, Col_Multiply, Col_Except:
|
||||
@ -34,6 +38,7 @@ func ColValue(opt operator, value interface{}) interface{} {
|
||||
return val
|
||||
}
|
||||
|
||||
// real query struct
|
||||
type querySet struct {
|
||||
mi *modelInfo
|
||||
cond *Condition
|
||||
@ -47,6 +52,7 @@ type querySet struct {
|
||||
|
||||
var _ QuerySeter = new(querySet)
|
||||
|
||||
// add condition expression to QuerySeter.
|
||||
func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
|
||||
if o.cond == nil {
|
||||
o.cond = NewCondition()
|
||||
@ -55,6 +61,7 @@ func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
|
||||
return &o
|
||||
}
|
||||
|
||||
// add NOT condition to querySeter.
|
||||
func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
|
||||
if o.cond == nil {
|
||||
o.cond = NewCondition()
|
||||
@ -63,10 +70,13 @@ func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
|
||||
return &o
|
||||
}
|
||||
|
||||
// set offset number
|
||||
func (o *querySet) setOffset(num interface{}) {
|
||||
o.offset = ToInt64(num)
|
||||
}
|
||||
|
||||
// add LIMIT value.
|
||||
// args[0] means offset, e.g. LIMIT num,offset.
|
||||
func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter {
|
||||
o.limit = ToInt64(limit)
|
||||
if len(args) > 0 {
|
||||
@ -75,16 +85,21 @@ func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter {
|
||||
return &o
|
||||
}
|
||||
|
||||
// add OFFSET value
|
||||
func (o querySet) Offset(offset interface{}) QuerySeter {
|
||||
o.setOffset(offset)
|
||||
return &o
|
||||
}
|
||||
|
||||
// add ORDER expression.
|
||||
// "column" means ASC, "-column" means DESC.
|
||||
func (o querySet) OrderBy(exprs ...string) QuerySeter {
|
||||
o.orders = exprs
|
||||
return &o
|
||||
}
|
||||
|
||||
// set relation model to query together.
|
||||
// it will query relation models and assign to parent model.
|
||||
func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
|
||||
var related []string
|
||||
if len(params) == 0 {
|
||||
@ -105,36 +120,50 @@ func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
|
||||
return &o
|
||||
}
|
||||
|
||||
// set condition to QuerySeter.
|
||||
func (o querySet) SetCond(cond *Condition) QuerySeter {
|
||||
o.cond = cond
|
||||
return &o
|
||||
}
|
||||
|
||||
// return QuerySeter execution result number
|
||||
func (o *querySet) Count() (int64, error) {
|
||||
return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
||||
}
|
||||
|
||||
// check result empty or not after QuerySeter executed
|
||||
func (o *querySet) Exist() bool {
|
||||
cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
||||
return cnt > 0
|
||||
}
|
||||
|
||||
// execute update with parameters
|
||||
func (o *querySet) Update(values Params) (int64, error) {
|
||||
return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ)
|
||||
}
|
||||
|
||||
// execute delete
|
||||
func (o *querySet) Delete() (int64, error) {
|
||||
return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
||||
}
|
||||
|
||||
// return a insert queryer.
|
||||
// it can be used in times.
|
||||
// example:
|
||||
// i,err := sq.PrepareInsert()
|
||||
// i.Add(&user1{},&user2{})
|
||||
func (o *querySet) PrepareInsert() (Inserter, error) {
|
||||
return newInsertSet(o.orm, o.mi)
|
||||
}
|
||||
|
||||
// query all data and map to containers.
|
||||
// cols means the columns when querying.
|
||||
func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
|
||||
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
|
||||
}
|
||||
|
||||
// query one row data and map to containers.
|
||||
// cols means the columns when querying.
|
||||
func (o *querySet) One(container interface{}, cols ...string) error {
|
||||
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
|
||||
if err != nil {
|
||||
@ -149,18 +178,26 @@ func (o *querySet) One(container interface{}, cols ...string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// query all data and map to []map[string]interface.
|
||||
// expres means condition expression.
|
||||
// it converts data to []map[column]value.
|
||||
func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) {
|
||||
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
|
||||
}
|
||||
|
||||
// query all data and map to [][]interface
|
||||
// it converts data to [][column_index]value
|
||||
func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) {
|
||||
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
|
||||
}
|
||||
|
||||
// query all data and map to []interface.
|
||||
// it's designed for one row record set, auto change to []value, not [][column]value.
|
||||
func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) {
|
||||
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ)
|
||||
}
|
||||
|
||||
// create new QuerySeter.
|
||||
func newQuerySet(orm *orm, mi *modelInfo) QuerySeter {
|
||||
o := new(querySet)
|
||||
o.mi = mi
|
||||
|
326
orm/orm_raw.go
326
orm/orm_raw.go
@ -4,10 +4,10 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// raw sql string prepared statement
|
||||
type rawPrepare struct {
|
||||
rs *rawSet
|
||||
stmt stmtQuerier
|
||||
@ -45,6 +45,7 @@ func newRawPreparer(rs *rawSet) (RawPreparer, error) {
|
||||
return o, nil
|
||||
}
|
||||
|
||||
// raw query seter
|
||||
type rawSet struct {
|
||||
query string
|
||||
args []interface{}
|
||||
@ -53,11 +54,13 @@ type rawSet struct {
|
||||
|
||||
var _ RawSeter = new(rawSet)
|
||||
|
||||
// set args for every query
|
||||
func (o rawSet) SetArgs(args ...interface{}) RawSeter {
|
||||
o.args = args
|
||||
return &o
|
||||
}
|
||||
|
||||
// execute raw sql and return sql.Result
|
||||
func (o *rawSet) Exec() (sql.Result, error) {
|
||||
query := o.query
|
||||
o.orm.alias.DbBaser.ReplaceMarks(&query)
|
||||
@ -66,6 +69,7 @@ func (o *rawSet) Exec() (sql.Result, error) {
|
||||
return o.orm.db.Exec(query, args...)
|
||||
}
|
||||
|
||||
// set field value to row container
|
||||
func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
|
||||
switch ind.Kind() {
|
||||
case reflect.Bool:
|
||||
@ -164,65 +168,12 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
func (o *rawSet) loopInitRefs(typ reflect.Type, refsPtr *[]interface{}, sIdxesPtr *[][]int) {
|
||||
sIdxes := *sIdxesPtr
|
||||
refs := *refsPtr
|
||||
|
||||
if typ.Kind() == reflect.Struct {
|
||||
if typ.String() == "time.Time" {
|
||||
var ref interface{}
|
||||
refs = append(refs, &ref)
|
||||
sIdxes = append(sIdxes, []int{0})
|
||||
} else {
|
||||
idxs := []int{}
|
||||
outFor:
|
||||
for idx := 0; idx < typ.NumField(); idx++ {
|
||||
ctyp := typ.Field(idx)
|
||||
|
||||
tag := ctyp.Tag.Get(defaultStructTagName)
|
||||
for _, v := range strings.Split(tag, defaultStructTagDelim) {
|
||||
if v == "-" {
|
||||
continue outFor
|
||||
}
|
||||
}
|
||||
|
||||
tp := ctyp.Type
|
||||
if tp.Kind() == reflect.Ptr {
|
||||
tp = tp.Elem()
|
||||
}
|
||||
|
||||
if tp.String() == "time.Time" {
|
||||
var ref interface{}
|
||||
refs = append(refs, &ref)
|
||||
|
||||
} else if tp.Kind() != reflect.Struct {
|
||||
var ref interface{}
|
||||
refs = append(refs, &ref)
|
||||
|
||||
} else {
|
||||
// skip other type
|
||||
continue
|
||||
}
|
||||
|
||||
idxs = append(idxs, idx)
|
||||
}
|
||||
sIdxes = append(sIdxes, idxs)
|
||||
}
|
||||
} else {
|
||||
var ref interface{}
|
||||
refs = append(refs, &ref)
|
||||
sIdxes = append(sIdxes, []int{0})
|
||||
}
|
||||
|
||||
*sIdxesPtr = sIdxes
|
||||
*refsPtr = refs
|
||||
}
|
||||
|
||||
func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) {
|
||||
// set field value in loop for slice container
|
||||
func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) {
|
||||
nInds := *nIndsPtr
|
||||
|
||||
cur := 0
|
||||
for i, idxs := range sIdxes {
|
||||
for i := 0; i < len(sInds); i++ {
|
||||
sInd := sInds[i]
|
||||
eTyp := eTyps[i]
|
||||
|
||||
@ -258,32 +209,8 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect
|
||||
o.setFieldValue(ind, value)
|
||||
}
|
||||
cur++
|
||||
} else {
|
||||
hasValue := false
|
||||
for _, idx := range idxs {
|
||||
tind := ind.Field(idx)
|
||||
value := reflect.ValueOf(refs[cur]).Elem().Interface()
|
||||
if value != nil {
|
||||
hasValue = true
|
||||
}
|
||||
if tind.Kind() == reflect.Ptr {
|
||||
if value == nil {
|
||||
tindV := reflect.New(tind.Type()).Elem()
|
||||
tind.Set(tindV)
|
||||
} else {
|
||||
tindV := reflect.New(tind.Type().Elem())
|
||||
o.setFieldValue(tindV.Elem(), value)
|
||||
tind.Set(tindV)
|
||||
}
|
||||
} else {
|
||||
o.setFieldValue(tind, value)
|
||||
}
|
||||
cur++
|
||||
}
|
||||
if hasValue == false && isPtr {
|
||||
val = reflect.New(val.Type()).Elem()
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
value := reflect.ValueOf(refs[cur]).Elem().Interface()
|
||||
if isPtr && value == nil {
|
||||
@ -312,16 +239,14 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect
|
||||
}
|
||||
}
|
||||
|
||||
// query data and map to container
|
||||
func (o *rawSet) QueryRow(containers ...interface{}) error {
|
||||
if len(containers) == 0 {
|
||||
panic(fmt.Errorf("<RawSeter.QueryRow> need at least one arg"))
|
||||
}
|
||||
|
||||
refs := make([]interface{}, 0, len(containers))
|
||||
sIdxes := make([][]int, 0)
|
||||
sInds := make([]reflect.Value, 0)
|
||||
eTyps := make([]reflect.Type, 0)
|
||||
|
||||
structMode := false
|
||||
var sMi *modelInfo
|
||||
for _, container := range containers {
|
||||
val := reflect.ValueOf(container)
|
||||
ind := reflect.Indirect(val)
|
||||
@ -335,44 +260,123 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
|
||||
if typ.Kind() == reflect.Ptr {
|
||||
typ = typ.Elem()
|
||||
}
|
||||
if typ.Kind() == reflect.Ptr {
|
||||
typ = typ.Elem()
|
||||
}
|
||||
|
||||
sInds = append(sInds, ind)
|
||||
eTyps = append(eTyps, etyp)
|
||||
|
||||
o.loopInitRefs(typ, &refs, &sIdxes)
|
||||
if typ.Kind() == reflect.Struct && typ.String() != "time.Time" {
|
||||
if len(containers) > 1 {
|
||||
panic(fmt.Errorf("<RawSeter.QueryRow> now support one struct only. see #384"))
|
||||
}
|
||||
|
||||
structMode = true
|
||||
fn := getFullName(typ)
|
||||
if mi, ok := modelCache.getByFN(fn); ok {
|
||||
sMi = mi
|
||||
}
|
||||
} else {
|
||||
var ref interface{}
|
||||
refs = append(refs, &ref)
|
||||
}
|
||||
}
|
||||
|
||||
query := o.query
|
||||
o.orm.alias.DbBaser.ReplaceMarks(&query)
|
||||
|
||||
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
|
||||
row := o.orm.db.QueryRow(query, args...)
|
||||
|
||||
if err := row.Scan(refs...); err == sql.ErrNoRows {
|
||||
return ErrNoRows
|
||||
} else if err != nil {
|
||||
rows, err := o.orm.db.Query(query, args...)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return ErrNoRows
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
nInds := make([]reflect.Value, len(sInds))
|
||||
o.loopSetRefs(refs, sIdxes, sInds, &nInds, eTyps, true)
|
||||
for i, sInd := range sInds {
|
||||
nInd := nInds[i]
|
||||
sInd.Set(nInd)
|
||||
defer rows.Close()
|
||||
|
||||
if rows.Next() {
|
||||
if structMode {
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
columnsMp := make(map[string]interface{}, len(columns))
|
||||
|
||||
refs = make([]interface{}, 0, len(columns))
|
||||
for _, col := range columns {
|
||||
var ref interface{}
|
||||
columnsMp[col] = &ref
|
||||
refs = append(refs, &ref)
|
||||
}
|
||||
|
||||
if err := rows.Scan(refs...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ind := sInds[0]
|
||||
|
||||
if ind.Kind() == reflect.Ptr {
|
||||
if ind.IsNil() || !ind.IsValid() {
|
||||
ind.Set(reflect.New(eTyps[0].Elem()))
|
||||
}
|
||||
ind = ind.Elem()
|
||||
}
|
||||
|
||||
if sMi != nil {
|
||||
for _, col := range columns {
|
||||
if fi := sMi.fields.GetByColumn(col); fi != nil {
|
||||
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
|
||||
o.setFieldValue(ind.FieldByIndex([]int{fi.fieldIndex}), value)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < ind.NumField(); i++ {
|
||||
f := ind.Field(i)
|
||||
fe := ind.Type().Field(i)
|
||||
|
||||
var attrs map[string]bool
|
||||
var tags map[string]string
|
||||
parseStructTag(fe.Tag.Get("orm"), &attrs, &tags)
|
||||
var col string
|
||||
if col = tags["column"]; len(col) == 0 {
|
||||
col = snakeString(fe.Name)
|
||||
}
|
||||
if v, ok := columnsMp[col]; ok {
|
||||
value := reflect.ValueOf(v).Elem().Interface()
|
||||
o.setFieldValue(f, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
if err := rows.Scan(refs...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nInds := make([]reflect.Value, len(sInds))
|
||||
o.loopSetRefs(refs, sInds, &nInds, eTyps, true)
|
||||
for i, sInd := range sInds {
|
||||
nInd := nInds[i]
|
||||
sInd.Set(nInd)
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
return ErrNoRows
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// query data rows and map to container
|
||||
func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
|
||||
refs := make([]interface{}, 0)
|
||||
sIdxes := make([][]int, 0)
|
||||
refs := make([]interface{}, 0, len(containers))
|
||||
sInds := make([]reflect.Value, 0)
|
||||
eTyps := make([]reflect.Type, 0)
|
||||
|
||||
structMode := false
|
||||
var sMi *modelInfo
|
||||
for _, container := range containers {
|
||||
val := reflect.ValueOf(container)
|
||||
sInd := reflect.Indirect(val)
|
||||
@ -389,7 +393,20 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
|
||||
sInds = append(sInds, sInd)
|
||||
eTyps = append(eTyps, etyp)
|
||||
|
||||
o.loopInitRefs(typ, &refs, &sIdxes)
|
||||
if typ.Kind() == reflect.Struct && typ.String() != "time.Time" {
|
||||
if len(containers) > 1 {
|
||||
panic(fmt.Errorf("<RawSeter.QueryRow> now support one struct only. see #384"))
|
||||
}
|
||||
|
||||
structMode = true
|
||||
fn := getFullName(typ)
|
||||
if mi, ok := modelCache.getByFN(fn); ok {
|
||||
sMi = mi
|
||||
}
|
||||
} else {
|
||||
var ref interface{}
|
||||
refs = append(refs, &ref)
|
||||
}
|
||||
}
|
||||
|
||||
query := o.query
|
||||
@ -401,23 +418,100 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
nInds := make([]reflect.Value, len(sInds))
|
||||
defer rows.Close()
|
||||
|
||||
var cnt int64
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(refs...); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
nInds := make([]reflect.Value, len(sInds))
|
||||
sInd := sInds[0]
|
||||
|
||||
o.loopSetRefs(refs, sIdxes, sInds, &nInds, eTyps, cnt == 0)
|
||||
for rows.Next() {
|
||||
|
||||
if structMode {
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
columnsMp := make(map[string]interface{}, len(columns))
|
||||
|
||||
refs = make([]interface{}, 0, len(columns))
|
||||
for _, col := range columns {
|
||||
var ref interface{}
|
||||
columnsMp[col] = &ref
|
||||
refs = append(refs, &ref)
|
||||
}
|
||||
|
||||
if err := rows.Scan(refs...); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if cnt == 0 && !sInd.IsNil() {
|
||||
sInd.Set(reflect.New(sInd.Type()).Elem())
|
||||
}
|
||||
|
||||
var ind reflect.Value
|
||||
if eTyps[0].Kind() == reflect.Ptr {
|
||||
ind = reflect.New(eTyps[0].Elem())
|
||||
} else {
|
||||
ind = reflect.New(eTyps[0])
|
||||
}
|
||||
|
||||
if ind.Kind() == reflect.Ptr {
|
||||
ind = ind.Elem()
|
||||
}
|
||||
|
||||
if sMi != nil {
|
||||
for _, col := range columns {
|
||||
if fi := sMi.fields.GetByColumn(col); fi != nil {
|
||||
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
|
||||
o.setFieldValue(ind.FieldByIndex([]int{fi.fieldIndex}), value)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < ind.NumField(); i++ {
|
||||
f := ind.Field(i)
|
||||
fe := ind.Type().Field(i)
|
||||
|
||||
var attrs map[string]bool
|
||||
var tags map[string]string
|
||||
parseStructTag(fe.Tag.Get("orm"), &attrs, &tags)
|
||||
var col string
|
||||
if col = tags["column"]; len(col) == 0 {
|
||||
col = snakeString(fe.Name)
|
||||
}
|
||||
if v, ok := columnsMp[col]; ok {
|
||||
value := reflect.ValueOf(v).Elem().Interface()
|
||||
o.setFieldValue(f, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if eTyps[0].Kind() == reflect.Ptr {
|
||||
ind = ind.Addr()
|
||||
}
|
||||
|
||||
sInd = reflect.Append(sInd, ind)
|
||||
|
||||
} else {
|
||||
if err := rows.Scan(refs...); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
o.loopSetRefs(refs, sInds, &nInds, eTyps, cnt == 0)
|
||||
}
|
||||
|
||||
cnt++
|
||||
}
|
||||
|
||||
if cnt > 0 {
|
||||
for i, sInd := range sInds {
|
||||
nInd := nInds[i]
|
||||
sInd.Set(nInd)
|
||||
|
||||
if structMode {
|
||||
sInds[0].Set(sInd)
|
||||
} else {
|
||||
for i, sInd := range sInds {
|
||||
nInd := nInds[i]
|
||||
sInd.Set(nInd)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -455,6 +549,8 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
|
||||
rs = r
|
||||
}
|
||||
|
||||
defer rs.Close()
|
||||
|
||||
var (
|
||||
refs []interface{}
|
||||
cnt int64
|
||||
@ -527,18 +623,22 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
|
||||
return cnt, nil
|
||||
}
|
||||
|
||||
// query data to []map[string]interface
|
||||
func (o *rawSet) Values(container *[]Params) (int64, error) {
|
||||
return o.readValues(container)
|
||||
}
|
||||
|
||||
// query data to [][]interface
|
||||
func (o *rawSet) ValuesList(container *[]ParamsList) (int64, error) {
|
||||
return o.readValues(container)
|
||||
}
|
||||
|
||||
// query data to []interface
|
||||
func (o *rawSet) ValuesFlat(container *ParamsList) (int64, error) {
|
||||
return o.readValues(container)
|
||||
}
|
||||
|
||||
// return prepared raw statement for used in times.
|
||||
func (o *rawSet) Prepare() (RawPreparer, error) {
|
||||
return newRawPreparer(o)
|
||||
}
|
||||
|
204
orm/orm_test.go
204
orm/orm_test.go
@ -1322,58 +1322,6 @@ func TestRawQueryRow(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type Tmp struct {
|
||||
Skip0 string
|
||||
Id int
|
||||
Char *string
|
||||
Skip1 int `orm:"-"`
|
||||
Date time.Time
|
||||
DateTime time.Time
|
||||
}
|
||||
|
||||
Boolean = false
|
||||
Text = ""
|
||||
Int64 = 0
|
||||
Uint = 0
|
||||
|
||||
tmp := new(Tmp)
|
||||
|
||||
cols = []string{
|
||||
"int", "char", "date", "datetime", "boolean", "text", "int64", "uint",
|
||||
}
|
||||
query = fmt.Sprintf("SELECT NULL, %s%s%s FROM data WHERE id = ?", Q, strings.Join(cols, sep), Q)
|
||||
values = []interface{}{
|
||||
tmp, &Boolean, &Text, &Int64, &Uint,
|
||||
}
|
||||
err = dORM.Raw(query, 1).QueryRow(values...)
|
||||
throwFailNow(t, err)
|
||||
|
||||
for _, col := range cols {
|
||||
switch col {
|
||||
case "id":
|
||||
throwFail(t, AssertIs(tmp.Id, data_values[col]))
|
||||
case "char":
|
||||
c := tmp.Char
|
||||
throwFail(t, AssertIs(*c, data_values[col]))
|
||||
case "date":
|
||||
v := tmp.Date.In(DefaultTimeLoc)
|
||||
value := data_values[col].(time.Time).In(DefaultTimeLoc)
|
||||
throwFail(t, AssertIs(v, value, test_Date))
|
||||
case "datetime":
|
||||
v := tmp.DateTime.In(DefaultTimeLoc)
|
||||
value := data_values[col].(time.Time).In(DefaultTimeLoc)
|
||||
throwFail(t, AssertIs(v, value, test_DateTime))
|
||||
case "boolean":
|
||||
throwFail(t, AssertIs(Boolean, data_values[col]))
|
||||
case "text":
|
||||
throwFail(t, AssertIs(Text, data_values[col]))
|
||||
case "int64":
|
||||
throwFail(t, AssertIs(Int64, data_values[col]))
|
||||
case "uint":
|
||||
throwFail(t, AssertIs(Uint, data_values[col]))
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
uid int
|
||||
status *int
|
||||
@ -1394,22 +1342,13 @@ func TestRawQueryRow(t *testing.T) {
|
||||
func TestQueryRows(t *testing.T) {
|
||||
Q := dDbBaser.TableQuote()
|
||||
|
||||
cols := []string{
|
||||
"id", "boolean", "char", "text", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32",
|
||||
"int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal",
|
||||
}
|
||||
|
||||
var datas []*Data
|
||||
var dids []int
|
||||
|
||||
sep := fmt.Sprintf("%s, %s", Q, Q)
|
||||
query := fmt.Sprintf("SELECT %s%s%s, id FROM %sdata%s", Q, strings.Join(cols, sep), Q, Q, Q)
|
||||
num, err := dORM.Raw(query).QueryRows(&datas, &dids)
|
||||
query := fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q)
|
||||
num, err := dORM.Raw(query).QueryRows(&datas)
|
||||
throwFailNow(t, err)
|
||||
throwFailNow(t, AssertIs(num, 1))
|
||||
throwFailNow(t, AssertIs(len(datas), 1))
|
||||
throwFailNow(t, AssertIs(len(dids), 1))
|
||||
throwFailNow(t, AssertIs(dids[0], 1))
|
||||
|
||||
ind := reflect.Indirect(reflect.ValueOf(datas[0]))
|
||||
|
||||
@ -1427,90 +1366,43 @@ func TestQueryRows(t *testing.T) {
|
||||
throwFail(t, AssertIs(vu == value, true), value, vu)
|
||||
}
|
||||
|
||||
type Tmp struct {
|
||||
Id int
|
||||
Name string
|
||||
Skiped0 string `orm:"-"`
|
||||
Pid *int
|
||||
Skiped1 Data
|
||||
Skiped2 *Data
|
||||
var datas2 []Data
|
||||
|
||||
query = fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q)
|
||||
num, err = dORM.Raw(query).QueryRows(&datas2)
|
||||
throwFailNow(t, err)
|
||||
throwFailNow(t, AssertIs(num, 1))
|
||||
throwFailNow(t, AssertIs(len(datas2), 1))
|
||||
|
||||
ind = reflect.Indirect(reflect.ValueOf(datas2[0]))
|
||||
|
||||
for name, value := range Data_Values {
|
||||
e := ind.FieldByName(name)
|
||||
vu := e.Interface()
|
||||
switch name {
|
||||
case "Date":
|
||||
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date)
|
||||
value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date)
|
||||
case "DateTime":
|
||||
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
|
||||
value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
|
||||
}
|
||||
throwFail(t, AssertIs(vu == value, true), value, vu)
|
||||
}
|
||||
|
||||
var (
|
||||
ids []int
|
||||
userNames []string
|
||||
profileIds1 []int
|
||||
profileIds2 []*int
|
||||
createds []time.Time
|
||||
updateds []time.Time
|
||||
tmps1 []*Tmp
|
||||
tmps2 []Tmp
|
||||
)
|
||||
cols = []string{
|
||||
"id", "user_name", "profile_id", "profile_id", "id", "user_name", "profile_id", "id", "user_name", "profile_id", "created", "updated",
|
||||
}
|
||||
query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s ORDER BY id", Q, strings.Join(cols, sep), Q, Q, Q)
|
||||
num, err = dORM.Raw(query).QueryRows(&ids, &userNames, &profileIds1, &profileIds2, &tmps1, &tmps2, &createds, &updateds)
|
||||
var ids []int
|
||||
var usernames []string
|
||||
query = fmt.Sprintf("SELECT %sid%s, %suser_name%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q, Q, Q)
|
||||
num, err = dORM.Raw(query).QueryRows(&ids, &usernames)
|
||||
throwFailNow(t, err)
|
||||
throwFailNow(t, AssertIs(num, 3))
|
||||
|
||||
var users []User
|
||||
dORM.QueryTable("user").OrderBy("Id").All(&users)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
id := ids[i]
|
||||
name := userNames[i]
|
||||
pid1 := profileIds1[i]
|
||||
pid2 := profileIds2[i]
|
||||
created := createds[i]
|
||||
updated := updateds[i]
|
||||
|
||||
user := users[i]
|
||||
throwFailNow(t, AssertIs(id, user.Id))
|
||||
throwFailNow(t, AssertIs(name, user.UserName))
|
||||
if user.Profile != nil {
|
||||
throwFailNow(t, AssertIs(pid1, user.Profile.Id))
|
||||
throwFailNow(t, AssertIs(*pid2, user.Profile.Id))
|
||||
} else {
|
||||
throwFailNow(t, AssertIs(pid1, 0))
|
||||
throwFailNow(t, AssertIs(pid2, nil))
|
||||
}
|
||||
throwFailNow(t, AssertIs(created, user.Created, test_Date))
|
||||
throwFailNow(t, AssertIs(updated, user.Updated, test_DateTime))
|
||||
|
||||
tmp := tmps1[i]
|
||||
tmp1 := *tmp
|
||||
throwFailNow(t, AssertIs(tmp1.Id, user.Id))
|
||||
throwFailNow(t, AssertIs(tmp1.Name, user.UserName))
|
||||
if user.Profile != nil {
|
||||
pid := tmp1.Pid
|
||||
throwFailNow(t, AssertIs(*pid, user.Profile.Id))
|
||||
} else {
|
||||
throwFailNow(t, AssertIs(tmp1.Pid, nil))
|
||||
}
|
||||
|
||||
tmp2 := tmps2[i]
|
||||
throwFailNow(t, AssertIs(tmp2.Id, user.Id))
|
||||
throwFailNow(t, AssertIs(tmp2.Name, user.UserName))
|
||||
if user.Profile != nil {
|
||||
pid := tmp2.Pid
|
||||
throwFailNow(t, AssertIs(*pid, user.Profile.Id))
|
||||
} else {
|
||||
throwFailNow(t, AssertIs(tmp2.Pid, nil))
|
||||
}
|
||||
}
|
||||
|
||||
type Sec struct {
|
||||
Id int
|
||||
Name string
|
||||
}
|
||||
|
||||
var tmp []*Sec
|
||||
query = fmt.Sprintf("SELECT NULL, NULL FROM %suser%s LIMIT 1", Q, Q)
|
||||
num, err = dORM.Raw(query).QueryRows(&tmp)
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 1))
|
||||
throwFail(t, AssertIs(tmp[0], nil))
|
||||
throwFailNow(t, AssertIs(len(ids), 3))
|
||||
throwFailNow(t, AssertIs(ids[0], 2))
|
||||
throwFailNow(t, AssertIs(usernames[0], "slene"))
|
||||
throwFailNow(t, AssertIs(ids[1], 3))
|
||||
throwFailNow(t, AssertIs(usernames[1], "astaxie"))
|
||||
throwFailNow(t, AssertIs(ids[2], 4))
|
||||
throwFailNow(t, AssertIs(usernames[2], "nobody"))
|
||||
}
|
||||
|
||||
func TestRawValues(t *testing.T) {
|
||||
@ -1669,6 +1561,32 @@ func TestDelete(t *testing.T) {
|
||||
num, err = qs.Filter("user_name", "slene").Filter("profile__isnull", true).Count()
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 1))
|
||||
|
||||
qs = dORM.QueryTable("comment")
|
||||
num, err = qs.Count()
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 6))
|
||||
|
||||
qs = dORM.QueryTable("post")
|
||||
num, err = qs.Filter("Id", 3).Delete()
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 1))
|
||||
|
||||
qs = dORM.QueryTable("comment")
|
||||
num, err = qs.Count()
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 4))
|
||||
|
||||
fmt.Println("...")
|
||||
qs = dORM.QueryTable("comment")
|
||||
num, err = qs.Filter("Post__User", 3).Delete()
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 3))
|
||||
|
||||
qs = dORM.QueryTable("comment")
|
||||
num, err = qs.Count()
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 1))
|
||||
}
|
||||
|
||||
func TestTransaction(t *testing.T) {
|
||||
|
17
orm/types.go
17
orm/types.go
@ -6,11 +6,13 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// database driver
|
||||
type Driver interface {
|
||||
Name() string
|
||||
Type() DriverType
|
||||
}
|
||||
|
||||
// field info
|
||||
type Fielder interface {
|
||||
String() string
|
||||
FieldType() int
|
||||
@ -18,9 +20,11 @@ type Fielder interface {
|
||||
RawValue() interface{}
|
||||
}
|
||||
|
||||
// orm struct
|
||||
type Ormer interface {
|
||||
Read(interface{}, ...string) error
|
||||
Insert(interface{}) (int64, error)
|
||||
InsertMulti(int, interface{}) (int64, error)
|
||||
Update(interface{}, ...string) (int64, error)
|
||||
Delete(interface{}) (int64, error)
|
||||
LoadRelated(interface{}, string, ...interface{}) (int64, error)
|
||||
@ -34,11 +38,13 @@ type Ormer interface {
|
||||
Driver() Driver
|
||||
}
|
||||
|
||||
// insert prepared statement
|
||||
type Inserter interface {
|
||||
Insert(interface{}) (int64, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// query seter
|
||||
type QuerySeter interface {
|
||||
Filter(string, ...interface{}) QuerySeter
|
||||
Exclude(string, ...interface{}) QuerySeter
|
||||
@ -59,6 +65,7 @@ type QuerySeter interface {
|
||||
ValuesFlat(*ParamsList, string) (int64, error)
|
||||
}
|
||||
|
||||
// model to model query struct
|
||||
type QueryM2Mer interface {
|
||||
Add(...interface{}) (int64, error)
|
||||
Remove(...interface{}) (int64, error)
|
||||
@ -67,11 +74,13 @@ type QueryM2Mer interface {
|
||||
Count() (int64, error)
|
||||
}
|
||||
|
||||
// raw query statement
|
||||
type RawPreparer interface {
|
||||
Exec(...interface{}) (sql.Result, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// raw query seter
|
||||
type RawSeter interface {
|
||||
Exec() (sql.Result, error)
|
||||
QueryRow(...interface{}) error
|
||||
@ -83,6 +92,7 @@ type RawSeter interface {
|
||||
Prepare() (RawPreparer, error)
|
||||
}
|
||||
|
||||
// statement querier
|
||||
type stmtQuerier interface {
|
||||
Close() error
|
||||
Exec(args ...interface{}) (sql.Result, error)
|
||||
@ -90,6 +100,7 @@ type stmtQuerier interface {
|
||||
QueryRow(args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
// db querier
|
||||
type dbQuerier interface {
|
||||
Prepare(query string) (*sql.Stmt, error)
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
@ -97,19 +108,23 @@ type dbQuerier interface {
|
||||
QueryRow(query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
// transaction beginner
|
||||
type txer interface {
|
||||
Begin() (*sql.Tx, error)
|
||||
}
|
||||
|
||||
// transaction ending
|
||||
type txEnder interface {
|
||||
Commit() error
|
||||
Rollback() error
|
||||
}
|
||||
|
||||
// base database struct
|
||||
type dbBaser interface {
|
||||
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error
|
||||
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
||||
InsertValue(dbQuerier, *modelInfo, []string, []interface{}) (int64, error)
|
||||
InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
|
||||
InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
|
||||
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
||||
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
|
||||
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
||||
|
27
orm/utils.go
27
orm/utils.go
@ -10,6 +10,7 @@ import (
|
||||
|
||||
type StrTo string
|
||||
|
||||
// set string
|
||||
func (f *StrTo) Set(v string) {
|
||||
if v != "" {
|
||||
*f = StrTo(v)
|
||||
@ -18,77 +19,93 @@ func (f *StrTo) Set(v string) {
|
||||
}
|
||||
}
|
||||
|
||||
// clean string
|
||||
func (f *StrTo) Clear() {
|
||||
*f = StrTo(0x1E)
|
||||
}
|
||||
|
||||
// check string exist
|
||||
func (f StrTo) Exist() bool {
|
||||
return string(f) != string(0x1E)
|
||||
}
|
||||
|
||||
// string to bool
|
||||
func (f StrTo) Bool() (bool, error) {
|
||||
return strconv.ParseBool(f.String())
|
||||
}
|
||||
|
||||
// string to float32
|
||||
func (f StrTo) Float32() (float32, error) {
|
||||
v, err := strconv.ParseFloat(f.String(), 32)
|
||||
return float32(v), err
|
||||
}
|
||||
|
||||
// string to float64
|
||||
func (f StrTo) Float64() (float64, error) {
|
||||
return strconv.ParseFloat(f.String(), 64)
|
||||
}
|
||||
|
||||
// string to int
|
||||
func (f StrTo) Int() (int, error) {
|
||||
v, err := strconv.ParseInt(f.String(), 10, 32)
|
||||
return int(v), err
|
||||
}
|
||||
|
||||
// string to int8
|
||||
func (f StrTo) Int8() (int8, error) {
|
||||
v, err := strconv.ParseInt(f.String(), 10, 8)
|
||||
return int8(v), err
|
||||
}
|
||||
|
||||
// string to int16
|
||||
func (f StrTo) Int16() (int16, error) {
|
||||
v, err := strconv.ParseInt(f.String(), 10, 16)
|
||||
return int16(v), err
|
||||
}
|
||||
|
||||
// string to int32
|
||||
func (f StrTo) Int32() (int32, error) {
|
||||
v, err := strconv.ParseInt(f.String(), 10, 32)
|
||||
return int32(v), err
|
||||
}
|
||||
|
||||
// string to int64
|
||||
func (f StrTo) Int64() (int64, error) {
|
||||
v, err := strconv.ParseInt(f.String(), 10, 64)
|
||||
return int64(v), err
|
||||
}
|
||||
|
||||
// string to uint
|
||||
func (f StrTo) Uint() (uint, error) {
|
||||
v, err := strconv.ParseUint(f.String(), 10, 32)
|
||||
return uint(v), err
|
||||
}
|
||||
|
||||
// string to uint8
|
||||
func (f StrTo) Uint8() (uint8, error) {
|
||||
v, err := strconv.ParseUint(f.String(), 10, 8)
|
||||
return uint8(v), err
|
||||
}
|
||||
|
||||
// string to uint16
|
||||
func (f StrTo) Uint16() (uint16, error) {
|
||||
v, err := strconv.ParseUint(f.String(), 10, 16)
|
||||
return uint16(v), err
|
||||
}
|
||||
|
||||
// string to uint31
|
||||
func (f StrTo) Uint32() (uint32, error) {
|
||||
v, err := strconv.ParseUint(f.String(), 10, 32)
|
||||
return uint32(v), err
|
||||
}
|
||||
|
||||
// string to uint64
|
||||
func (f StrTo) Uint64() (uint64, error) {
|
||||
v, err := strconv.ParseUint(f.String(), 10, 64)
|
||||
return uint64(v), err
|
||||
}
|
||||
|
||||
// string to string
|
||||
func (f StrTo) String() string {
|
||||
if f.Exist() {
|
||||
return string(f)
|
||||
@ -96,6 +113,7 @@ func (f StrTo) String() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// interface to string
|
||||
func ToStr(value interface{}, args ...int) (s string) {
|
||||
switch v := value.(type) {
|
||||
case bool:
|
||||
@ -134,6 +152,7 @@ func ToStr(value interface{}, args ...int) (s string) {
|
||||
return s
|
||||
}
|
||||
|
||||
// interface to int64
|
||||
func ToInt64(value interface{}) (d int64) {
|
||||
val := reflect.ValueOf(value)
|
||||
switch value.(type) {
|
||||
@ -147,6 +166,7 @@ func ToInt64(value interface{}) (d int64) {
|
||||
return
|
||||
}
|
||||
|
||||
// snake string, XxYy to xx_yy
|
||||
func snakeString(s string) string {
|
||||
data := make([]byte, 0, len(s)*2)
|
||||
j := false
|
||||
@ -164,6 +184,7 @@ func snakeString(s string) string {
|
||||
return strings.ToLower(string(data[:len(data)]))
|
||||
}
|
||||
|
||||
// camel string, xx_yy to XxYy
|
||||
func camelString(s string) string {
|
||||
data := make([]byte, 0, len(s))
|
||||
j := false
|
||||
@ -190,6 +211,7 @@ func camelString(s string) string {
|
||||
|
||||
type argString []string
|
||||
|
||||
// get string by index from string slice
|
||||
func (a argString) Get(i int, args ...string) (r string) {
|
||||
if i >= 0 && i < len(a) {
|
||||
r = a[i]
|
||||
@ -201,6 +223,7 @@ func (a argString) Get(i int, args ...string) (r string) {
|
||||
|
||||
type argInt []int
|
||||
|
||||
// get int by index from int slice
|
||||
func (a argInt) Get(i int, args ...int) (r int) {
|
||||
if i >= 0 && i < len(a) {
|
||||
r = a[i]
|
||||
@ -213,6 +236,7 @@ func (a argInt) Get(i int, args ...int) (r int) {
|
||||
|
||||
type argAny []interface{}
|
||||
|
||||
// get interface by index from interface slice
|
||||
func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
|
||||
if i >= 0 && i < len(a) {
|
||||
r = a[i]
|
||||
@ -223,15 +247,18 @@ func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
|
||||
return
|
||||
}
|
||||
|
||||
// parse time to string with location
|
||||
func timeParse(dateString, format string) (time.Time, error) {
|
||||
tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
|
||||
return tp, err
|
||||
}
|
||||
|
||||
// format time string
|
||||
func timeFormat(t time.Time, format string) string {
|
||||
return t.Format(format)
|
||||
}
|
||||
|
||||
// get pointer indirect type
|
||||
func indirectType(v reflect.Type) reflect.Type {
|
||||
switch v.Kind() {
|
||||
case reflect.Ptr:
|
||||
|
81
router.go
81
router.go
@ -1,7 +1,10 @@
|
||||
package beego
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@ -30,6 +33,14 @@ const (
|
||||
var (
|
||||
// supported http methods.
|
||||
HTTPMETHOD = []string{"get", "post", "put", "delete", "patch", "options", "head"}
|
||||
// these beego.Controller's methods shouldn't reflect to AutoRouter
|
||||
exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString",
|
||||
"RenderBytes", "Redirect", "Abort", "StopRun", "UrlFor", "ServeJson", "ServeJsonp",
|
||||
"ServeXml", "Input", "ParseForm", "GetString", "GetStrings", "GetInt", "GetBool",
|
||||
"GetFloat", "GetFile", "SaveToFile", "StartSession", "SetSession", "GetSession",
|
||||
"DelSession", "SessionRegenerateID", "DestroySession", "IsAjax", "GetSecureCookie",
|
||||
"SetSecureCookie", "XsrfToken", "CheckXsrfCookie", "XsrfFormHtml",
|
||||
"GetControllerAndAction"}
|
||||
)
|
||||
|
||||
type controllerInfo struct {
|
||||
@ -77,7 +88,7 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM
|
||||
params := make(map[int]string)
|
||||
for i, part := range parts {
|
||||
if strings.HasPrefix(part, ":") {
|
||||
expr := "(.+)"
|
||||
expr := "(.*)"
|
||||
//a user may choose to override the defult expression
|
||||
// similar to expressjs: ‘/user/:id([0-9]+)’
|
||||
if index := strings.Index(part, "("); index != -1 {
|
||||
@ -100,7 +111,7 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM
|
||||
j++
|
||||
}
|
||||
if strings.HasPrefix(part, "*") {
|
||||
expr := "(.+)"
|
||||
expr := "(.*)"
|
||||
if part == "*.*" {
|
||||
params[j] = ":path"
|
||||
parts[i] = "([^.]+).([^.]+)"
|
||||
@ -218,8 +229,8 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM
|
||||
// Add auto router to ControllerRegistor.
|
||||
// example beego.AddAuto(&MainContorlller{}),
|
||||
// MainController has method List and Page.
|
||||
// visit the url /main/list to exec List function
|
||||
// /main/page to exec Page function.
|
||||
// visit the url /main/list to execute List function
|
||||
// /main/page to execute Page function.
|
||||
func (p *ControllerRegistor) AddAuto(c ControllerInterface) {
|
||||
p.enableAuto = true
|
||||
reflectVal := reflect.ValueOf(c)
|
||||
@ -232,14 +243,42 @@ func (p *ControllerRegistor) AddAuto(c ControllerInterface) {
|
||||
p.autoRouter[firstParam] = make(map[string]reflect.Type)
|
||||
}
|
||||
for i := 0; i < rt.NumMethod(); i++ {
|
||||
p.autoRouter[firstParam][rt.Method(i).Name] = ct
|
||||
if !utils.InSlice(rt.Method(i).Name, exceptMethod) {
|
||||
p.autoRouter[firstParam][rt.Method(i).Name] = ct
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add auto router to ControllerRegistor with prefix.
|
||||
// example beego.AddAutoPrefix("/admin",&MainContorlller{}),
|
||||
// MainController has method List and Page.
|
||||
// visit the url /admin/main/list to execute List function
|
||||
// /admin/main/page to execute Page function.
|
||||
func (p *ControllerRegistor) AddAutoPrefix(prefix string, c ControllerInterface) {
|
||||
p.enableAuto = true
|
||||
reflectVal := reflect.ValueOf(c)
|
||||
rt := reflectVal.Type()
|
||||
ct := reflect.Indirect(reflectVal).Type()
|
||||
firstParam := strings.Trim(prefix, "/") + "/" + strings.ToLower(strings.TrimSuffix(ct.Name(), "Controller"))
|
||||
if _, ok := p.autoRouter[firstParam]; ok {
|
||||
return
|
||||
} else {
|
||||
p.autoRouter[firstParam] = make(map[string]reflect.Type)
|
||||
}
|
||||
for i := 0; i < rt.NumMethod(); i++ {
|
||||
if !utils.InSlice(rt.Method(i).Name, exceptMethod) {
|
||||
p.autoRouter[firstParam][rt.Method(i).Name] = ct
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// [Deprecated] use InsertFilter.
|
||||
// Add FilterFunc with pattern for action.
|
||||
func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc) {
|
||||
mr := buildFilter(pattern, filter)
|
||||
func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc) error {
|
||||
mr, err := buildFilter(pattern, filter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch action {
|
||||
case "BeforeRouter":
|
||||
p.filters[BeforeRouter] = append(p.filters[BeforeRouter], mr)
|
||||
@ -253,13 +292,18 @@ func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc
|
||||
p.filters[FinishRouter] = append(p.filters[FinishRouter], mr)
|
||||
}
|
||||
p.enableFilter = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add a FilterFunc with pattern rule and action constant.
|
||||
func (p *ControllerRegistor) InsertFilter(pattern string, pos int, filter FilterFunc) {
|
||||
mr := buildFilter(pattern, filter)
|
||||
func (p *ControllerRegistor) InsertFilter(pattern string, pos int, filter FilterFunc) error {
|
||||
mr, err := buildFilter(pattern, filter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.filters[pos] = append(p.filters[pos], mr)
|
||||
p.enableFilter = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// UrlFor does another controller handler in this request function.
|
||||
@ -485,7 +529,9 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
||||
// session init
|
||||
if SessionOn {
|
||||
context.Input.CruSession = GlobalSessions.SessionStart(w, r)
|
||||
defer context.Input.CruSession.SessionRelease()
|
||||
defer func() {
|
||||
context.Input.CruSession.SessionRelease(w)
|
||||
}()
|
||||
}
|
||||
|
||||
if !utils.InSlice(strings.ToLower(r.Method), HTTPMETHOD) {
|
||||
@ -575,12 +621,11 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
// pattern /admin url /admin 200 /admin/ 200
|
||||
// pattern /admin/ url /admin 301 /admin/ 200
|
||||
if requestPath[n-1] != '/' && len(route.pattern) == n+1 &&
|
||||
route.pattern[n] == '/' && route.pattern[:n] == requestPath {
|
||||
if requestPath[n-1] != '/' && requestPath+"/" == route.pattern {
|
||||
http.Redirect(w, r, requestPath+"/", 301)
|
||||
goto Admin
|
||||
}
|
||||
if requestPath[n-1] == '/' && n >= 2 && requestPath[:n-2] == route.pattern {
|
||||
if requestPath[n-1] == '/' && route.pattern+"/" == requestPath {
|
||||
runMethod = p.getRunMethod(r.Method, context, route)
|
||||
if runMethod != "" {
|
||||
runrouter = route.controllerType
|
||||
@ -857,3 +902,13 @@ func (w *responseWriter) WriteHeader(code int) {
|
||||
w.started = true
|
||||
w.writer.WriteHeader(code)
|
||||
}
|
||||
|
||||
// hijacker for http
|
||||
func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
hj, ok := w.writer.(http.Hijacker)
|
||||
if !ok {
|
||||
println("supported?")
|
||||
return nil, nil, errors.New("webserver doesn't support hijacking")
|
||||
}
|
||||
return hj.Hijack()
|
||||
}
|
||||
|
@ -198,3 +198,15 @@ func TestPrepare(t *testing.T) {
|
||||
t.Errorf(w.Body.String() + "user define func can't run")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutoPrefix(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "/admin/test/list", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler := NewControllerRegistor()
|
||||
handler.AddAutoPrefix("/admin", &TestController{})
|
||||
handler.ServeHTTP(w, r)
|
||||
if w.Body.String() != "i am list" {
|
||||
t.Errorf("TestAutoPrefix can't run")
|
||||
}
|
||||
}
|
||||
|
@ -28,21 +28,21 @@ Then in you web app init the global session manager
|
||||
* Use **memory** as provider:
|
||||
|
||||
func init() {
|
||||
globalSessions, _ = session.NewManager("memory", "gosessionid", 3600,"")
|
||||
globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid","gclifetime":3600}`)
|
||||
go globalSessions.GC()
|
||||
}
|
||||
|
||||
* Use **file** as provider, the last param is the path where you want file to be stored:
|
||||
|
||||
func init() {
|
||||
globalSessions, _ = session.NewManager("file", "gosessionid", 3600, "./tmp")
|
||||
globalSessions, _ = session.NewManager("file",`{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","./tmp"}`)
|
||||
go globalSessions.GC()
|
||||
}
|
||||
|
||||
* Use **Redis** as provider, the last param is the Redis conn address,poolsize,password:
|
||||
|
||||
func init() {
|
||||
globalSessions, _ = session.NewManager("redis", "gosessionid", 3600, "127.0.0.1:6379,100,astaxie")
|
||||
globalSessions, _ = session.NewManager("redis", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","127.0.0.1:6379,100,astaxie"}`)
|
||||
go globalSessions.GC()
|
||||
}
|
||||
|
||||
@ -50,15 +50,24 @@ Then in you web app init the global session manager
|
||||
|
||||
func init() {
|
||||
globalSessions, _ = session.NewManager(
|
||||
"mysql", "gosessionid", 3600, "username:password@protocol(address)/dbname?param=value")
|
||||
"mysql", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","username:password@protocol(address)/dbname?param=value"}`)
|
||||
go globalSessions.GC()
|
||||
}
|
||||
|
||||
* Use **Cookie** as provider:
|
||||
|
||||
func init() {
|
||||
globalSessions, _ = session.NewManager(
|
||||
"cookie", `{"cookieName":"gosessionid","enableSetCookie":false,gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`)
|
||||
go globalSessions.GC()
|
||||
}
|
||||
|
||||
|
||||
Finally in the handlerfunc you can use it like this
|
||||
|
||||
func login(w http.ResponseWriter, r *http.Request) {
|
||||
sess := globalSessions.SessionStart(w, r)
|
||||
defer sess.SessionRelease()
|
||||
defer sess.SessionRelease(w)
|
||||
username := sess.Get("username")
|
||||
fmt.Println(username)
|
||||
if r.Method == "GET" {
|
||||
@ -78,19 +87,19 @@ When you develop a web app, maybe you want to write own provider because you mus
|
||||
|
||||
Writing a provider is easy. You only need to define two struct types
|
||||
(Session and Provider), which satisfy the interface definition.
|
||||
Maybe you will find the **memory** provider as good example.
|
||||
Maybe you will find the **memory** provider is a good example.
|
||||
|
||||
type SessionStore interface {
|
||||
Set(key, value interface{}) error //set session value
|
||||
Get(key interface{}) interface{} //get session value
|
||||
Delete(key interface{}) error //delete session value
|
||||
SessionID() string //back current sessionID
|
||||
SessionRelease() // release the resource & save data to provider
|
||||
Flush() error //delete all data
|
||||
Set(key, value interface{}) error //set session value
|
||||
Get(key interface{}) interface{} //get session value
|
||||
Delete(key interface{}) error //delete session value
|
||||
SessionID() string //back current sessionID
|
||||
SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data
|
||||
Flush() error //delete all data
|
||||
}
|
||||
|
||||
type Provider interface {
|
||||
SessionInit(maxlifetime int64, savePath string) error
|
||||
SessionInit(gclifetime int64, config string) error
|
||||
SessionRead(sid string) (SessionStore, error)
|
||||
SessionExist(sid string) bool
|
||||
SessionRegenerate(oldsid, sid string) (SessionStore, error)
|
||||
|
145
session/sess_cookie.go
Normal file
145
session/sess_cookie.go
Normal file
@ -0,0 +1,145 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var cookiepder = &CookieProvider{}
|
||||
|
||||
type CookieSessionStore struct {
|
||||
sid string
|
||||
values map[interface{}]interface{} //session data
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
func (st *CookieSessionStore) Set(key, value interface{}) error {
|
||||
st.lock.Lock()
|
||||
defer st.lock.Unlock()
|
||||
st.values[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (st *CookieSessionStore) Get(key interface{}) interface{} {
|
||||
st.lock.RLock()
|
||||
defer st.lock.RUnlock()
|
||||
if v, ok := st.values[key]; ok {
|
||||
return v
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (st *CookieSessionStore) Delete(key interface{}) error {
|
||||
st.lock.Lock()
|
||||
defer st.lock.Unlock()
|
||||
delete(st.values, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (st *CookieSessionStore) Flush() error {
|
||||
st.lock.Lock()
|
||||
defer st.lock.Unlock()
|
||||
st.values = make(map[interface{}]interface{})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (st *CookieSessionStore) SessionID() string {
|
||||
return st.sid
|
||||
}
|
||||
|
||||
func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) {
|
||||
str, err := encodeCookie(cookiepder.block,
|
||||
cookiepder.config.SecurityKey,
|
||||
cookiepder.config.SecurityName,
|
||||
st.values)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cookie := &http.Cookie{Name: cookiepder.config.CookieName,
|
||||
Value: url.QueryEscape(str),
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: cookiepder.config.Secure}
|
||||
http.SetCookie(w, cookie)
|
||||
return
|
||||
}
|
||||
|
||||
type cookieConfig struct {
|
||||
SecurityKey string `json:"securityKey"`
|
||||
BlockKey string `json:"blockKey"`
|
||||
SecurityName string `json:"securityName"`
|
||||
CookieName string `json:"cookieName"`
|
||||
Secure bool `json:"secure"`
|
||||
Maxage int `json:"maxage"`
|
||||
}
|
||||
|
||||
type CookieProvider struct {
|
||||
maxlifetime int64
|
||||
config *cookieConfig
|
||||
block cipher.Block
|
||||
}
|
||||
|
||||
func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error {
|
||||
pder.config = &cookieConfig{}
|
||||
err := json.Unmarshal([]byte(config), pder.config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if pder.config.BlockKey == "" {
|
||||
pder.config.BlockKey = string(generateRandomKey(16))
|
||||
}
|
||||
if pder.config.SecurityName == "" {
|
||||
pder.config.SecurityName = string(generateRandomKey(20))
|
||||
}
|
||||
pder.block, err = aes.NewCipher([]byte(pder.config.BlockKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pder *CookieProvider) SessionRead(sid string) (SessionStore, error) {
|
||||
maps, _ := decodeCookie(pder.block,
|
||||
pder.config.SecurityKey,
|
||||
pder.config.SecurityName,
|
||||
sid, pder.maxlifetime)
|
||||
if maps == nil {
|
||||
maps = make(map[interface{}]interface{})
|
||||
}
|
||||
rs := &CookieSessionStore{sid: sid, values: maps}
|
||||
return rs, nil
|
||||
}
|
||||
|
||||
func (pder *CookieProvider) SessionExist(sid string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (pder *CookieProvider) SessionDestroy(sid string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pder *CookieProvider) SessionGC() {
|
||||
return
|
||||
}
|
||||
|
||||
func (pder *CookieProvider) SessionAll() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (pder *CookieProvider) SessionUpdate(sid string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
Register("cookie", cookiepder)
|
||||
}
|
38
session/sess_cookie_test.go
Normal file
38
session/sess_cookie_test.go
Normal file
@ -0,0 +1,38 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCookie(t *testing.T) {
|
||||
config := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
|
||||
globalSessions, err := NewManager("cookie", config)
|
||||
if err != nil {
|
||||
t.Fatal("init cookie session err", err)
|
||||
}
|
||||
r, _ := http.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
sess := globalSessions.SessionStart(w, r)
|
||||
err = sess.Set("username", "astaxie")
|
||||
if err != nil {
|
||||
t.Fatal("set error,", err)
|
||||
}
|
||||
if username := sess.Get("username"); username != "astaxie" {
|
||||
t.Fatal("get username error")
|
||||
}
|
||||
sess.SessionRelease(w)
|
||||
if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" {
|
||||
t.Fatal("setcookie error")
|
||||
} else {
|
||||
parts := strings.Split(strings.TrimSpace(cookiestr), ";")
|
||||
for k, v := range parts {
|
||||
nameval := strings.Split(v, "=")
|
||||
if k == 0 && nameval[0] != "gosessionid" {
|
||||
t.Fatal("error")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
@ -60,7 +61,7 @@ func (fs *FileSessionStore) SessionID() string {
|
||||
return fs.sid
|
||||
}
|
||||
|
||||
func (fs *FileSessionStore) SessionRelease() {
|
||||
func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) {
|
||||
defer fs.f.Close()
|
||||
b, err := encodeGob(fs.values)
|
||||
if err != nil {
|
||||
|
@ -1,38 +0,0 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gob.Register([]interface{}{})
|
||||
gob.Register(map[int]interface{}{})
|
||||
gob.Register(map[string]interface{}{})
|
||||
gob.Register(map[interface{}]interface{}{})
|
||||
gob.Register(map[string]string{})
|
||||
gob.Register(map[int]string{})
|
||||
gob.Register(map[int]int{})
|
||||
gob.Register(map[int]int64{})
|
||||
}
|
||||
|
||||
func encodeGob(obj map[interface{}]interface{}) ([]byte, error) {
|
||||
buf := bytes.NewBuffer(nil)
|
||||
enc := gob.NewEncoder(buf)
|
||||
err := enc.Encode(obj)
|
||||
if err != nil {
|
||||
return []byte(""), err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func decodeGob(encoded []byte) (map[interface{}]interface{}, error) {
|
||||
buf := bytes.NewBuffer(encoded)
|
||||
dec := gob.NewDecoder(buf)
|
||||
var out map[interface{}]interface{}
|
||||
err := dec.Decode(&out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
@ -2,6 +2,7 @@ package session
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@ -9,9 +10,9 @@ import (
|
||||
var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)}
|
||||
|
||||
type MemSessionStore struct {
|
||||
sid string //session id唯一标示
|
||||
timeAccessed time.Time //最后访问时间
|
||||
value map[interface{}]interface{} //session里面存储的值
|
||||
sid string //session id
|
||||
timeAccessed time.Time //last access time
|
||||
value map[interface{}]interface{} //session store
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
@ -51,8 +52,7 @@ func (st *MemSessionStore) SessionID() string {
|
||||
return st.sid
|
||||
}
|
||||
|
||||
func (st *MemSessionStore) SessionRelease() {
|
||||
|
||||
func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) {
|
||||
}
|
||||
|
||||
type MemProvider struct {
|
||||
|
35
session/sess_mem_test.go
Normal file
35
session/sess_mem_test.go
Normal file
@ -0,0 +1,35 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMem(t *testing.T) {
|
||||
globalSessions, _ := NewManager("memory", `{"cookieName":"gosessionid","gclifetime":10}`)
|
||||
go globalSessions.GC()
|
||||
r, _ := http.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
sess := globalSessions.SessionStart(w, r)
|
||||
defer sess.SessionRelease(w)
|
||||
err := sess.Set("username", "astaxie")
|
||||
if err != nil {
|
||||
t.Fatal("set error,", err)
|
||||
}
|
||||
if username := sess.Get("username"); username != "astaxie" {
|
||||
t.Fatal("get username error")
|
||||
}
|
||||
if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" {
|
||||
t.Fatal("setcookie error")
|
||||
} else {
|
||||
parts := strings.Split(strings.TrimSpace(cookiestr), ";")
|
||||
for k, v := range parts {
|
||||
nameval := strings.Split(v, "=")
|
||||
if k == 0 && nameval[0] != "gosessionid" {
|
||||
t.Fatal("error")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -9,6 +9,7 @@ package session
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -60,15 +61,15 @@ func (st *MysqlSessionStore) SessionID() string {
|
||||
return st.sid
|
||||
}
|
||||
|
||||
func (st *MysqlSessionStore) SessionRelease() {
|
||||
func (st *MysqlSessionStore) SessionRelease(w http.ResponseWriter) {
|
||||
defer st.c.Close()
|
||||
if len(st.values) > 0 {
|
||||
b, err := encodeGob(st.values)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
st.c.Exec("UPDATE session set `session_data`= ? where session_key=?", b, st.sid)
|
||||
b, err := encodeGob(st.values)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
st.c.Exec("UPDATE session set `session_data`=?, `session_expiry`=? where session_key=?",
|
||||
b, time.Now().Unix(), st.sid)
|
||||
|
||||
}
|
||||
|
||||
type MysqlProvider struct {
|
||||
@ -96,7 +97,8 @@ func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) {
|
||||
var sessiondata []byte
|
||||
err := row.Scan(&sessiondata)
|
||||
if err == sql.ErrNoRows {
|
||||
c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", sid, "", time.Now().Unix())
|
||||
c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)",
|
||||
sid, "", time.Now().Unix())
|
||||
}
|
||||
var kv map[interface{}]interface{}
|
||||
if len(sessiondata) == 0 {
|
||||
@ -113,6 +115,7 @@ func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) {
|
||||
|
||||
func (mp *MysqlProvider) SessionExist(sid string) bool {
|
||||
c := mp.connectInit()
|
||||
defer c.Close()
|
||||
row := c.QueryRow("select session_data from session where session_key=?", sid)
|
||||
var sessiondata []byte
|
||||
err := row.Scan(&sessiondata)
|
||||
|
@ -1,6 +1,7 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -58,16 +59,14 @@ func (rs *RedisSessionStore) SessionID() string {
|
||||
return rs.sid
|
||||
}
|
||||
|
||||
func (rs *RedisSessionStore) SessionRelease() {
|
||||
func (rs *RedisSessionStore) SessionRelease(w http.ResponseWriter) {
|
||||
defer rs.c.Close()
|
||||
if len(rs.values) > 0 {
|
||||
b, err := encodeGob(rs.values)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
rs.c.Do("SET", rs.sid, string(b))
|
||||
rs.c.Do("EXPIRE", rs.sid, rs.maxlifetime)
|
||||
b, err := encodeGob(rs.values)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
rs.c.Do("SET", rs.sid, string(b))
|
||||
rs.c.Do("EXPIRE", rs.sid, rs.maxlifetime)
|
||||
}
|
||||
|
||||
type RedisProvider struct {
|
||||
|
@ -1,6 +1,8 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@ -26,3 +28,82 @@ func Test_gob(t *testing.T) {
|
||||
t.Error("decode int error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerate(t *testing.T) {
|
||||
str := generateRandomKey(20)
|
||||
if len(str) != 20 {
|
||||
t.Fatal("generate length is not equal to 20")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCookieEncodeDecode(t *testing.T) {
|
||||
hashKey := "testhashKey"
|
||||
blockkey := generateRandomKey(16)
|
||||
block, err := aes.NewCipher(blockkey)
|
||||
if err != nil {
|
||||
t.Fatal("NewCipher:", err)
|
||||
}
|
||||
securityName := string(generateRandomKey(20))
|
||||
val := make(map[interface{}]interface{})
|
||||
val["name"] = "astaxie"
|
||||
val["gender"] = "male"
|
||||
str, err := encodeCookie(block, hashKey, securityName, val)
|
||||
if err != nil {
|
||||
t.Fatal("encodeCookie:", err)
|
||||
}
|
||||
dst := make(map[interface{}]interface{})
|
||||
dst, err = decodeCookie(block, hashKey, securityName, str, 3600)
|
||||
if err != nil {
|
||||
t.Fatal("decodeCookie", err)
|
||||
}
|
||||
if dst["name"] != "astaxie" {
|
||||
t.Fatal("dst get map error")
|
||||
}
|
||||
if dst["gender"] != "male" {
|
||||
t.Fatal("dst get map error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
s := `{"cookieName":"gosessionid","gclifetime":3600}`
|
||||
cf := new(managerConfig)
|
||||
cf.EnableSetCookie = true
|
||||
err := json.Unmarshal([]byte(s), cf)
|
||||
if err != nil {
|
||||
t.Fatal("parse json error,", err)
|
||||
}
|
||||
if cf.CookieName != "gosessionid" {
|
||||
t.Fatal("parseconfig get cookiename error")
|
||||
}
|
||||
if cf.Gclifetime != 3600 {
|
||||
t.Fatal("parseconfig get gclifetime error")
|
||||
}
|
||||
|
||||
cc := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
|
||||
cf2 := new(managerConfig)
|
||||
cf2.EnableSetCookie = true
|
||||
err = json.Unmarshal([]byte(cc), cf2)
|
||||
if err != nil {
|
||||
t.Fatal("parse json error,", err)
|
||||
}
|
||||
if cf2.CookieName != "gosessionid" {
|
||||
t.Fatal("parseconfig get cookiename error")
|
||||
}
|
||||
if cf2.Gclifetime != 3600 {
|
||||
t.Fatal("parseconfig get gclifetime error")
|
||||
}
|
||||
if cf2.EnableSetCookie != false {
|
||||
t.Fatal("parseconfig get enableSetCookie error")
|
||||
}
|
||||
cconfig := new(cookieConfig)
|
||||
err = json.Unmarshal([]byte(cf2.ProviderConfig), cconfig)
|
||||
if err != nil {
|
||||
t.Fatal("parse ProviderConfig err,", err)
|
||||
}
|
||||
if cconfig.CookieName != "gosessionid" {
|
||||
t.Fatal("ProviderConfig get cookieName error")
|
||||
}
|
||||
if cconfig.SecurityKey != "beegocookiehashkey" {
|
||||
t.Fatal("ProviderConfig get securityKey error")
|
||||
}
|
||||
}
|
||||
|
188
session/sess_utils.go
Normal file
188
session/sess_utils.go
Normal file
@ -0,0 +1,188 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/cipher"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/gob"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gob.Register([]interface{}{})
|
||||
gob.Register(map[int]interface{}{})
|
||||
gob.Register(map[string]interface{}{})
|
||||
gob.Register(map[interface{}]interface{}{})
|
||||
gob.Register(map[string]string{})
|
||||
gob.Register(map[int]string{})
|
||||
gob.Register(map[int]int{})
|
||||
gob.Register(map[int]int64{})
|
||||
}
|
||||
|
||||
func encodeGob(obj map[interface{}]interface{}) ([]byte, error) {
|
||||
buf := bytes.NewBuffer(nil)
|
||||
enc := gob.NewEncoder(buf)
|
||||
err := enc.Encode(obj)
|
||||
if err != nil {
|
||||
return []byte(""), err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func decodeGob(encoded []byte) (map[interface{}]interface{}, error) {
|
||||
buf := bytes.NewBuffer(encoded)
|
||||
dec := gob.NewDecoder(buf)
|
||||
var out map[interface{}]interface{}
|
||||
err := dec.Decode(&out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// generateRandomKey creates a random key with the given strength.
|
||||
func generateRandomKey(strength int) []byte {
|
||||
k := make([]byte, strength)
|
||||
if _, err := io.ReadFull(rand.Reader, k); err != nil {
|
||||
return nil
|
||||
}
|
||||
return k
|
||||
}
|
||||
|
||||
// Encryption -----------------------------------------------------------------
|
||||
|
||||
// encrypt encrypts a value using the given block in counter mode.
|
||||
//
|
||||
// A random initialization vector (http://goo.gl/zF67k) with the length of the
|
||||
// block size is prepended to the resulting ciphertext.
|
||||
func encrypt(block cipher.Block, value []byte) ([]byte, error) {
|
||||
iv := generateRandomKey(block.BlockSize())
|
||||
if iv == nil {
|
||||
return nil, errors.New("encrypt: failed to generate random iv")
|
||||
}
|
||||
// Encrypt it.
|
||||
stream := cipher.NewCTR(block, iv)
|
||||
stream.XORKeyStream(value, value)
|
||||
// Return iv + ciphertext.
|
||||
return append(iv, value...), nil
|
||||
}
|
||||
|
||||
// decrypt decrypts a value using the given block in counter mode.
|
||||
//
|
||||
// The value to be decrypted must be prepended by a initialization vector
|
||||
// (http://goo.gl/zF67k) with the length of the block size.
|
||||
func decrypt(block cipher.Block, value []byte) ([]byte, error) {
|
||||
size := block.BlockSize()
|
||||
if len(value) > size {
|
||||
// Extract iv.
|
||||
iv := value[:size]
|
||||
// Extract ciphertext.
|
||||
value = value[size:]
|
||||
// Decrypt it.
|
||||
stream := cipher.NewCTR(block, iv)
|
||||
stream.XORKeyStream(value, value)
|
||||
return value, nil
|
||||
}
|
||||
return nil, errors.New("decrypt: the value could not be decrypted")
|
||||
}
|
||||
|
||||
func encodeCookie(block cipher.Block, hashKey, name string, value map[interface{}]interface{}) (string, error) {
|
||||
var err error
|
||||
var b []byte
|
||||
// 1. encodeGob.
|
||||
if b, err = encodeGob(value); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// 2. Encrypt (optional).
|
||||
if b, err = encrypt(block, b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
b = encode(b)
|
||||
// 3. Create MAC for "name|date|value". Extra pipe to be used later.
|
||||
b = []byte(fmt.Sprintf("%s|%d|%s|", name, time.Now().UTC().Unix(), b))
|
||||
h := hmac.New(sha1.New, []byte(hashKey))
|
||||
h.Write(b)
|
||||
sig := h.Sum(nil)
|
||||
// Append mac, remove name.
|
||||
b = append(b, sig...)[len(name)+1:]
|
||||
// 4. Encode to base64.
|
||||
b = encode(b)
|
||||
// Done.
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime int64) (map[interface{}]interface{}, error) {
|
||||
// 1. Decode from base64.
|
||||
b, err := decode([]byte(value))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 2. Verify MAC. Value is "date|value|mac".
|
||||
parts := bytes.SplitN(b, []byte("|"), 3)
|
||||
if len(parts) != 3 {
|
||||
return nil, errors.New("Decode: invalid value %v")
|
||||
}
|
||||
|
||||
b = append([]byte(name+"|"), b[:len(b)-len(parts[2])]...)
|
||||
h := hmac.New(sha1.New, []byte(hashKey))
|
||||
h.Write(b)
|
||||
sig := h.Sum(nil)
|
||||
if len(sig) != len(parts[2]) || subtle.ConstantTimeCompare(sig, parts[2]) != 1 {
|
||||
return nil, errors.New("Decode: the value is not valid")
|
||||
}
|
||||
// 3. Verify date ranges.
|
||||
var t1 int64
|
||||
if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil {
|
||||
return nil, errors.New("Decode: invalid timestamp")
|
||||
}
|
||||
t2 := time.Now().UTC().Unix()
|
||||
if t1 > t2 {
|
||||
return nil, errors.New("Decode: timestamp is too new")
|
||||
}
|
||||
if t1 < t2-gcmaxlifetime {
|
||||
return nil, errors.New("Decode: expired timestamp")
|
||||
}
|
||||
// 4. Decrypt (optional).
|
||||
b, err = decode(parts[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if b, err = decrypt(block, b); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 5. decodeGob.
|
||||
if dst, err := decodeGob(b); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return dst, nil
|
||||
}
|
||||
// Done.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Encoding -------------------------------------------------------------------
|
||||
|
||||
// encode encodes a value using base64.
|
||||
func encode(value []byte) []byte {
|
||||
encoded := make([]byte, base64.URLEncoding.EncodedLen(len(value)))
|
||||
base64.URLEncoding.Encode(encoded, value)
|
||||
return encoded
|
||||
}
|
||||
|
||||
// decode decodes a cookie using base64.
|
||||
func decode(value []byte) ([]byte, error) {
|
||||
decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value)))
|
||||
b, err := base64.URLEncoding.Decode(decoded, value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return decoded[:b], nil
|
||||
}
|
@ -6,6 +6,7 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -14,16 +15,16 @@ import (
|
||||
)
|
||||
|
||||
type SessionStore interface {
|
||||
Set(key, value interface{}) error //set session value
|
||||
Get(key interface{}) interface{} //get session value
|
||||
Delete(key interface{}) error //delete session value
|
||||
SessionID() string //back current sessionID
|
||||
SessionRelease() // release the resource & save data to provider
|
||||
Flush() error //delete all data
|
||||
Set(key, value interface{}) error //set session value
|
||||
Get(key interface{}) interface{} //get session value
|
||||
Delete(key interface{}) error //delete session value
|
||||
SessionID() string //back current sessionID
|
||||
SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data
|
||||
Flush() error //delete all data
|
||||
}
|
||||
|
||||
type Provider interface {
|
||||
SessionInit(maxlifetime int64, savePath string) error
|
||||
SessionInit(gclifetime int64, config string) error
|
||||
SessionRead(sid string) (SessionStore, error)
|
||||
SessionExist(sid string) bool
|
||||
SessionRegenerate(oldsid, sid string) (SessionStore, error)
|
||||
@ -47,15 +48,22 @@ func Register(name string, provide Provider) {
|
||||
provides[name] = provide
|
||||
}
|
||||
|
||||
type managerConfig struct {
|
||||
CookieName string `json:"cookieName"`
|
||||
EnableSetCookie bool `json:"enableSetCookie,omitempty"`
|
||||
Gclifetime int64 `json:"gclifetime"`
|
||||
Maxlifetime int64 `json:"maxLifetime"`
|
||||
Maxage int `json:"maxage"`
|
||||
Secure bool `json:"secure"`
|
||||
SessionIDHashFunc string `json:"sessionIDHashFunc"`
|
||||
SessionIDHashKey string `json:"sessionIDHashKey"`
|
||||
CookieLifeTime int64 `json:"cookieLifeTime"`
|
||||
ProviderConfig string `json:"providerConfig"`
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
cookieName string //private cookiename
|
||||
provider Provider
|
||||
maxlifetime int64
|
||||
hashfunc string //support md5 & sha1
|
||||
hashkey string
|
||||
maxage int //cookielifetime
|
||||
secure bool
|
||||
options []interface{}
|
||||
provider Provider
|
||||
config *managerConfig
|
||||
}
|
||||
|
||||
//options
|
||||
@ -63,74 +71,54 @@ type Manager struct {
|
||||
//2. hashfunc default sha1
|
||||
//3. hashkey default beegosessionkey
|
||||
//4. maxage default is none
|
||||
func NewManager(provideName, cookieName string, maxlifetime int64, savePath string, options ...interface{}) (*Manager, error) {
|
||||
func NewManager(provideName, config string) (*Manager, error) {
|
||||
provider, ok := provides[provideName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName)
|
||||
}
|
||||
provider.SessionInit(maxlifetime, savePath)
|
||||
secure := false
|
||||
if len(options) > 0 {
|
||||
secure = options[0].(bool)
|
||||
cf := new(managerConfig)
|
||||
cf.EnableSetCookie = true
|
||||
err := json.Unmarshal([]byte(config), cf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hashfunc := "sha1"
|
||||
if len(options) > 1 {
|
||||
hashfunc = options[1].(string)
|
||||
if cf.Maxlifetime == 0 {
|
||||
cf.Maxlifetime = cf.Gclifetime
|
||||
}
|
||||
hashkey := "beegosessionkey"
|
||||
if len(options) > 2 {
|
||||
hashkey = options[2].(string)
|
||||
err = provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
maxage := -1
|
||||
if len(options) > 3 {
|
||||
switch options[3].(type) {
|
||||
case int:
|
||||
if options[3].(int) > 0 {
|
||||
maxage = options[3].(int)
|
||||
} else if options[3].(int) < 0 {
|
||||
maxage = 0
|
||||
}
|
||||
case int64:
|
||||
if options[3].(int64) > 0 {
|
||||
maxage = int(options[3].(int64))
|
||||
} else if options[3].(int64) < 0 {
|
||||
maxage = 0
|
||||
}
|
||||
case int32:
|
||||
if options[3].(int32) > 0 {
|
||||
maxage = int(options[3].(int32))
|
||||
} else if options[3].(int32) < 0 {
|
||||
maxage = 0
|
||||
}
|
||||
}
|
||||
if cf.SessionIDHashFunc == "" {
|
||||
cf.SessionIDHashFunc = "sha1"
|
||||
}
|
||||
if cf.SessionIDHashKey == "" {
|
||||
cf.SessionIDHashKey = string(generateRandomKey(16))
|
||||
}
|
||||
|
||||
return &Manager{
|
||||
provider: provider,
|
||||
cookieName: cookieName,
|
||||
maxlifetime: maxlifetime,
|
||||
hashfunc: hashfunc,
|
||||
hashkey: hashkey,
|
||||
maxage: maxage,
|
||||
secure: secure,
|
||||
options: options,
|
||||
provider,
|
||||
cf,
|
||||
}, nil
|
||||
}
|
||||
|
||||
//get Session
|
||||
func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore) {
|
||||
cookie, err := r.Cookie(manager.cookieName)
|
||||
cookie, err := r.Cookie(manager.config.CookieName)
|
||||
if err != nil || cookie.Value == "" {
|
||||
sid := manager.sessionId(r)
|
||||
session, _ = manager.provider.SessionRead(sid)
|
||||
cookie = &http.Cookie{Name: manager.cookieName,
|
||||
cookie = &http.Cookie{Name: manager.config.CookieName,
|
||||
Value: url.QueryEscape(sid),
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: manager.secure}
|
||||
if manager.maxage >= 0 {
|
||||
cookie.MaxAge = manager.maxage
|
||||
Secure: manager.config.Secure}
|
||||
if manager.config.Maxage >= 0 {
|
||||
cookie.MaxAge = manager.config.Maxage
|
||||
}
|
||||
if manager.config.EnableSetCookie {
|
||||
http.SetCookie(w, cookie)
|
||||
}
|
||||
http.SetCookie(w, cookie)
|
||||
r.AddCookie(cookie)
|
||||
} else {
|
||||
sid, _ := url.QueryUnescape(cookie.Value)
|
||||
@ -139,15 +127,17 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
|
||||
} else {
|
||||
sid = manager.sessionId(r)
|
||||
session, _ = manager.provider.SessionRead(sid)
|
||||
cookie = &http.Cookie{Name: manager.cookieName,
|
||||
cookie = &http.Cookie{Name: manager.config.CookieName,
|
||||
Value: url.QueryEscape(sid),
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: manager.secure}
|
||||
if manager.maxage >= 0 {
|
||||
cookie.MaxAge = manager.maxage
|
||||
Secure: manager.config.Secure}
|
||||
if manager.config.Maxage >= 0 {
|
||||
cookie.MaxAge = manager.config.Maxage
|
||||
}
|
||||
if manager.config.EnableSetCookie {
|
||||
http.SetCookie(w, cookie)
|
||||
}
|
||||
http.SetCookie(w, cookie)
|
||||
r.AddCookie(cookie)
|
||||
}
|
||||
}
|
||||
@ -156,13 +146,17 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
|
||||
|
||||
//Destroy sessionid
|
||||
func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie(manager.cookieName)
|
||||
cookie, err := r.Cookie(manager.config.CookieName)
|
||||
if err != nil || cookie.Value == "" {
|
||||
return
|
||||
} else {
|
||||
manager.provider.SessionDestroy(cookie.Value)
|
||||
expiration := time.Now()
|
||||
cookie := http.Cookie{Name: manager.cookieName, Path: "/", HttpOnly: true, Expires: expiration, MaxAge: -1}
|
||||
cookie := http.Cookie{Name: manager.config.CookieName,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Expires: expiration,
|
||||
MaxAge: -1}
|
||||
http.SetCookie(w, &cookie)
|
||||
}
|
||||
}
|
||||
@ -174,20 +168,20 @@ func (manager *Manager) GetProvider(sid string) (sessions SessionStore, err erro
|
||||
|
||||
func (manager *Manager) GC() {
|
||||
manager.provider.SessionGC()
|
||||
time.AfterFunc(time.Duration(manager.maxlifetime)*time.Second, func() { manager.GC() })
|
||||
time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() })
|
||||
}
|
||||
|
||||
func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Request) (session SessionStore) {
|
||||
sid := manager.sessionId(r)
|
||||
cookie, err := r.Cookie(manager.cookieName)
|
||||
cookie, err := r.Cookie(manager.config.CookieName)
|
||||
if err != nil && cookie.Value == "" {
|
||||
//delete old cookie
|
||||
session, _ = manager.provider.SessionRead(sid)
|
||||
cookie = &http.Cookie{Name: manager.cookieName,
|
||||
cookie = &http.Cookie{Name: manager.config.CookieName,
|
||||
Value: url.QueryEscape(sid),
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: manager.secure,
|
||||
Secure: manager.config.Secure,
|
||||
}
|
||||
} else {
|
||||
oldsid, _ := url.QueryUnescape(cookie.Value)
|
||||
@ -196,8 +190,8 @@ func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Reque
|
||||
cookie.HttpOnly = true
|
||||
cookie.Path = "/"
|
||||
}
|
||||
if manager.maxage >= 0 {
|
||||
cookie.MaxAge = manager.maxage
|
||||
if manager.config.Maxage >= 0 {
|
||||
cookie.MaxAge = manager.config.Maxage
|
||||
}
|
||||
http.SetCookie(w, cookie)
|
||||
r.AddCookie(cookie)
|
||||
@ -209,12 +203,12 @@ func (manager *Manager) GetActiveSession() int {
|
||||
}
|
||||
|
||||
func (manager *Manager) SetHashFunc(hasfunc, hashkey string) {
|
||||
manager.hashfunc = hasfunc
|
||||
manager.hashkey = hashkey
|
||||
manager.config.SessionIDHashFunc = hasfunc
|
||||
manager.config.SessionIDHashKey = hashkey
|
||||
}
|
||||
|
||||
func (manager *Manager) SetSecure(secure bool) {
|
||||
manager.secure = secure
|
||||
manager.config.Secure = secure
|
||||
}
|
||||
|
||||
//remote_addr cruunixnano randdata
|
||||
@ -224,16 +218,16 @@ func (manager *Manager) sessionId(r *http.Request) (sid string) {
|
||||
return ""
|
||||
}
|
||||
sig := fmt.Sprintf("%s%d%s", r.RemoteAddr, time.Now().UnixNano(), bs)
|
||||
if manager.hashfunc == "md5" {
|
||||
if manager.config.SessionIDHashFunc == "md5" {
|
||||
h := md5.New()
|
||||
h.Write([]byte(sig))
|
||||
sid = hex.EncodeToString(h.Sum(nil))
|
||||
} else if manager.hashfunc == "sha1" {
|
||||
h := hmac.New(sha1.New, []byte(manager.hashkey))
|
||||
} else if manager.config.SessionIDHashFunc == "sha1" {
|
||||
h := hmac.New(sha1.New, []byte(manager.config.SessionIDHashKey))
|
||||
fmt.Fprintf(h, "%s", sig)
|
||||
sid = hex.EncodeToString(h.Sum(nil))
|
||||
} else {
|
||||
h := hmac.New(sha1.New, []byte(manager.hashkey))
|
||||
h := hmac.New(sha1.New, []byte(manager.config.SessionIDHashKey))
|
||||
fmt.Fprintf(h, "%s", sig)
|
||||
sid = hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
45
utils/captcha/README.md
Normal file
45
utils/captcha/README.md
Normal file
@ -0,0 +1,45 @@
|
||||
# Captcha
|
||||
|
||||
an example for use captcha
|
||||
|
||||
```
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"github.com/astaxie/beego"
|
||||
"github.com/astaxie/beego/cache"
|
||||
"github.com/astaxie/beego/utils/captcha"
|
||||
)
|
||||
|
||||
var cpt *captcha.Captcha
|
||||
|
||||
func init() {
|
||||
// use beego cache system store the captcha data
|
||||
store := cache.NewMemoryCache()
|
||||
cpt = captcha.NewWithFilter("/captcha/", store)
|
||||
}
|
||||
|
||||
type MainController struct {
|
||||
beego.Controller
|
||||
}
|
||||
|
||||
func (this *MainController) Get() {
|
||||
this.TplNames = "index.tpl"
|
||||
}
|
||||
|
||||
func (this *MainController) Post() {
|
||||
this.TplNames = "index.tpl"
|
||||
|
||||
this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request)
|
||||
}
|
||||
```
|
||||
|
||||
template usage
|
||||
|
||||
```
|
||||
{{.Success}}
|
||||
<form action="/" method="post">
|
||||
{{create_captcha}}
|
||||
<input name="captcha" type="text">
|
||||
</form>
|
||||
```
|
248
utils/captcha/captcha.go
Normal file
248
utils/captcha/captcha.go
Normal file
@ -0,0 +1,248 @@
|
||||
// an example for use captcha
|
||||
//
|
||||
// ```
|
||||
// package controllers
|
||||
//
|
||||
// import (
|
||||
// "github.com/astaxie/beego"
|
||||
// "github.com/astaxie/beego/cache"
|
||||
// "github.com/astaxie/beego/utils/captcha"
|
||||
// )
|
||||
//
|
||||
// var cpt *captcha.Captcha
|
||||
//
|
||||
// func init() {
|
||||
// // use beego cache system store the captcha data
|
||||
// store := cache.NewMemoryCache()
|
||||
// cpt = captcha.NewWithFilter("/captcha/", store)
|
||||
// }
|
||||
//
|
||||
// type MainController struct {
|
||||
// beego.Controller
|
||||
// }
|
||||
//
|
||||
// func (this *MainController) Get() {
|
||||
// this.TplNames = "index.tpl"
|
||||
// }
|
||||
//
|
||||
// func (this *MainController) Post() {
|
||||
// this.TplNames = "index.tpl"
|
||||
//
|
||||
// this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request)
|
||||
// }
|
||||
// ```
|
||||
//
|
||||
// template usage
|
||||
//
|
||||
// ```
|
||||
// {{.Success}}
|
||||
// <form action="/" method="post">
|
||||
// {{create_captcha}}
|
||||
// <input name="captcha" type="text">
|
||||
// </form>
|
||||
// ```
|
||||
package captcha
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/astaxie/beego"
|
||||
"github.com/astaxie/beego/cache"
|
||||
"github.com/astaxie/beego/context"
|
||||
"github.com/astaxie/beego/utils"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultChars = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
|
||||
)
|
||||
|
||||
const (
|
||||
// default captcha attributes
|
||||
challengeNums = 6
|
||||
expiration = 600
|
||||
fieldIdName = "captcha_id"
|
||||
fieldCaptchaName = "captcha"
|
||||
cachePrefix = "captcha_"
|
||||
urlPrefix = "/captcha/"
|
||||
)
|
||||
|
||||
type Captcha struct {
|
||||
// beego cache store
|
||||
store cache.Cache
|
||||
|
||||
// url prefix for captcha image
|
||||
urlPrefix string
|
||||
|
||||
// specify captcha id input field name
|
||||
FieldIdName string
|
||||
// specify captcha result input field name
|
||||
FieldCaptchaName string
|
||||
|
||||
// captcha image width and height
|
||||
StdWidth int
|
||||
StdHeight int
|
||||
|
||||
// captcha chars nums
|
||||
ChallengeNums int
|
||||
|
||||
// captcha expiration seconds
|
||||
Expiration int64
|
||||
|
||||
// cache key prefix
|
||||
CachePrefix string
|
||||
}
|
||||
|
||||
func (c *Captcha) key(id string) string {
|
||||
return c.CachePrefix + id
|
||||
}
|
||||
|
||||
func (c *Captcha) genRandChars() []byte {
|
||||
return utils.RandomCreateBytes(c.ChallengeNums, defaultChars...)
|
||||
}
|
||||
|
||||
// beego filter handler for serve captcha image
|
||||
func (c *Captcha) Handler(ctx *context.Context) {
|
||||
var chars []byte
|
||||
|
||||
id := path.Base(ctx.Request.RequestURI)
|
||||
if i := strings.Index(id, "."); i != -1 {
|
||||
id = id[:i]
|
||||
}
|
||||
|
||||
key := c.key(id)
|
||||
|
||||
if v, ok := c.store.Get(key).([]byte); ok {
|
||||
chars = v
|
||||
} else {
|
||||
ctx.Output.SetStatus(404)
|
||||
ctx.WriteString("captcha not found")
|
||||
return
|
||||
}
|
||||
|
||||
// reload captcha
|
||||
if len(ctx.Input.Query("reload")) > 0 {
|
||||
chars = c.genRandChars()
|
||||
if err := c.store.Put(key, chars, c.Expiration); err != nil {
|
||||
ctx.Output.SetStatus(500)
|
||||
ctx.WriteString("captcha reload error")
|
||||
beego.Error("Reload Create Captcha Error:", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
img := NewImage(chars, c.StdWidth, c.StdHeight)
|
||||
if _, err := img.WriteTo(ctx.ResponseWriter); err != nil {
|
||||
beego.Error("Write Captcha Image Error:", err)
|
||||
}
|
||||
}
|
||||
|
||||
// tempalte func for output html
|
||||
func (c *Captcha) CreateCaptchaHtml() template.HTML {
|
||||
value, err := c.CreateCaptcha()
|
||||
if err != nil {
|
||||
beego.Error("Create Captcha Error:", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
// create html
|
||||
return template.HTML(fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`+
|
||||
`<a class="captcha" href="javascript:">`+
|
||||
`<img onclick="this.src=('%s%s.png?reload='+(new Date()).getTime())" class="captcha-img" src="%s%s.png">`+
|
||||
`</a>`, c.FieldIdName, value, c.urlPrefix, value, c.urlPrefix, value))
|
||||
}
|
||||
|
||||
// create a new captcha id
|
||||
func (c *Captcha) CreateCaptcha() (string, error) {
|
||||
// generate captcha id
|
||||
id := string(utils.RandomCreateBytes(15))
|
||||
|
||||
// get the captcha chars
|
||||
chars := c.genRandChars()
|
||||
|
||||
// save to store
|
||||
if err := c.store.Put(c.key(id), chars, c.Expiration); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// verify from a request
|
||||
func (c *Captcha) VerifyReq(req *http.Request) bool {
|
||||
req.ParseForm()
|
||||
return c.Verify(req.Form.Get(c.FieldIdName), req.Form.Get(c.FieldCaptchaName))
|
||||
}
|
||||
|
||||
// direct verify id and challenge string
|
||||
func (c *Captcha) Verify(id string, challenge string) (success bool) {
|
||||
if len(challenge) == 0 || len(id) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var chars []byte
|
||||
|
||||
key := c.key(id)
|
||||
|
||||
if v, ok := c.store.Get(key).([]byte); ok && len(v) == len(challenge) {
|
||||
chars = v
|
||||
} else {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// finally remove it
|
||||
c.store.Delete(key)
|
||||
}()
|
||||
|
||||
// verify challenge
|
||||
for i, c := range chars {
|
||||
if c != challenge[i]-48 {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// create a new captcha.Captcha
|
||||
func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha {
|
||||
cpt := &Captcha{}
|
||||
cpt.store = store
|
||||
cpt.FieldIdName = fieldIdName
|
||||
cpt.FieldCaptchaName = fieldCaptchaName
|
||||
cpt.ChallengeNums = challengeNums
|
||||
cpt.Expiration = expiration
|
||||
cpt.CachePrefix = cachePrefix
|
||||
cpt.StdWidth = stdWidth
|
||||
cpt.StdHeight = stdHeight
|
||||
|
||||
if len(urlPrefix) == 0 {
|
||||
urlPrefix = urlPrefix
|
||||
}
|
||||
|
||||
if urlPrefix[len(urlPrefix)-1] != '/' {
|
||||
urlPrefix += "/"
|
||||
}
|
||||
|
||||
cpt.urlPrefix = urlPrefix
|
||||
|
||||
return cpt
|
||||
}
|
||||
|
||||
// create a new captcha.Captcha and auto AddFilter for serve captacha image
|
||||
// and add a tempalte func for output html
|
||||
func NewWithFilter(urlPrefix string, store cache.Cache) *Captcha {
|
||||
cpt := NewCaptcha(urlPrefix, store)
|
||||
|
||||
// create filter for serve captcha image
|
||||
beego.AddFilter(urlPrefix+":", "BeforeRouter", cpt.Handler)
|
||||
|
||||
// add to template func map
|
||||
beego.AddFuncMap("create_captcha", cpt.CreateCaptchaHtml)
|
||||
|
||||
return cpt
|
||||
}
|
484
utils/captcha/image.go
Normal file
484
utils/captcha/image.go
Normal file
@ -0,0 +1,484 @@
|
||||
// modifiy and integrated to Beego from https://github.com/dchest/captcha
|
||||
package captcha
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/png"
|
||||
"io"
|
||||
"math"
|
||||
)
|
||||
|
||||
const (
|
||||
fontWidth = 11
|
||||
fontHeight = 18
|
||||
blackChar = 1
|
||||
|
||||
// Standard width and height of a captcha image.
|
||||
stdWidth = 240
|
||||
stdHeight = 80
|
||||
// Maximum absolute skew factor of a single digit.
|
||||
maxSkew = 0.7
|
||||
// Number of background circles.
|
||||
circleCount = 20
|
||||
)
|
||||
|
||||
var font = [][]byte{
|
||||
{ // 0
|
||||
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
|
||||
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
|
||||
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0,
|
||||
0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1,
|
||||
0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0,
|
||||
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
|
||||
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
|
||||
},
|
||||
{ // 1
|
||||
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,
|
||||
0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0,
|
||||
0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0,
|
||||
0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0,
|
||||
0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
},
|
||||
{ // 2
|
||||
0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0,
|
||||
0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
|
||||
1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0,
|
||||
0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,
|
||||
0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
|
||||
0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
},
|
||||
{ // 3
|
||||
0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
|
||||
1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||
0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0,
|
||||
0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0,
|
||||
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
|
||||
1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
|
||||
0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
|
||||
},
|
||||
{ // 4
|
||||
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0,
|
||||
0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0,
|
||||
0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0,
|
||||
0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0,
|
||||
0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0,
|
||||
0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0,
|
||||
0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0,
|
||||
0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0,
|
||||
0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0,
|
||||
1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0,
|
||||
1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
|
||||
},
|
||||
{ // 5
|
||||
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
|
||||
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
|
||||
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
|
||||
0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
|
||||
1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
|
||||
0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0,
|
||||
},
|
||||
{ // 6
|
||||
0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0,
|
||||
0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0,
|
||||
0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0,
|
||||
1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0,
|
||||
1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0,
|
||||
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1,
|
||||
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1,
|
||||
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0,
|
||||
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
|
||||
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
|
||||
},
|
||||
{ // 7
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0,
|
||||
},
|
||||
{ // 8
|
||||
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
|
||||
0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0,
|
||||
0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1,
|
||||
0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0,
|
||||
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
|
||||
0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0,
|
||||
0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0,
|
||||
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0,
|
||||
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1,
|
||||
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
|
||||
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
|
||||
},
|
||||
{ // 9
|
||||
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
|
||||
0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
|
||||
0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0,
|
||||
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1,
|
||||
0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1,
|
||||
0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,
|
||||
0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
|
||||
0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0,
|
||||
0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
|
||||
0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
|
||||
},
|
||||
}
|
||||
|
||||
type Image struct {
|
||||
*image.Paletted
|
||||
numWidth int
|
||||
numHeight int
|
||||
dotSize int
|
||||
}
|
||||
|
||||
var prng = &siprng{}
|
||||
|
||||
// randIntn returns a pseudorandom non-negative int in range [0, n).
|
||||
func randIntn(n int) int {
|
||||
return prng.Intn(n)
|
||||
}
|
||||
|
||||
// randInt returns a pseudorandom int in range [from, to].
|
||||
func randInt(from, to int) int {
|
||||
return prng.Intn(to+1-from) + from
|
||||
}
|
||||
|
||||
// randFloat returns a pseudorandom float64 in range [from, to].
|
||||
func randFloat(from, to float64) float64 {
|
||||
return (to-from)*prng.Float64() + from
|
||||
}
|
||||
|
||||
func randomPalette() color.Palette {
|
||||
p := make([]color.Color, circleCount+1)
|
||||
// Transparent color.
|
||||
p[0] = color.RGBA{0xFF, 0xFF, 0xFF, 0x00}
|
||||
// Primary color.
|
||||
prim := color.RGBA{
|
||||
uint8(randIntn(129)),
|
||||
uint8(randIntn(129)),
|
||||
uint8(randIntn(129)),
|
||||
0xFF,
|
||||
}
|
||||
p[1] = prim
|
||||
// Circle colors.
|
||||
for i := 2; i <= circleCount; i++ {
|
||||
p[i] = randomBrightness(prim, 255)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// NewImage returns a new captcha image of the given width and height with the
|
||||
// given digits, where each digit must be in range 0-9.
|
||||
func NewImage(digits []byte, width, height int) *Image {
|
||||
m := new(Image)
|
||||
m.Paletted = image.NewPaletted(image.Rect(0, 0, width, height), randomPalette())
|
||||
m.calculateSizes(width, height, len(digits))
|
||||
// Randomly position captcha inside the image.
|
||||
maxx := width - (m.numWidth+m.dotSize)*len(digits) - m.dotSize
|
||||
maxy := height - m.numHeight - m.dotSize*2
|
||||
var border int
|
||||
if width > height {
|
||||
border = height / 5
|
||||
} else {
|
||||
border = width / 5
|
||||
}
|
||||
x := randInt(border, maxx-border)
|
||||
y := randInt(border, maxy-border)
|
||||
// Draw digits.
|
||||
for _, n := range digits {
|
||||
m.drawDigit(font[n], x, y)
|
||||
x += m.numWidth + m.dotSize
|
||||
}
|
||||
// Draw strike-through line.
|
||||
m.strikeThrough()
|
||||
// Apply wave distortion.
|
||||
m.distort(randFloat(5, 10), randFloat(100, 200))
|
||||
// Fill image with random circles.
|
||||
m.fillWithCircles(circleCount, m.dotSize)
|
||||
return m
|
||||
}
|
||||
|
||||
// encodedPNG encodes an image to PNG and returns
|
||||
// the result as a byte slice.
|
||||
func (m *Image) encodedPNG() []byte {
|
||||
var buf bytes.Buffer
|
||||
if err := png.Encode(&buf, m.Paletted); err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// WriteTo writes captcha image in PNG format into the given writer.
|
||||
func (m *Image) WriteTo(w io.Writer) (int64, error) {
|
||||
n, err := w.Write(m.encodedPNG())
|
||||
return int64(n), err
|
||||
}
|
||||
|
||||
func (m *Image) calculateSizes(width, height, ncount int) {
|
||||
// Goal: fit all digits inside the image.
|
||||
var border int
|
||||
if width > height {
|
||||
border = height / 4
|
||||
} else {
|
||||
border = width / 4
|
||||
}
|
||||
// Convert everything to floats for calculations.
|
||||
w := float64(width - border*2)
|
||||
h := float64(height - border*2)
|
||||
// fw takes into account 1-dot spacing between digits.
|
||||
fw := float64(fontWidth + 1)
|
||||
fh := float64(fontHeight)
|
||||
nc := float64(ncount)
|
||||
// Calculate the width of a single digit taking into account only the
|
||||
// width of the image.
|
||||
nw := w / nc
|
||||
// Calculate the height of a digit from this width.
|
||||
nh := nw * fh / fw
|
||||
// Digit too high?
|
||||
if nh > h {
|
||||
// Fit digits based on height.
|
||||
nh = h
|
||||
nw = fw / fh * nh
|
||||
}
|
||||
// Calculate dot size.
|
||||
m.dotSize = int(nh / fh)
|
||||
// Save everything, making the actual width smaller by 1 dot to account
|
||||
// for spacing between digits.
|
||||
m.numWidth = int(nw) - m.dotSize
|
||||
m.numHeight = int(nh)
|
||||
}
|
||||
|
||||
func (m *Image) drawHorizLine(fromX, toX, y int, colorIdx uint8) {
|
||||
for x := fromX; x <= toX; x++ {
|
||||
m.SetColorIndex(x, y, colorIdx)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Image) drawCircle(x, y, radius int, colorIdx uint8) {
|
||||
f := 1 - radius
|
||||
dfx := 1
|
||||
dfy := -2 * radius
|
||||
xo := 0
|
||||
yo := radius
|
||||
|
||||
m.SetColorIndex(x, y+radius, colorIdx)
|
||||
m.SetColorIndex(x, y-radius, colorIdx)
|
||||
m.drawHorizLine(x-radius, x+radius, y, colorIdx)
|
||||
|
||||
for xo < yo {
|
||||
if f >= 0 {
|
||||
yo--
|
||||
dfy += 2
|
||||
f += dfy
|
||||
}
|
||||
xo++
|
||||
dfx += 2
|
||||
f += dfx
|
||||
m.drawHorizLine(x-xo, x+xo, y+yo, colorIdx)
|
||||
m.drawHorizLine(x-xo, x+xo, y-yo, colorIdx)
|
||||
m.drawHorizLine(x-yo, x+yo, y+xo, colorIdx)
|
||||
m.drawHorizLine(x-yo, x+yo, y-xo, colorIdx)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Image) fillWithCircles(n, maxradius int) {
|
||||
maxx := m.Bounds().Max.X
|
||||
maxy := m.Bounds().Max.Y
|
||||
for i := 0; i < n; i++ {
|
||||
colorIdx := uint8(randInt(1, circleCount-1))
|
||||
r := randInt(1, maxradius)
|
||||
m.drawCircle(randInt(r, maxx-r), randInt(r, maxy-r), r, colorIdx)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Image) strikeThrough() {
|
||||
maxx := m.Bounds().Max.X
|
||||
maxy := m.Bounds().Max.Y
|
||||
y := randInt(maxy/3, maxy-maxy/3)
|
||||
amplitude := randFloat(5, 20)
|
||||
period := randFloat(80, 180)
|
||||
dx := 2.0 * math.Pi / period
|
||||
for x := 0; x < maxx; x++ {
|
||||
xo := amplitude * math.Cos(float64(y)*dx)
|
||||
yo := amplitude * math.Sin(float64(x)*dx)
|
||||
for yn := 0; yn < m.dotSize; yn++ {
|
||||
r := randInt(0, m.dotSize)
|
||||
m.drawCircle(x+int(xo), y+int(yo)+(yn*m.dotSize), r/2, 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Image) drawDigit(digit []byte, x, y int) {
|
||||
skf := randFloat(-maxSkew, maxSkew)
|
||||
xs := float64(x)
|
||||
r := m.dotSize / 2
|
||||
y += randInt(-r, r)
|
||||
for yo := 0; yo < fontHeight; yo++ {
|
||||
for xo := 0; xo < fontWidth; xo++ {
|
||||
if digit[yo*fontWidth+xo] != blackChar {
|
||||
continue
|
||||
}
|
||||
m.drawCircle(x+xo*m.dotSize, y+yo*m.dotSize, r, 1)
|
||||
}
|
||||
xs += skf
|
||||
x = int(xs)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Image) distort(amplude float64, period float64) {
|
||||
w := m.Bounds().Max.X
|
||||
h := m.Bounds().Max.Y
|
||||
|
||||
oldm := m.Paletted
|
||||
newm := image.NewPaletted(image.Rect(0, 0, w, h), oldm.Palette)
|
||||
|
||||
dx := 2.0 * math.Pi / period
|
||||
for x := 0; x < w; x++ {
|
||||
for y := 0; y < h; y++ {
|
||||
xo := amplude * math.Sin(float64(y)*dx)
|
||||
yo := amplude * math.Cos(float64(x)*dx)
|
||||
newm.SetColorIndex(x, y, oldm.ColorIndexAt(x+int(xo), y+int(yo)))
|
||||
}
|
||||
}
|
||||
m.Paletted = newm
|
||||
}
|
||||
|
||||
func randomBrightness(c color.RGBA, max uint8) color.RGBA {
|
||||
minc := min3(c.R, c.G, c.B)
|
||||
maxc := max3(c.R, c.G, c.B)
|
||||
if maxc > max {
|
||||
return c
|
||||
}
|
||||
n := randIntn(int(max-maxc)) - int(minc)
|
||||
return color.RGBA{
|
||||
uint8(int(c.R) + n),
|
||||
uint8(int(c.G) + n),
|
||||
uint8(int(c.B) + n),
|
||||
uint8(c.A),
|
||||
}
|
||||
}
|
||||
|
||||
func min3(x, y, z uint8) (m uint8) {
|
||||
m = x
|
||||
if y < m {
|
||||
m = y
|
||||
}
|
||||
if z < m {
|
||||
m = z
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func max3(x, y, z uint8) (m uint8) {
|
||||
m = x
|
||||
if y > m {
|
||||
m = y
|
||||
}
|
||||
if z > m {
|
||||
m = z
|
||||
}
|
||||
return
|
||||
}
|
38
utils/captcha/image_test.go
Normal file
38
utils/captcha/image_test.go
Normal file
@ -0,0 +1,38 @@
|
||||
package captcha
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/astaxie/beego/utils"
|
||||
)
|
||||
|
||||
type byteCounter struct {
|
||||
n int64
|
||||
}
|
||||
|
||||
func (bc *byteCounter) Write(b []byte) (int, error) {
|
||||
bc.n += int64(len(b))
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func BenchmarkNewImage(b *testing.B) {
|
||||
b.StopTimer()
|
||||
d := utils.RandomCreateBytes(challengeNums, defaultChars...)
|
||||
b.StartTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
NewImage(d, stdWidth, stdHeight)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkImageWriteTo(b *testing.B) {
|
||||
b.StopTimer()
|
||||
d := utils.RandomCreateBytes(challengeNums, defaultChars...)
|
||||
b.StartTimer()
|
||||
counter := &byteCounter{}
|
||||
for i := 0; i < b.N; i++ {
|
||||
img := NewImage(d, stdWidth, stdHeight)
|
||||
img.WriteTo(counter)
|
||||
b.SetBytes(counter.n)
|
||||
counter.n = 0
|
||||
}
|
||||
}
|
264
utils/captcha/siprng.go
Normal file
264
utils/captcha/siprng.go
Normal file
@ -0,0 +1,264 @@
|
||||
// modifiy and integrated to Beego from https://github.com/dchest/captcha
|
||||
package captcha
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// siprng is PRNG based on SipHash-2-4.
|
||||
type siprng struct {
|
||||
mu sync.Mutex
|
||||
k0, k1, ctr uint64
|
||||
}
|
||||
|
||||
// siphash implements SipHash-2-4, accepting a uint64 as a message.
|
||||
func siphash(k0, k1, m uint64) uint64 {
|
||||
// Initialization.
|
||||
v0 := k0 ^ 0x736f6d6570736575
|
||||
v1 := k1 ^ 0x646f72616e646f6d
|
||||
v2 := k0 ^ 0x6c7967656e657261
|
||||
v3 := k1 ^ 0x7465646279746573
|
||||
t := uint64(8) << 56
|
||||
|
||||
// Compression.
|
||||
v3 ^= m
|
||||
|
||||
// Round 1.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>(64-13)
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>(64-32)
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>(64-16)
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>(64-21)
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>(64-17)
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>(64-32)
|
||||
|
||||
// Round 2.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>(64-13)
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>(64-32)
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>(64-16)
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>(64-21)
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>(64-17)
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>(64-32)
|
||||
|
||||
v0 ^= m
|
||||
|
||||
// Compress last block.
|
||||
v3 ^= t
|
||||
|
||||
// Round 1.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>(64-13)
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>(64-32)
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>(64-16)
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>(64-21)
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>(64-17)
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>(64-32)
|
||||
|
||||
// Round 2.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>(64-13)
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>(64-32)
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>(64-16)
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>(64-21)
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>(64-17)
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>(64-32)
|
||||
|
||||
v0 ^= t
|
||||
|
||||
// Finalization.
|
||||
v2 ^= 0xff
|
||||
|
||||
// Round 1.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>(64-13)
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>(64-32)
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>(64-16)
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>(64-21)
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>(64-17)
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>(64-32)
|
||||
|
||||
// Round 2.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>(64-13)
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>(64-32)
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>(64-16)
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>(64-21)
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>(64-17)
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>(64-32)
|
||||
|
||||
// Round 3.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>(64-13)
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>(64-32)
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>(64-16)
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>(64-21)
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>(64-17)
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>(64-32)
|
||||
|
||||
// Round 4.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>(64-13)
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>(64-32)
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>(64-16)
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>(64-21)
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>(64-17)
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>(64-32)
|
||||
|
||||
return v0 ^ v1 ^ v2 ^ v3
|
||||
}
|
||||
|
||||
// rekey sets a new PRNG key, which is read from crypto/rand.
|
||||
func (p *siprng) rekey() {
|
||||
var k [16]byte
|
||||
if _, err := io.ReadFull(rand.Reader, k[:]); err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
p.k0 = binary.LittleEndian.Uint64(k[0:8])
|
||||
p.k1 = binary.LittleEndian.Uint64(k[8:16])
|
||||
p.ctr = 1
|
||||
}
|
||||
|
||||
// Uint64 returns a new pseudorandom uint64.
|
||||
// It rekeys PRNG on the first call and every 64 MB of generated data.
|
||||
func (p *siprng) Uint64() uint64 {
|
||||
p.mu.Lock()
|
||||
if p.ctr == 0 || p.ctr > 8*1024*1024 {
|
||||
p.rekey()
|
||||
}
|
||||
v := siphash(p.k0, p.k1, p.ctr)
|
||||
p.ctr++
|
||||
p.mu.Unlock()
|
||||
return v
|
||||
}
|
||||
|
||||
func (p *siprng) Int63() int64 {
|
||||
return int64(p.Uint64() & 0x7fffffffffffffff)
|
||||
}
|
||||
|
||||
func (p *siprng) Uint32() uint32 {
|
||||
return uint32(p.Uint64())
|
||||
}
|
||||
|
||||
func (p *siprng) Int31() int32 {
|
||||
return int32(p.Uint32() & 0x7fffffff)
|
||||
}
|
||||
|
||||
func (p *siprng) Intn(n int) int {
|
||||
if n <= 0 {
|
||||
panic("invalid argument to Intn")
|
||||
}
|
||||
if n <= 1<<31-1 {
|
||||
return int(p.Int31n(int32(n)))
|
||||
}
|
||||
return int(p.Int63n(int64(n)))
|
||||
}
|
||||
|
||||
func (p *siprng) Int63n(n int64) int64 {
|
||||
if n <= 0 {
|
||||
panic("invalid argument to Int63n")
|
||||
}
|
||||
max := int64((1 << 63) - 1 - (1<<63)%uint64(n))
|
||||
v := p.Int63()
|
||||
for v > max {
|
||||
v = p.Int63()
|
||||
}
|
||||
return v % n
|
||||
}
|
||||
|
||||
func (p *siprng) Int31n(n int32) int32 {
|
||||
if n <= 0 {
|
||||
panic("invalid argument to Int31n")
|
||||
}
|
||||
max := int32((1 << 31) - 1 - (1<<31)%uint32(n))
|
||||
v := p.Int31()
|
||||
for v > max {
|
||||
v = p.Int31()
|
||||
}
|
||||
return v % n
|
||||
}
|
||||
|
||||
func (p *siprng) Float64() float64 { return float64(p.Int63()) / (1 << 63) }
|
19
utils/captcha/siprng_test.go
Normal file
19
utils/captcha/siprng_test.go
Normal file
@ -0,0 +1,19 @@
|
||||
package captcha
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSiphash(t *testing.T) {
|
||||
good := uint64(0xe849e8bb6ffe2567)
|
||||
cur := siphash(0, 0, 0)
|
||||
if cur != good {
|
||||
t.Fatalf("siphash: expected %x, got %x", good, cur)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSiprng(b *testing.B) {
|
||||
b.SetBytes(8)
|
||||
p := &siprng{}
|
||||
for i := 0; i < b.N; i++ {
|
||||
p.Uint64()
|
||||
}
|
||||
}
|
299
utils/mail.go
Normal file
299
utils/mail.go
Normal file
@ -0,0 +1,299 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/mail"
|
||||
"net/smtp"
|
||||
"net/textproto"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
maxLineLength = 76
|
||||
)
|
||||
|
||||
// Email is the type used for email messages
|
||||
type Email struct {
|
||||
Auth smtp.Auth
|
||||
Identity string `json:"identity"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
From string `json:"from"`
|
||||
To []string
|
||||
Bcc []string
|
||||
Cc []string
|
||||
Subject string
|
||||
Text string // Plaintext message (optional)
|
||||
HTML string // Html message (optional)
|
||||
Headers textproto.MIMEHeader
|
||||
Attachments []*Attachment
|
||||
ReadReceipt []string
|
||||
}
|
||||
|
||||
// Attachment is a struct representing an email attachment.
|
||||
// Based on the mime/multipart.FileHeader struct, Attachment contains the name, MIMEHeader, and content of the attachment in question
|
||||
type Attachment struct {
|
||||
Filename string
|
||||
Header textproto.MIMEHeader
|
||||
Content []byte
|
||||
}
|
||||
|
||||
func NewEMail(config string) *Email {
|
||||
e := new(Email)
|
||||
e.Headers = textproto.MIMEHeader{}
|
||||
err := json.Unmarshal([]byte(config), e)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if e.From == "" {
|
||||
e.From = e.Username
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// make all send information to byte
|
||||
func (e *Email) Bytes() ([]byte, error) {
|
||||
buff := &bytes.Buffer{}
|
||||
w := multipart.NewWriter(buff)
|
||||
// Set the appropriate headers (overwriting any conflicts)
|
||||
// Leave out Bcc (only included in envelope headers)
|
||||
e.Headers.Set("To", strings.Join(e.To, ","))
|
||||
if e.Cc != nil {
|
||||
e.Headers.Set("Cc", strings.Join(e.Cc, ","))
|
||||
}
|
||||
e.Headers.Set("From", e.From)
|
||||
e.Headers.Set("Subject", e.Subject)
|
||||
if len(e.ReadReceipt) != 0 {
|
||||
e.Headers.Set("Disposition-Notification-To", strings.Join(e.ReadReceipt, ","))
|
||||
}
|
||||
e.Headers.Set("MIME-Version", "1.0")
|
||||
e.Headers.Set("Content-Type", fmt.Sprintf("multipart/mixed;\r\n boundary=%s\r\n", w.Boundary()))
|
||||
|
||||
// Write the envelope headers (including any custom headers)
|
||||
if err := headerToBytes(buff, e.Headers); err != nil {
|
||||
return nil, fmt.Errorf("Failed to render message headers: %s", err)
|
||||
}
|
||||
// Start the multipart/mixed part
|
||||
fmt.Fprintf(buff, "--%s\r\n", w.Boundary())
|
||||
header := textproto.MIMEHeader{}
|
||||
// Check to see if there is a Text or HTML field
|
||||
if e.Text != "" || e.HTML != "" {
|
||||
subWriter := multipart.NewWriter(buff)
|
||||
// Create the multipart alternative part
|
||||
header.Set("Content-Type", fmt.Sprintf("multipart/alternative;\r\n boundary=%s\r\n", subWriter.Boundary()))
|
||||
// Write the header
|
||||
if err := headerToBytes(buff, header); err != nil {
|
||||
return nil, fmt.Errorf("Failed to render multipart message headers: %s", err)
|
||||
}
|
||||
// Create the body sections
|
||||
if e.Text != "" {
|
||||
header.Set("Content-Type", fmt.Sprintf("text/plain; charset=UTF-8"))
|
||||
header.Set("Content-Transfer-Encoding", "quoted-printable")
|
||||
if _, err := subWriter.CreatePart(header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Write the text
|
||||
if err := quotePrintEncode(buff, e.Text); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if e.HTML != "" {
|
||||
header.Set("Content-Type", fmt.Sprintf("text/html; charset=UTF-8"))
|
||||
header.Set("Content-Transfer-Encoding", "quoted-printable")
|
||||
if _, err := subWriter.CreatePart(header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Write the text
|
||||
if err := quotePrintEncode(buff, e.HTML); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := subWriter.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// Create attachment part, if necessary
|
||||
for _, a := range e.Attachments {
|
||||
ap, err := w.CreatePart(a.Header)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Write the base64Wrapped content to the part
|
||||
base64Wrap(ap, a.Content)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buff.Bytes(), nil
|
||||
}
|
||||
|
||||
// Attach file to the send mail
|
||||
func (e *Email) AttachFile(filename string) (a *Attachment, err error) {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ct := mime.TypeByExtension(filepath.Ext(filename))
|
||||
basename := path.Base(filename)
|
||||
return e.Attach(f, basename, ct)
|
||||
}
|
||||
|
||||
// Attach is used to attach content from an io.Reader to the email.
|
||||
// Required parameters include an io.Reader, the desired filename for the attachment, and the Content-Type
|
||||
func (e *Email) Attach(r io.Reader, filename string, c string) (a *Attachment, err error) {
|
||||
var buffer bytes.Buffer
|
||||
if _, err = io.Copy(&buffer, r); err != nil {
|
||||
return
|
||||
}
|
||||
at := &Attachment{
|
||||
Filename: filename,
|
||||
Header: textproto.MIMEHeader{},
|
||||
Content: buffer.Bytes(),
|
||||
}
|
||||
// Get the Content-Type to be used in the MIMEHeader
|
||||
if c != "" {
|
||||
at.Header.Set("Content-Type", c)
|
||||
} else {
|
||||
// If the Content-Type is blank, set the Content-Type to "application/octet-stream"
|
||||
at.Header.Set("Content-Type", "application/octet-stream")
|
||||
}
|
||||
at.Header.Set("Content-Disposition", fmt.Sprintf("attachment;\r\n filename=\"%s\"", filename))
|
||||
at.Header.Set("Content-Transfer-Encoding", "base64")
|
||||
e.Attachments = append(e.Attachments, at)
|
||||
return at, nil
|
||||
}
|
||||
|
||||
func (e *Email) Send() error {
|
||||
if e.Auth == nil {
|
||||
e.Auth = smtp.PlainAuth(e.Identity, e.Username, e.Password, e.Host)
|
||||
}
|
||||
// Merge the To, Cc, and Bcc fields
|
||||
to := make([]string, 0, len(e.To)+len(e.Cc)+len(e.Bcc))
|
||||
to = append(append(append(to, e.To...), e.Cc...), e.Bcc...)
|
||||
// Check to make sure there is at least one recipient and one "From" address
|
||||
if e.From == "" || len(to) == 0 {
|
||||
return errors.New("Must specify at least one From address and one To address")
|
||||
}
|
||||
from, err := mail.ParseAddress(e.From)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
raw, err := e.Bytes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return smtp.SendMail(e.Host+":"+strconv.Itoa(e.Port), e.Auth, from.Address, to, raw)
|
||||
}
|
||||
|
||||
// quotePrintEncode writes the quoted-printable text to the IO Writer (according to RFC 2045)
|
||||
func quotePrintEncode(w io.Writer, s string) error {
|
||||
var buf [3]byte
|
||||
mc := 0
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
// We're assuming Unix style text formats as input (LF line break), and
|
||||
// quoted-printble uses CRLF line breaks. (Literal CRs will become
|
||||
// "=0D", but probably shouldn't be there to begin with!)
|
||||
if c == '\n' {
|
||||
io.WriteString(w, "\r\n")
|
||||
mc = 0
|
||||
continue
|
||||
}
|
||||
|
||||
var nextOut []byte
|
||||
if isPrintable(c) {
|
||||
nextOut = append(buf[:0], c)
|
||||
} else {
|
||||
nextOut = buf[:]
|
||||
qpEscape(nextOut, c)
|
||||
}
|
||||
|
||||
// Add a soft line break if the next (encoded) byte would push this line
|
||||
// to or past the limit.
|
||||
if mc+len(nextOut) >= maxLineLength {
|
||||
if _, err := io.WriteString(w, "=\r\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
mc = 0
|
||||
}
|
||||
|
||||
if _, err := w.Write(nextOut); err != nil {
|
||||
return err
|
||||
}
|
||||
mc += len(nextOut)
|
||||
}
|
||||
// No trailing end-of-line?? Soft line break, then. TODO: is this sane?
|
||||
if mc > 0 {
|
||||
io.WriteString(w, "=\r\n")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isPrintable returns true if the rune given is "printable" according to RFC 2045, false otherwise
|
||||
func isPrintable(c byte) bool {
|
||||
return (c >= '!' && c <= '<') || (c >= '>' && c <= '~') || (c == ' ' || c == '\n' || c == '\t')
|
||||
}
|
||||
|
||||
// qpEscape is a helper function for quotePrintEncode which escapes a
|
||||
// non-printable byte. Expects len(dest) == 3.
|
||||
func qpEscape(dest []byte, c byte) {
|
||||
const nums = "0123456789ABCDEF"
|
||||
dest[0] = '='
|
||||
dest[1] = nums[(c&0xf0)>>4]
|
||||
dest[2] = nums[(c & 0xf)]
|
||||
}
|
||||
|
||||
// headerToBytes enumerates the key and values in the header, and writes the results to the IO Writer
|
||||
func headerToBytes(w io.Writer, t textproto.MIMEHeader) error {
|
||||
for k, v := range t {
|
||||
// Write the header key
|
||||
_, err := fmt.Fprintf(w, "%s:", k)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Write each value in the header
|
||||
for _, c := range v {
|
||||
_, err := fmt.Fprintf(w, " %s\r\n", c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// base64Wrap encodeds the attachment content, and wraps it according to RFC 2045 standards (every 76 chars)
|
||||
// The output is then written to the specified io.Writer
|
||||
func base64Wrap(w io.Writer, b []byte) {
|
||||
// 57 raw bytes per 76-byte base64 line.
|
||||
const maxRaw = 57
|
||||
// Buffer for each line, including trailing CRLF.
|
||||
var buffer [maxLineLength + len("\r\n")]byte
|
||||
copy(buffer[maxLineLength:], "\r\n")
|
||||
// Process raw chunks until there's no longer enough to fill a line.
|
||||
for len(b) >= maxRaw {
|
||||
base64.StdEncoding.Encode(buffer[:], b[:maxRaw])
|
||||
w.Write(buffer[:])
|
||||
b = b[maxRaw:]
|
||||
}
|
||||
// Handle the last chunk of bytes.
|
||||
if len(b) > 0 {
|
||||
out := buffer[:base64.StdEncoding.EncodedLen(len(b))]
|
||||
base64.StdEncoding.Encode(out, b)
|
||||
out = append(out, "\r\n"...)
|
||||
w.Write(out)
|
||||
}
|
||||
}
|
27
utils/mail_test.go
Normal file
27
utils/mail_test.go
Normal file
@ -0,0 +1,27 @@
|
||||
package utils
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestMail(t *testing.T) {
|
||||
config := `{"username":"astaxie@gmail.com","password":"astaxie","host":"smtp.gmail.com","port":587}`
|
||||
mail := NewEMail(config)
|
||||
if mail.Username != "astaxie@gmail.com" {
|
||||
t.Fatal("email parse get username error")
|
||||
}
|
||||
if mail.Password != "astaxie" {
|
||||
t.Fatal("email parse get password error")
|
||||
}
|
||||
if mail.Host != "smtp.gmail.com" {
|
||||
t.Fatal("email parse get host error")
|
||||
}
|
||||
if mail.Port != 587 {
|
||||
t.Fatal("email parse get port error")
|
||||
}
|
||||
mail.To = []string{"xiemengjun@gmail.com"}
|
||||
mail.From = "astaxie@gmail.com"
|
||||
mail.Subject = "hi, just from beego!"
|
||||
mail.Text = "Text Body is, of course, supported!"
|
||||
mail.HTML = "<h1>Fancy Html is supported, too!</h1>"
|
||||
mail.AttachFile("/Users/astaxie/github/beego/beego.go")
|
||||
mail.Send()
|
||||
}
|
20
utils/rand.go
Normal file
20
utils/rand.go
Normal file
@ -0,0 +1,20 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
)
|
||||
|
||||
// RandomCreateBytes generate random []byte by specify chars.
|
||||
func RandomCreateBytes(n int, alphabets ...byte) []byte {
|
||||
const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
var bytes = make([]byte, n)
|
||||
rand.Read(bytes)
|
||||
for i, b := range bytes {
|
||||
if len(alphabets) == 0 {
|
||||
bytes[i] = alphanum[b%byte(len(alphanum))]
|
||||
} else {
|
||||
bytes[i] = alphabets[b%byte(len(alphabets))]
|
||||
}
|
||||
}
|
||||
return bytes
|
||||
}
|
Loading…
Reference in New Issue
Block a user