mirror of
https://github.com/astaxie/beego.git
synced 2024-11-22 12:50:55 +00:00
Merge branch 'release/release1.1.0'
This commit is contained in:
commit
92196c602b
@ -35,9 +35,5 @@ More info [beego.me](http://beego.me)
|
|||||||
beego is licensed under the Apache Licence, Version 2.0
|
beego is licensed under the Apache Licence, Version 2.0
|
||||||
(http://www.apache.org/licenses/LICENSE-2.0.html).
|
(http://www.apache.org/licenses/LICENSE-2.0.html).
|
||||||
|
|
||||||
|
[![Clone in Koding](http://learn.koding.com/btn/clone_d.png)][koding]
|
||||||
## Use case
|
[koding]: https://koding.com/Teamwork?import=https://github.com/astaxie/beego/archive/master.zip&c=git1
|
||||||
|
|
||||||
- Displaying API documentation: [gowalker](https://github.com/Unknwon/gowalker)
|
|
||||||
- seocms: [seocms](https://github.com/chinakr/seocms)
|
|
||||||
- CMS: [toropress](https://github.com/insionng/toropress)
|
|
8
app.go
8
app.go
@ -118,6 +118,14 @@ func (app *App) AutoRouter(c ControllerInterface) *App {
|
|||||||
return app
|
return app
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AutoRouterWithPrefix adds beego-defined controller handler with prefix.
|
||||||
|
// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page,
|
||||||
|
// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function.
|
||||||
|
func (app *App) AutoRouterWithPrefix(prefix string, c ControllerInterface) *App {
|
||||||
|
app.Handlers.AddAutoPrefix(prefix, c)
|
||||||
|
return app
|
||||||
|
}
|
||||||
|
|
||||||
// UrlFor creates a url with another registered controller handler with params.
|
// UrlFor creates a url with another registered controller handler with params.
|
||||||
// The endpoint is formed as path.controller.name to defined the controller method which will run.
|
// The endpoint is formed as path.controller.name to defined the controller method which will run.
|
||||||
// The values need key-pair data to assign into controller method.
|
// The values need key-pair data to assign into controller method.
|
||||||
|
128
beego.go
128
beego.go
@ -4,6 +4,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/astaxie/beego/middleware"
|
"github.com/astaxie/beego/middleware"
|
||||||
@ -11,7 +12,77 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// beego web framework version.
|
// beego web framework version.
|
||||||
const VERSION = "1.0.1"
|
const VERSION = "1.1.0"
|
||||||
|
|
||||||
|
type hookfunc func() error //hook function to run
|
||||||
|
var hooks []hookfunc //hook function slice to store the hookfunc
|
||||||
|
|
||||||
|
type groupRouter struct {
|
||||||
|
pattern string
|
||||||
|
controller ControllerInterface
|
||||||
|
mappingMethods string
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterGroups which will store routers
|
||||||
|
type GroupRouters []groupRouter
|
||||||
|
|
||||||
|
// Get a new GroupRouters
|
||||||
|
func NewGroupRouters() GroupRouters {
|
||||||
|
return make([]groupRouter, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add Router in the GroupRouters
|
||||||
|
// it is for plugin or module to register router
|
||||||
|
func (gr GroupRouters) AddRouter(pattern string, c ControllerInterface, mappingMethod ...string) {
|
||||||
|
var newRG groupRouter
|
||||||
|
if len(mappingMethod) > 0 {
|
||||||
|
newRG = groupRouter{
|
||||||
|
pattern,
|
||||||
|
c,
|
||||||
|
mappingMethod[0],
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
newRG = groupRouter{
|
||||||
|
pattern,
|
||||||
|
c,
|
||||||
|
"",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
gr = append(gr, newRG)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (gr GroupRouters) AddAuto(c ControllerInterface) {
|
||||||
|
newRG := groupRouter{
|
||||||
|
"",
|
||||||
|
c,
|
||||||
|
"",
|
||||||
|
}
|
||||||
|
gr = append(gr, newRG)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddGroupRouter with the prefix
|
||||||
|
// it will register the router in BeeApp
|
||||||
|
// the follow code is write in modules:
|
||||||
|
// GR:=NewGroupRouters()
|
||||||
|
// GR.AddRouter("/login",&UserController,"get:Login")
|
||||||
|
// GR.AddRouter("/logout",&UserController,"get:Logout")
|
||||||
|
// GR.AddRouter("/register",&UserController,"get:Reg")
|
||||||
|
// the follow code is write in app:
|
||||||
|
// import "github.com/beego/modules/auth"
|
||||||
|
// AddRouterGroup("/admin", auth.GR)
|
||||||
|
func AddGroupRouter(prefix string, groups GroupRouters) *App {
|
||||||
|
for _, v := range groups {
|
||||||
|
if v.pattern == "" {
|
||||||
|
BeeApp.AutoRouterWithPrefix(prefix, v.controller)
|
||||||
|
} else if v.mappingMethods != "" {
|
||||||
|
BeeApp.Router(prefix+v.pattern, v.controller, v.mappingMethods)
|
||||||
|
} else {
|
||||||
|
BeeApp.Router(prefix+v.pattern, v.controller)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
return BeeApp
|
||||||
|
}
|
||||||
|
|
||||||
// Router adds a patterned controller handler to BeeApp.
|
// Router adds a patterned controller handler to BeeApp.
|
||||||
// it's an alias method of App.Router.
|
// it's an alias method of App.Router.
|
||||||
@ -36,6 +107,13 @@ func AutoRouter(c ControllerInterface) *App {
|
|||||||
return BeeApp
|
return BeeApp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AutoPrefix adds controller handler to BeeApp with prefix.
|
||||||
|
// it's same to App.AutoRouterWithPrefix.
|
||||||
|
func AutoPrefix(prefix string, c ControllerInterface) *App {
|
||||||
|
BeeApp.AutoRouterWithPrefix(prefix, c)
|
||||||
|
return BeeApp
|
||||||
|
}
|
||||||
|
|
||||||
// ErrorHandler registers http.HandlerFunc to each http err code string.
|
// ErrorHandler registers http.HandlerFunc to each http err code string.
|
||||||
// usage:
|
// usage:
|
||||||
// beego.ErrorHandler("404",NotFound)
|
// beego.ErrorHandler("404",NotFound)
|
||||||
@ -87,6 +165,12 @@ func InsertFilter(pattern string, pos int, filter FilterFunc) *App {
|
|||||||
return BeeApp
|
return BeeApp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The hookfunc will run in beego.Run()
|
||||||
|
// such as sessionInit, middlerware start, buildtemplate, admin start
|
||||||
|
func AddAPPStartHook(hf hookfunc) {
|
||||||
|
hooks = append(hooks, hf)
|
||||||
|
}
|
||||||
|
|
||||||
// Run beego application.
|
// Run beego application.
|
||||||
// it's alias of App.Run.
|
// it's alias of App.Run.
|
||||||
func Run() {
|
func Run() {
|
||||||
@ -99,18 +183,32 @@ func Run() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//init mime
|
// do hooks function
|
||||||
initMime()
|
for _, hk := range hooks {
|
||||||
|
err := hk()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if SessionOn {
|
if SessionOn {
|
||||||
GlobalSessions, _ = session.NewManager(SessionProvider,
|
var err error
|
||||||
SessionName,
|
sessionConfig := AppConfig.String("sessionConfig")
|
||||||
SessionGCMaxLifetime,
|
if sessionConfig == "" {
|
||||||
SessionSavePath,
|
sessionConfig = `{"cookieName":"` + SessionName + `",` +
|
||||||
HttpTLS,
|
`"gclifetime":` + strconv.FormatInt(SessionGCMaxLifetime, 10) + `,` +
|
||||||
SessionHashFunc,
|
`"providerConfig":"` + SessionSavePath + `",` +
|
||||||
SessionHashKey,
|
`"secure":` + strconv.FormatBool(HttpTLS) + `,` +
|
||||||
SessionCookieLifeTime)
|
`"sessionIDHashFunc":"` + SessionHashFunc + `",` +
|
||||||
|
`"sessionIDHashKey":"` + SessionHashKey + `",` +
|
||||||
|
`"enableSetCookie":` + strconv.FormatBool(SessionAutoSetCookie) + `,` +
|
||||||
|
`"cookieLifeTime":` + strconv.Itoa(SessionCookieLifeTime) + `}`
|
||||||
|
}
|
||||||
|
GlobalSessions, err = session.NewManager(SessionProvider,
|
||||||
|
sessionConfig)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
go GlobalSessions.GC()
|
go GlobalSessions.GC()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -123,7 +221,7 @@ func Run() {
|
|||||||
|
|
||||||
middleware.VERSION = VERSION
|
middleware.VERSION = VERSION
|
||||||
middleware.AppName = AppName
|
middleware.AppName = AppName
|
||||||
middleware.RegisterErrorHander()
|
middleware.RegisterErrorHandler()
|
||||||
|
|
||||||
if EnableAdmin {
|
if EnableAdmin {
|
||||||
go BeeAdminApp.Run()
|
go BeeAdminApp.Run()
|
||||||
@ -131,3 +229,9 @@ func Run() {
|
|||||||
|
|
||||||
BeeApp.Run()
|
BeeApp.Run()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
hooks = make([]hookfunc, 0)
|
||||||
|
//init mime
|
||||||
|
AddAPPStartHook(initMime)
|
||||||
|
}
|
||||||
|
2
cache/README.md
vendored
2
cache/README.md
vendored
@ -43,7 +43,7 @@ interval means the gc time. The cache will check at each time interval, whether
|
|||||||
|
|
||||||
## Memcache adapter
|
## Memcache adapter
|
||||||
|
|
||||||
memory adapter use the vitess's [Memcache](http://code.google.com/p/vitess/go/memcache) client.
|
Memcache adapter use the vitess's [Memcache](http://code.google.com/p/vitess/go/memcache) client.
|
||||||
|
|
||||||
Configure like this:
|
Configure like this:
|
||||||
|
|
||||||
|
50
cache/cache_test.go
vendored
50
cache/cache_test.go
vendored
@ -5,7 +5,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_cache(t *testing.T) {
|
func TestCache(t *testing.T) {
|
||||||
bm, err := NewCache("memory", `{"interval":20}`)
|
bm, err := NewCache("memory", `{"interval":20}`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("init err")
|
t.Error("init err")
|
||||||
@ -51,3 +51,51 @@ func Test_cache(t *testing.T) {
|
|||||||
t.Error("delete err")
|
t.Error("delete err")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFileCache(t *testing.T) {
|
||||||
|
bm, err := NewCache("file", `{"CachePath":"/cache","FileSuffix":".bin","DirectoryLevel":2,"EmbedExpiry":0}`)
|
||||||
|
if err != nil {
|
||||||
|
t.Error("init err")
|
||||||
|
}
|
||||||
|
if err = bm.Put("astaxie", 1, 10); err != nil {
|
||||||
|
t.Error("set Error", err)
|
||||||
|
}
|
||||||
|
if !bm.IsExist("astaxie") {
|
||||||
|
t.Error("check err")
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := bm.Get("astaxie"); v.(int) != 1 {
|
||||||
|
t.Error("get err")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = bm.Incr("astaxie"); err != nil {
|
||||||
|
t.Error("Incr Error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := bm.Get("astaxie"); v.(int) != 2 {
|
||||||
|
t.Error("get err")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = bm.Decr("astaxie"); err != nil {
|
||||||
|
t.Error("Incr Error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := bm.Get("astaxie"); v.(int) != 1 {
|
||||||
|
t.Error("get err")
|
||||||
|
}
|
||||||
|
bm.Delete("astaxie")
|
||||||
|
if bm.IsExist("astaxie") {
|
||||||
|
t.Error("delete err")
|
||||||
|
}
|
||||||
|
//test string
|
||||||
|
if err = bm.Put("astaxie", "author", 10); err != nil {
|
||||||
|
t.Error("set Error", err)
|
||||||
|
}
|
||||||
|
if !bm.IsExist("astaxie") {
|
||||||
|
t.Error("check err")
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := bm.Get("astaxie"); v.(string) != "author" {
|
||||||
|
t.Error("get err")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
10
cache/file.go
vendored
10
cache/file.go
vendored
@ -61,6 +61,7 @@ func (this *FileCache) StartAndGC(config string) error {
|
|||||||
var cfg map[string]string
|
var cfg map[string]string
|
||||||
json.Unmarshal([]byte(config), &cfg)
|
json.Unmarshal([]byte(config), &cfg)
|
||||||
//fmt.Println(cfg)
|
//fmt.Println(cfg)
|
||||||
|
//fmt.Println(config)
|
||||||
if _, ok := cfg["CachePath"]; !ok {
|
if _, ok := cfg["CachePath"]; !ok {
|
||||||
cfg["CachePath"] = FileCachePath
|
cfg["CachePath"] = FileCachePath
|
||||||
}
|
}
|
||||||
@ -135,7 +136,7 @@ func (this *FileCache) Get(key string) interface{} {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
var to FileCacheItem
|
var to FileCacheItem
|
||||||
Gob_decode([]byte(filedata), &to)
|
Gob_decode(filedata, &to)
|
||||||
if to.Expired < time.Now().Unix() {
|
if to.Expired < time.Now().Unix() {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@ -177,7 +178,7 @@ func (this *FileCache) Delete(key string) error {
|
|||||||
func (this *FileCache) Incr(key string) error {
|
func (this *FileCache) Incr(key string) error {
|
||||||
data := this.Get(key)
|
data := this.Get(key)
|
||||||
var incr int
|
var incr int
|
||||||
fmt.Println(reflect.TypeOf(data).Name())
|
//fmt.Println(reflect.TypeOf(data).Name())
|
||||||
if reflect.TypeOf(data).Name() != "int" {
|
if reflect.TypeOf(data).Name() != "int" {
|
||||||
incr = 0
|
incr = 0
|
||||||
} else {
|
} else {
|
||||||
@ -210,8 +211,7 @@ func (this *FileCache) IsExist(key string) bool {
|
|||||||
// Clean cached files.
|
// Clean cached files.
|
||||||
// not implemented.
|
// not implemented.
|
||||||
func (this *FileCache) ClearAll() error {
|
func (this *FileCache) ClearAll() error {
|
||||||
//this.CachePath .递归删除
|
//this.CachePath
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -271,7 +271,7 @@ func Gob_encode(data interface{}) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Gob decodes file cache item.
|
// Gob decodes file cache item.
|
||||||
func Gob_decode(data []byte, to interface{}) error {
|
func Gob_decode(data []byte, to *FileCacheItem) error {
|
||||||
buf := bytes.NewBuffer(data)
|
buf := bytes.NewBuffer(data)
|
||||||
dec := gob.NewDecoder(buf)
|
dec := gob.NewDecoder(buf)
|
||||||
return dec.Decode(&to)
|
return dec.Decode(&to)
|
||||||
|
41
cache/memcache.go
vendored
41
cache/memcache.go
vendored
@ -21,7 +21,11 @@ func NewMemCache() *MemcacheCache {
|
|||||||
// get value from memcache.
|
// get value from memcache.
|
||||||
func (rc *MemcacheCache) Get(key string) interface{} {
|
func (rc *MemcacheCache) Get(key string) interface{} {
|
||||||
if rc.c == nil {
|
if rc.c == nil {
|
||||||
rc.c = rc.connectInit()
|
var err error
|
||||||
|
rc.c, err = rc.connectInit()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
v, err := rc.c.Get(key)
|
v, err := rc.c.Get(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -39,7 +43,11 @@ func (rc *MemcacheCache) Get(key string) interface{} {
|
|||||||
// put value to memcache. only support string.
|
// put value to memcache. only support string.
|
||||||
func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error {
|
func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error {
|
||||||
if rc.c == nil {
|
if rc.c == nil {
|
||||||
rc.c = rc.connectInit()
|
var err error
|
||||||
|
rc.c, err = rc.connectInit()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
v, ok := val.(string)
|
v, ok := val.(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -55,7 +63,11 @@ func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error {
|
|||||||
// delete value in memcache.
|
// delete value in memcache.
|
||||||
func (rc *MemcacheCache) Delete(key string) error {
|
func (rc *MemcacheCache) Delete(key string) error {
|
||||||
if rc.c == nil {
|
if rc.c == nil {
|
||||||
rc.c = rc.connectInit()
|
var err error
|
||||||
|
rc.c, err = rc.connectInit()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
_, err := rc.c.Delete(key)
|
_, err := rc.c.Delete(key)
|
||||||
return err
|
return err
|
||||||
@ -76,7 +88,11 @@ func (rc *MemcacheCache) Decr(key string) error {
|
|||||||
// check value exists in memcache.
|
// check value exists in memcache.
|
||||||
func (rc *MemcacheCache) IsExist(key string) bool {
|
func (rc *MemcacheCache) IsExist(key string) bool {
|
||||||
if rc.c == nil {
|
if rc.c == nil {
|
||||||
rc.c = rc.connectInit()
|
var err error
|
||||||
|
rc.c, err = rc.connectInit()
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
v, err := rc.c.Get(key)
|
v, err := rc.c.Get(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -93,7 +109,11 @@ func (rc *MemcacheCache) IsExist(key string) bool {
|
|||||||
// clear all cached in memcache.
|
// clear all cached in memcache.
|
||||||
func (rc *MemcacheCache) ClearAll() error {
|
func (rc *MemcacheCache) ClearAll() error {
|
||||||
if rc.c == nil {
|
if rc.c == nil {
|
||||||
rc.c = rc.connectInit()
|
var err error
|
||||||
|
rc.c, err = rc.connectInit()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
err := rc.c.FlushAll()
|
err := rc.c.FlushAll()
|
||||||
return err
|
return err
|
||||||
@ -109,20 +129,21 @@ func (rc *MemcacheCache) StartAndGC(config string) error {
|
|||||||
return errors.New("config has no conn key")
|
return errors.New("config has no conn key")
|
||||||
}
|
}
|
||||||
rc.conninfo = cf["conn"]
|
rc.conninfo = cf["conn"]
|
||||||
rc.c = rc.connectInit()
|
var err error
|
||||||
if rc.c == nil {
|
rc.c, err = rc.connectInit()
|
||||||
|
if err != nil {
|
||||||
return errors.New("dial tcp conn error")
|
return errors.New("dial tcp conn error")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// connect to memcache and keep the connection.
|
// connect to memcache and keep the connection.
|
||||||
func (rc *MemcacheCache) connectInit() *memcache.Connection {
|
func (rc *MemcacheCache) connectInit() (*memcache.Connection, error) {
|
||||||
c, err := memcache.Connect(rc.conninfo)
|
c, err := memcache.Connect(rc.conninfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil, err
|
||||||
}
|
}
|
||||||
return c
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
106
cache/redis.go
vendored
106
cache/redis.go
vendored
@ -3,6 +3,7 @@ package cache
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/beego/redigo/redis"
|
"github.com/beego/redigo/redis"
|
||||||
)
|
)
|
||||||
@ -14,7 +15,7 @@ var (
|
|||||||
|
|
||||||
// Redis cache adapter.
|
// Redis cache adapter.
|
||||||
type RedisCache struct {
|
type RedisCache struct {
|
||||||
c redis.Conn
|
p *redis.Pool // redis connection pool
|
||||||
conninfo string
|
conninfo string
|
||||||
key string
|
key string
|
||||||
}
|
}
|
||||||
@ -24,107 +25,62 @@ func NewRedisCache() *RedisCache {
|
|||||||
return &RedisCache{key: DefaultKey}
|
return &RedisCache{key: DefaultKey}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// actually do the redis cmds
|
||||||
|
func (rc *RedisCache) do(commandName string, args ...interface{}) (reply interface{}, err error) {
|
||||||
|
c := rc.p.Get()
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
return c.Do(commandName, args...)
|
||||||
|
}
|
||||||
|
|
||||||
// Get cache from redis.
|
// Get cache from redis.
|
||||||
func (rc *RedisCache) Get(key string) interface{} {
|
func (rc *RedisCache) Get(key string) interface{} {
|
||||||
if rc.c == nil {
|
v, err := rc.do("HGET", rc.key, key)
|
||||||
var err error
|
|
||||||
rc.c, err = rc.connectInit()
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
v, err := rc.c.Do("HGET", rc.key, key)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
// put cache to redis.
|
// put cache to redis.
|
||||||
// timeout is ignored.
|
// timeout is ignored.
|
||||||
func (rc *RedisCache) Put(key string, val interface{}, timeout int64) error {
|
func (rc *RedisCache) Put(key string, val interface{}, timeout int64) error {
|
||||||
if rc.c == nil {
|
_, err := rc.do("HSET", rc.key, key, val)
|
||||||
var err error
|
|
||||||
rc.c, err = rc.connectInit()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_, err := rc.c.Do("HSET", rc.key, key, val)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// delete cache in redis.
|
// delete cache in redis.
|
||||||
func (rc *RedisCache) Delete(key string) error {
|
func (rc *RedisCache) Delete(key string) error {
|
||||||
if rc.c == nil {
|
_, err := rc.do("HDEL", rc.key, key)
|
||||||
var err error
|
|
||||||
rc.c, err = rc.connectInit()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_, err := rc.c.Do("HDEL", rc.key, key)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// check cache exist in redis.
|
// check cache exist in redis.
|
||||||
func (rc *RedisCache) IsExist(key string) bool {
|
func (rc *RedisCache) IsExist(key string) bool {
|
||||||
if rc.c == nil {
|
v, err := redis.Bool(rc.do("HEXISTS", rc.key, key))
|
||||||
var err error
|
|
||||||
rc.c, err = rc.connectInit()
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
v, err := redis.Bool(rc.c.Do("HEXISTS", rc.key, key))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
// increase counter in redis.
|
// increase counter in redis.
|
||||||
func (rc *RedisCache) Incr(key string) error {
|
func (rc *RedisCache) Incr(key string) error {
|
||||||
if rc.c == nil {
|
_, err := redis.Bool(rc.do("HINCRBY", rc.key, key, 1))
|
||||||
var err error
|
|
||||||
rc.c, err = rc.connectInit()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
_, err := redis.Bool(rc.c.Do("HINCRBY", rc.key, key, 1))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// decrease counter in redis.
|
// decrease counter in redis.
|
||||||
func (rc *RedisCache) Decr(key string) error {
|
func (rc *RedisCache) Decr(key string) error {
|
||||||
if rc.c == nil {
|
_, err := redis.Bool(rc.do("HINCRBY", rc.key, key, -1))
|
||||||
var err error
|
|
||||||
rc.c, err = rc.connectInit()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
_, err := redis.Bool(rc.c.Do("HINCRBY", rc.key, key, -1))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// clean all cache in redis. delete this redis collection.
|
// clean all cache in redis. delete this redis collection.
|
||||||
func (rc *RedisCache) ClearAll() error {
|
func (rc *RedisCache) ClearAll() error {
|
||||||
if rc.c == nil {
|
_, err := rc.do("DEL", rc.key)
|
||||||
var err error
|
|
||||||
rc.c, err = rc.connectInit()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_, err := rc.c.Do("DEL", rc.key)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -135,32 +91,42 @@ func (rc *RedisCache) ClearAll() error {
|
|||||||
func (rc *RedisCache) StartAndGC(config string) error {
|
func (rc *RedisCache) StartAndGC(config string) error {
|
||||||
var cf map[string]string
|
var cf map[string]string
|
||||||
json.Unmarshal([]byte(config), &cf)
|
json.Unmarshal([]byte(config), &cf)
|
||||||
|
|
||||||
if _, ok := cf["key"]; !ok {
|
if _, ok := cf["key"]; !ok {
|
||||||
cf["key"] = DefaultKey
|
cf["key"] = DefaultKey
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := cf["conn"]; !ok {
|
if _, ok := cf["conn"]; !ok {
|
||||||
return errors.New("config has no conn key")
|
return errors.New("config has no conn key")
|
||||||
}
|
}
|
||||||
|
|
||||||
rc.key = cf["key"]
|
rc.key = cf["key"]
|
||||||
rc.conninfo = cf["conn"]
|
rc.conninfo = cf["conn"]
|
||||||
var err error
|
rc.connectInit()
|
||||||
rc.c, err = rc.connectInit()
|
|
||||||
if err != nil {
|
c := rc.p.Get()
|
||||||
|
defer c.Close()
|
||||||
|
if err := c.Err(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if rc.c == nil {
|
|
||||||
return errors.New("dial tcp conn error")
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// connect to redis.
|
// connect to redis.
|
||||||
func (rc *RedisCache) connectInit() (redis.Conn, error) {
|
func (rc *RedisCache) connectInit() {
|
||||||
|
// initialize a new pool
|
||||||
|
rc.p = &redis.Pool{
|
||||||
|
MaxIdle: 3,
|
||||||
|
IdleTimeout: 180 * time.Second,
|
||||||
|
Dial: func() (redis.Conn, error) {
|
||||||
c, err := redis.Dial("tcp", rc.conninfo)
|
c, err := redis.Dial("tcp", rc.conninfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return c, nil
|
return c, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -40,6 +40,7 @@ var (
|
|||||||
SessionHashFunc string // session hash generation func.
|
SessionHashFunc string // session hash generation func.
|
||||||
SessionHashKey string // session hash salt string.
|
SessionHashKey string // session hash salt string.
|
||||||
SessionCookieLifeTime int // the life time of session id in cookie.
|
SessionCookieLifeTime int // the life time of session id in cookie.
|
||||||
|
SessionAutoSetCookie bool // auto setcookie
|
||||||
UseFcgi bool
|
UseFcgi bool
|
||||||
MaxMemory int64
|
MaxMemory int64
|
||||||
EnableGzip bool // flag of enable gzip
|
EnableGzip bool // flag of enable gzip
|
||||||
@ -96,6 +97,7 @@ func init() {
|
|||||||
SessionHashFunc = "sha1"
|
SessionHashFunc = "sha1"
|
||||||
SessionHashKey = "beegoserversessionkey"
|
SessionHashKey = "beegoserversessionkey"
|
||||||
SessionCookieLifeTime = 0 //set cookie default is the brower life
|
SessionCookieLifeTime = 0 //set cookie default is the brower life
|
||||||
|
SessionAutoSetCookie = true
|
||||||
|
|
||||||
UseFcgi = false
|
UseFcgi = false
|
||||||
|
|
||||||
@ -139,6 +141,7 @@ func init() {
|
|||||||
func ParseConfig() (err error) {
|
func ParseConfig() (err error) {
|
||||||
AppConfig, err = config.NewConfig("ini", AppConfigPath)
|
AppConfig, err = config.NewConfig("ini", AppConfigPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
AppConfig = config.NewFakeConfig()
|
||||||
return err
|
return err
|
||||||
} else {
|
} else {
|
||||||
HttpAddr = AppConfig.String("HttpAddr")
|
HttpAddr = AppConfig.String("HttpAddr")
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
type ConfigContainer interface {
|
type ConfigContainer interface {
|
||||||
Set(key, val string) error // support section::key type in given key when using ini type.
|
Set(key, val string) error // support section::key type in given key when using ini type.
|
||||||
String(key string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
|
String(key string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
|
||||||
|
Strings(key string) []string //get string slice
|
||||||
Int(key string) (int, error)
|
Int(key string) (int, error)
|
||||||
Int64(key string) (int64, error)
|
Int64(key string) (int64, error)
|
||||||
Bool(key string) (bool, error)
|
Bool(key string) (bool, error)
|
||||||
|
62
config/fake.go
Normal file
62
config/fake.go
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeConfigContainer struct {
|
||||||
|
data map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeConfigContainer) getData(key string) string {
|
||||||
|
key = strings.ToLower(key)
|
||||||
|
return c.data[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeConfigContainer) Set(key, val string) error {
|
||||||
|
key = strings.ToLower(key)
|
||||||
|
c.data[key] = val
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeConfigContainer) String(key string) string {
|
||||||
|
return c.getData(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeConfigContainer) Strings(key string) []string {
|
||||||
|
return strings.Split(c.getData(key), ";")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeConfigContainer) Int(key string) (int, error) {
|
||||||
|
return strconv.Atoi(c.getData(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeConfigContainer) Int64(key string) (int64, error) {
|
||||||
|
return strconv.ParseInt(c.getData(key), 10, 64)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeConfigContainer) Bool(key string) (bool, error) {
|
||||||
|
return strconv.ParseBool(c.getData(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeConfigContainer) Float(key string) (float64, error) {
|
||||||
|
return strconv.ParseFloat(c.getData(key), 64)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeConfigContainer) DIY(key string) (interface{}, error) {
|
||||||
|
key = strings.ToLower(key)
|
||||||
|
if v, ok := c.data[key]; ok {
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
return nil, errors.New("key not find")
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ ConfigContainer = new(fakeConfigContainer)
|
||||||
|
|
||||||
|
func NewFakeConfig() ConfigContainer {
|
||||||
|
return &fakeConfigContainer{
|
||||||
|
data: make(map[string]string),
|
||||||
|
}
|
||||||
|
}
|
@ -146,6 +146,11 @@ func (c *IniConfigContainer) String(key string) string {
|
|||||||
return c.getdata(key)
|
return c.getdata(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Strings returns the []string value for a given key.
|
||||||
|
func (c *IniConfigContainer) Strings(key string) []string {
|
||||||
|
return strings.Split(c.String(key), ";")
|
||||||
|
}
|
||||||
|
|
||||||
// WriteValue writes a new value for key.
|
// WriteValue writes a new value for key.
|
||||||
// if write to one section, the key need be "section::key".
|
// if write to one section, the key need be "section::key".
|
||||||
// if the section is not existed, it panics.
|
// if the section is not existed, it panics.
|
||||||
|
@ -19,6 +19,7 @@ copyrequestbody = true
|
|||||||
key1="asta"
|
key1="asta"
|
||||||
key2 = "xie"
|
key2 = "xie"
|
||||||
CaseInsensitive = true
|
CaseInsensitive = true
|
||||||
|
peers = one;two;three
|
||||||
`
|
`
|
||||||
|
|
||||||
func TestIni(t *testing.T) {
|
func TestIni(t *testing.T) {
|
||||||
@ -78,4 +79,11 @@ func TestIni(t *testing.T) {
|
|||||||
if v, err := iniconf.Bool("demo::caseinsensitive"); err != nil || v != true {
|
if v, err := iniconf.Bool("demo::caseinsensitive"); err != nil || v != true {
|
||||||
t.Fatal("get demo.caseinsensitive error")
|
t.Fatal("get demo.caseinsensitive error")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if data := iniconf.Strings("demo::peers"); len(data) != 3 {
|
||||||
|
t.Fatal("get strings error", data)
|
||||||
|
} else if data[0] != "one" {
|
||||||
|
t.Fatal("get first params error not equat to one")
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -116,6 +116,11 @@ func (c *JsonConfigContainer) String(key string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Strings returns the []string value for a given key.
|
||||||
|
func (c *JsonConfigContainer) Strings(key string) []string {
|
||||||
|
return strings.Split(c.String(key), ";")
|
||||||
|
}
|
||||||
|
|
||||||
// WriteValue writes a new value for key.
|
// WriteValue writes a new value for key.
|
||||||
func (c *JsonConfigContainer) Set(key, val string) error {
|
func (c *JsonConfigContainer) Set(key, val string) error {
|
||||||
c.Lock()
|
c.Lock()
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/beego/x2j"
|
"github.com/beego/x2j"
|
||||||
@ -72,6 +73,11 @@ func (c *XMLConfigContainer) String(key string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Strings returns the []string value for a given key.
|
||||||
|
func (c *XMLConfigContainer) Strings(key string) []string {
|
||||||
|
return strings.Split(c.String(key), ";")
|
||||||
|
}
|
||||||
|
|
||||||
// WriteValue writes a new value for key.
|
// WriteValue writes a new value for key.
|
||||||
func (c *XMLConfigContainer) Set(key, val string) error {
|
func (c *XMLConfigContainer) Set(key, val string) error {
|
||||||
c.Lock()
|
c.Lock()
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/beego/goyaml2"
|
"github.com/beego/goyaml2"
|
||||||
@ -117,6 +118,11 @@ func (c *YAMLConfigContainer) String(key string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Strings returns the []string value for a given key.
|
||||||
|
func (c *YAMLConfigContainer) Strings(key string) []string {
|
||||||
|
return strings.Split(c.String(key), ";")
|
||||||
|
}
|
||||||
|
|
||||||
// WriteValue writes a new value for key.
|
// WriteValue writes a new value for key.
|
||||||
func (c *YAMLConfigContainer) Set(key, val string) error {
|
func (c *YAMLConfigContainer) Set(key, val string) error {
|
||||||
c.Lock()
|
c.Lock()
|
||||||
|
@ -3,7 +3,6 @@ package beego
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
"crypto/rand"
|
|
||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
@ -22,6 +21,7 @@ import (
|
|||||||
|
|
||||||
"github.com/astaxie/beego/context"
|
"github.com/astaxie/beego/context"
|
||||||
"github.com/astaxie/beego/session"
|
"github.com/astaxie/beego/session"
|
||||||
|
"github.com/astaxie/beego/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -45,6 +45,7 @@ type Controller struct {
|
|||||||
CruSession session.SessionStore
|
CruSession session.SessionStore
|
||||||
XSRFExpire int
|
XSRFExpire int
|
||||||
AppController interface{}
|
AppController interface{}
|
||||||
|
EnableReander bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ControllerInterface is an interface to uniform all controller handler.
|
// ControllerInterface is an interface to uniform all controller handler.
|
||||||
@ -74,6 +75,8 @@ func (c *Controller) Init(ctx *context.Context, controllerName, actionName strin
|
|||||||
c.Ctx = ctx
|
c.Ctx = ctx
|
||||||
c.TplExt = "tpl"
|
c.TplExt = "tpl"
|
||||||
c.AppController = app
|
c.AppController = app
|
||||||
|
c.EnableReander = true
|
||||||
|
c.Data = ctx.Input.Data
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare runs after Init before request function execution.
|
// Prepare runs after Init before request function execution.
|
||||||
@ -123,6 +126,9 @@ func (c *Controller) Options() {
|
|||||||
|
|
||||||
// Render sends the response with rendered template bytes as text/html type.
|
// Render sends the response with rendered template bytes as text/html type.
|
||||||
func (c *Controller) Render() error {
|
func (c *Controller) Render() error {
|
||||||
|
if !c.EnableReander {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
rb, err := c.RenderBytes()
|
rb, err := c.RenderBytes()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -140,7 +146,7 @@ func (c *Controller) RenderString() (string, error) {
|
|||||||
return string(b), e
|
return string(b), e
|
||||||
}
|
}
|
||||||
|
|
||||||
// RenderBytes returns the bytes of renderd tempate string. Do not send out response.
|
// RenderBytes returns the bytes of rendered template string. Do not send out response.
|
||||||
func (c *Controller) RenderBytes() ([]byte, error) {
|
func (c *Controller) RenderBytes() ([]byte, error) {
|
||||||
//if the controller has set layout, then first get the tplname's content set the content to the layout
|
//if the controller has set layout, then first get the tplname's content set the content to the layout
|
||||||
if c.Layout != "" {
|
if c.Layout != "" {
|
||||||
@ -165,7 +171,7 @@ func (c *Controller) RenderBytes() ([]byte, error) {
|
|||||||
|
|
||||||
if c.LayoutSections != nil {
|
if c.LayoutSections != nil {
|
||||||
for sectionName, sectionTpl := range c.LayoutSections {
|
for sectionName, sectionTpl := range c.LayoutSections {
|
||||||
if (sectionTpl == "") {
|
if sectionTpl == "" {
|
||||||
c.Data[sectionName] = ""
|
c.Data[sectionName] = ""
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -391,12 +397,14 @@ func (c *Controller) DelSession(name interface{}) {
|
|||||||
// SessionRegenerateID regenerates session id for this session.
|
// SessionRegenerateID regenerates session id for this session.
|
||||||
// the session data have no changes.
|
// the session data have no changes.
|
||||||
func (c *Controller) SessionRegenerateID() {
|
func (c *Controller) SessionRegenerateID() {
|
||||||
|
c.CruSession.SessionRelease(c.Ctx.ResponseWriter)
|
||||||
c.CruSession = GlobalSessions.SessionRegenerateId(c.Ctx.ResponseWriter, c.Ctx.Request)
|
c.CruSession = GlobalSessions.SessionRegenerateId(c.Ctx.ResponseWriter, c.Ctx.Request)
|
||||||
c.Ctx.Input.CruSession = c.CruSession
|
c.Ctx.Input.CruSession = c.CruSession
|
||||||
}
|
}
|
||||||
|
|
||||||
// DestroySession cleans session data and session cookie.
|
// DestroySession cleans session data and session cookie.
|
||||||
func (c *Controller) DestroySession() {
|
func (c *Controller) DestroySession() {
|
||||||
|
c.Ctx.Input.CruSession.Flush()
|
||||||
GlobalSessions.SessionDestroy(c.Ctx.ResponseWriter, c.Ctx.Request)
|
GlobalSessions.SessionDestroy(c.Ctx.ResponseWriter, c.Ctx.Request)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -454,7 +462,7 @@ func (c *Controller) XsrfToken() string {
|
|||||||
} else {
|
} else {
|
||||||
expire = int64(XSRFExpire)
|
expire = int64(XSRFExpire)
|
||||||
}
|
}
|
||||||
token = getRandomString(15)
|
token = string(utils.RandomCreateBytes(15))
|
||||||
c.SetSecureCookie(XSRFKEY, "_xsrf", token, expire)
|
c.SetSecureCookie(XSRFKEY, "_xsrf", token, expire)
|
||||||
}
|
}
|
||||||
c._xsrf_token = token
|
c._xsrf_token = token
|
||||||
@ -491,14 +499,3 @@ func (c *Controller) XsrfFormHtml() string {
|
|||||||
func (c *Controller) GetControllerAndAction() (controllerName, actionName string) {
|
func (c *Controller) GetControllerAndAction() (controllerName, actionName string) {
|
||||||
return c.controllerName, c.actionName
|
return c.controllerName, c.actionName
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRandomString returns random string.
|
|
||||||
func getRandomString(n int) string {
|
|
||||||
const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
|
||||||
var bytes = make([]byte, n)
|
|
||||||
rand.Read(bytes)
|
|
||||||
for i, b := range bytes {
|
|
||||||
bytes[i] = alphanum[b%byte(len(alphanum))]
|
|
||||||
}
|
|
||||||
return string(bytes)
|
|
||||||
}
|
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
package controllers
|
package controllers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/astaxie/beego"
|
|
||||||
"github.com/garyburd/go-websocket/websocket"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/astaxie/beego"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
17
filter.go
17
filter.go
@ -28,6 +28,12 @@ func (mr *FilterRouter) ValidRouter(router string) (bool, map[string]string) {
|
|||||||
if router == mr.pattern {
|
if router == mr.pattern {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
//pattern /admin router /admin/ match
|
||||||
|
//pattern /admin/ router /admin don't match, because url will 301 in router
|
||||||
|
if n := len(router); n > 1 && router[n-1] == '/' && router[:n-2] == mr.pattern {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
if mr.hasregex {
|
if mr.hasregex {
|
||||||
if !mr.regex.MatchString(router) {
|
if !mr.regex.MatchString(router) {
|
||||||
return false, nil
|
return false, nil
|
||||||
@ -46,7 +52,7 @@ func (mr *FilterRouter) ValidRouter(router string) (bool, map[string]string) {
|
|||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
|
func buildFilter(pattern string, filter FilterFunc) (*FilterRouter, error) {
|
||||||
mr := new(FilterRouter)
|
mr := new(FilterRouter)
|
||||||
mr.params = make(map[int]string)
|
mr.params = make(map[int]string)
|
||||||
mr.filterFunc = filter
|
mr.filterFunc = filter
|
||||||
@ -54,7 +60,7 @@ func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
|
|||||||
j := 0
|
j := 0
|
||||||
for i, part := range parts {
|
for i, part := range parts {
|
||||||
if strings.HasPrefix(part, ":") {
|
if strings.HasPrefix(part, ":") {
|
||||||
expr := "(.+)"
|
expr := "(.*)"
|
||||||
//a user may choose to override the default expression
|
//a user may choose to override the default expression
|
||||||
// similar to expressjs: ‘/user/:id([0-9]+)’
|
// similar to expressjs: ‘/user/:id([0-9]+)’
|
||||||
if index := strings.Index(part, "("); index != -1 {
|
if index := strings.Index(part, "("); index != -1 {
|
||||||
@ -77,7 +83,7 @@ func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
|
|||||||
j++
|
j++
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(part, "*") {
|
if strings.HasPrefix(part, "*") {
|
||||||
expr := "(.+)"
|
expr := "(.*)"
|
||||||
if part == "*.*" {
|
if part == "*.*" {
|
||||||
mr.params[j] = ":path"
|
mr.params[j] = ":path"
|
||||||
parts[i] = "([^.]+).([^.]+)"
|
parts[i] = "([^.]+).([^.]+)"
|
||||||
@ -137,12 +143,11 @@ func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
|
|||||||
pattern = strings.Join(parts, "/")
|
pattern = strings.Join(parts, "/")
|
||||||
regex, regexErr := regexp.Compile(pattern)
|
regex, regexErr := regexp.Compile(pattern)
|
||||||
if regexErr != nil {
|
if regexErr != nil {
|
||||||
//TODO add error handling here to avoid panic
|
return nil, regexErr
|
||||||
panic(regexErr)
|
|
||||||
}
|
}
|
||||||
mr.regex = regex
|
mr.regex = regex
|
||||||
mr.hasregex = true
|
mr.hasregex = true
|
||||||
}
|
}
|
||||||
mr.pattern = pattern
|
mr.pattern = pattern
|
||||||
return mr
|
return mr, nil
|
||||||
}
|
}
|
||||||
|
@ -23,3 +23,32 @@ func TestFilter(t *testing.T) {
|
|||||||
t.Errorf("user define func can't run")
|
t.Errorf("user define func can't run")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var FilterAdminUser = func(ctx *context.Context) {
|
||||||
|
ctx.Output.Body([]byte("i am admin"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter pattern /admin/:all
|
||||||
|
// all url like /admin/ /admin/xie will all get filter
|
||||||
|
|
||||||
|
func TestPatternTwo(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "/admin/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler := NewControllerRegistor()
|
||||||
|
handler.AddFilter("/admin/:all", "AfterStatic", FilterAdminUser)
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != "i am admin" {
|
||||||
|
t.Errorf("filter /admin/ can't run")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPatternThree(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "/admin/astaxie", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler := NewControllerRegistor()
|
||||||
|
handler.AddFilter("/admin/:all", "AfterStatic", FilterAdminUser)
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != "i am admin" {
|
||||||
|
t.Errorf("filter /admin/astaxie can't run")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -7,6 +7,8 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ConnWriter implements LoggerInterface.
|
||||||
|
// it writes messages in keep-live tcp connection.
|
||||||
type ConnWriter struct {
|
type ConnWriter struct {
|
||||||
lg *log.Logger
|
lg *log.Logger
|
||||||
innerWriter io.WriteCloser
|
innerWriter io.WriteCloser
|
||||||
@ -17,12 +19,15 @@ type ConnWriter struct {
|
|||||||
Level int `json:"level"`
|
Level int `json:"level"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create new ConnWrite returning as LoggerInterface.
|
||||||
func NewConn() LoggerInterface {
|
func NewConn() LoggerInterface {
|
||||||
conn := new(ConnWriter)
|
conn := new(ConnWriter)
|
||||||
conn.Level = LevelTrace
|
conn.Level = LevelTrace
|
||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// init connection writer with json config.
|
||||||
|
// json config only need key "level".
|
||||||
func (c *ConnWriter) Init(jsonconfig string) error {
|
func (c *ConnWriter) Init(jsonconfig string) error {
|
||||||
err := json.Unmarshal([]byte(jsonconfig), c)
|
err := json.Unmarshal([]byte(jsonconfig), c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -31,6 +36,8 @@ func (c *ConnWriter) Init(jsonconfig string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// write message in connection.
|
||||||
|
// if connection is down, try to re-connect.
|
||||||
func (c *ConnWriter) WriteMsg(msg string, level int) error {
|
func (c *ConnWriter) WriteMsg(msg string, level int) error {
|
||||||
if level < c.Level {
|
if level < c.Level {
|
||||||
return nil
|
return nil
|
||||||
@ -49,10 +56,12 @@ func (c *ConnWriter) WriteMsg(msg string, level int) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// implementing method. empty.
|
||||||
func (c *ConnWriter) Flush() {
|
func (c *ConnWriter) Flush() {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// destroy connection writer and close tcp listener.
|
||||||
func (c *ConnWriter) Destroy() {
|
func (c *ConnWriter) Destroy() {
|
||||||
if c.innerWriter == nil {
|
if c.innerWriter == nil {
|
||||||
return
|
return
|
||||||
|
@ -4,13 +4,35 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Brush func(string) string
|
||||||
|
|
||||||
|
func NewBrush(color string) Brush {
|
||||||
|
pre := "\033["
|
||||||
|
reset := "\033[0m"
|
||||||
|
return func(text string) string {
|
||||||
|
return pre + color + "m" + text + reset
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var colors = []Brush{
|
||||||
|
NewBrush("1;36"), // Trace cyan
|
||||||
|
NewBrush("1;34"), // Debug blue
|
||||||
|
NewBrush("1;32"), // Info green
|
||||||
|
NewBrush("1;33"), // Warn yellow
|
||||||
|
NewBrush("1;31"), // Error red
|
||||||
|
NewBrush("1;35"), // Critical purple
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConsoleWriter implements LoggerInterface and writes messages to terminal.
|
||||||
type ConsoleWriter struct {
|
type ConsoleWriter struct {
|
||||||
lg *log.Logger
|
lg *log.Logger
|
||||||
Level int `json:"level"`
|
Level int `json:"level"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create ConsoleWriter returning as LoggerInterface.
|
||||||
func NewConsole() LoggerInterface {
|
func NewConsole() LoggerInterface {
|
||||||
cw := new(ConsoleWriter)
|
cw := new(ConsoleWriter)
|
||||||
cw.lg = log.New(os.Stdout, "", log.Ldate|log.Ltime)
|
cw.lg = log.New(os.Stdout, "", log.Ldate|log.Ltime)
|
||||||
@ -18,6 +40,8 @@ func NewConsole() LoggerInterface {
|
|||||||
return cw
|
return cw
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// init console logger.
|
||||||
|
// jsonconfig like '{"level":LevelTrace}'.
|
||||||
func (c *ConsoleWriter) Init(jsonconfig string) error {
|
func (c *ConsoleWriter) Init(jsonconfig string) error {
|
||||||
err := json.Unmarshal([]byte(jsonconfig), c)
|
err := json.Unmarshal([]byte(jsonconfig), c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -26,18 +50,25 @@ func (c *ConsoleWriter) Init(jsonconfig string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// write message in console.
|
||||||
func (c *ConsoleWriter) WriteMsg(msg string, level int) error {
|
func (c *ConsoleWriter) WriteMsg(msg string, level int) error {
|
||||||
if level < c.Level {
|
if level < c.Level {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
if goos := runtime.GOOS; goos == "windows" {
|
||||||
c.lg.Println(msg)
|
c.lg.Println(msg)
|
||||||
|
} else {
|
||||||
|
c.lg.Println(colors[level](msg))
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// implementing method. empty.
|
||||||
func (c *ConsoleWriter) Destroy() {
|
func (c *ConsoleWriter) Destroy() {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// implementing method. empty.
|
||||||
func (c *ConsoleWriter) Flush() {
|
func (c *ConsoleWriter) Flush() {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
17
logs/file.go
17
logs/file.go
@ -13,6 +13,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// FileLogWriter implements LoggerInterface.
|
||||||
|
// It writes messages by lines limit, file size limit, or time frequency.
|
||||||
type FileLogWriter struct {
|
type FileLogWriter struct {
|
||||||
*log.Logger
|
*log.Logger
|
||||||
mw *MuxWriter
|
mw *MuxWriter
|
||||||
@ -38,17 +40,20 @@ type FileLogWriter struct {
|
|||||||
Level int `json:"level"`
|
Level int `json:"level"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// an *os.File writer with locker.
|
||||||
type MuxWriter struct {
|
type MuxWriter struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
fd *os.File
|
fd *os.File
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// write to os.File.
|
||||||
func (l *MuxWriter) Write(b []byte) (int, error) {
|
func (l *MuxWriter) Write(b []byte) (int, error) {
|
||||||
l.Lock()
|
l.Lock()
|
||||||
defer l.Unlock()
|
defer l.Unlock()
|
||||||
return l.fd.Write(b)
|
return l.fd.Write(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set os.File in writer.
|
||||||
func (l *MuxWriter) SetFd(fd *os.File) {
|
func (l *MuxWriter) SetFd(fd *os.File) {
|
||||||
if l.fd != nil {
|
if l.fd != nil {
|
||||||
l.fd.Close()
|
l.fd.Close()
|
||||||
@ -56,6 +61,7 @@ func (l *MuxWriter) SetFd(fd *os.File) {
|
|||||||
l.fd = fd
|
l.fd = fd
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create a FileLogWriter returning as LoggerInterface.
|
||||||
func NewFileWriter() LoggerInterface {
|
func NewFileWriter() LoggerInterface {
|
||||||
w := &FileLogWriter{
|
w := &FileLogWriter{
|
||||||
Filename: "",
|
Filename: "",
|
||||||
@ -73,7 +79,8 @@ func NewFileWriter() LoggerInterface {
|
|||||||
return w
|
return w
|
||||||
}
|
}
|
||||||
|
|
||||||
// jsonconfig like this
|
// Init file logger with json config.
|
||||||
|
// jsonconfig like:
|
||||||
// {
|
// {
|
||||||
// "filename":"logs/beego.log",
|
// "filename":"logs/beego.log",
|
||||||
// "maxlines":10000,
|
// "maxlines":10000,
|
||||||
@ -94,6 +101,7 @@ func (w *FileLogWriter) Init(jsonconfig string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// start file logger. create log file and set to locker-inside file writer.
|
||||||
func (w *FileLogWriter) StartLogger() error {
|
func (w *FileLogWriter) StartLogger() error {
|
||||||
fd, err := w.createLogFile()
|
fd, err := w.createLogFile()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -122,6 +130,7 @@ func (w *FileLogWriter) docheck(size int) {
|
|||||||
w.maxsize_cursize += size
|
w.maxsize_cursize += size
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// write logger message into file.
|
||||||
func (w *FileLogWriter) WriteMsg(msg string, level int) error {
|
func (w *FileLogWriter) WriteMsg(msg string, level int) error {
|
||||||
if level < w.Level {
|
if level < w.Level {
|
||||||
return nil
|
return nil
|
||||||
@ -158,6 +167,8 @@ func (w *FileLogWriter) initFd() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DoRotate means it need to write file in new file.
|
||||||
|
// new file name like xx.log.2013-01-01.2
|
||||||
func (w *FileLogWriter) DoRotate() error {
|
func (w *FileLogWriter) DoRotate() error {
|
||||||
_, err := os.Lstat(w.Filename)
|
_, err := os.Lstat(w.Filename)
|
||||||
if err == nil { // file exists
|
if err == nil { // file exists
|
||||||
@ -211,10 +222,14 @@ func (w *FileLogWriter) deleteOldLog() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// destroy file logger, close file writer.
|
||||||
func (w *FileLogWriter) Destroy() {
|
func (w *FileLogWriter) Destroy() {
|
||||||
w.mw.fd.Close()
|
w.mw.fd.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// flush file logger.
|
||||||
|
// there are no buffering messages in file logger in memory.
|
||||||
|
// flush file means sync file from disk.
|
||||||
func (w *FileLogWriter) Flush() {
|
func (w *FileLogWriter) Flush() {
|
||||||
w.mw.fd.Sync()
|
w.mw.fd.Sync()
|
||||||
}
|
}
|
||||||
|
24
logs/log.go
24
logs/log.go
@ -6,6 +6,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// log message levels
|
||||||
LevelTrace = iota
|
LevelTrace = iota
|
||||||
LevelDebug
|
LevelDebug
|
||||||
LevelInfo
|
LevelInfo
|
||||||
@ -16,6 +17,7 @@ const (
|
|||||||
|
|
||||||
type loggerType func() LoggerInterface
|
type loggerType func() LoggerInterface
|
||||||
|
|
||||||
|
// LoggerInterface defines the behavior of a log provider.
|
||||||
type LoggerInterface interface {
|
type LoggerInterface interface {
|
||||||
Init(config string) error
|
Init(config string) error
|
||||||
WriteMsg(msg string, level int) error
|
WriteMsg(msg string, level int) error
|
||||||
@ -38,6 +40,8 @@ func Register(name string, log loggerType) {
|
|||||||
adapters[name] = log
|
adapters[name] = log
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BeeLogger is default logger in beego application.
|
||||||
|
// it can contain several providers and log message into all providers.
|
||||||
type BeeLogger struct {
|
type BeeLogger struct {
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
level int
|
level int
|
||||||
@ -50,7 +54,9 @@ type logMsg struct {
|
|||||||
msg string
|
msg string
|
||||||
}
|
}
|
||||||
|
|
||||||
// config need to be correct JSON as string: {"interval":360}
|
// NewLogger returns a new BeeLogger.
|
||||||
|
// channellen means the number of messages in chan.
|
||||||
|
// if the buffering chan is full, logger adapters write to file or other way.
|
||||||
func NewLogger(channellen int64) *BeeLogger {
|
func NewLogger(channellen int64) *BeeLogger {
|
||||||
bl := new(BeeLogger)
|
bl := new(BeeLogger)
|
||||||
bl.msg = make(chan *logMsg, channellen)
|
bl.msg = make(chan *logMsg, channellen)
|
||||||
@ -60,6 +66,8 @@ func NewLogger(channellen int64) *BeeLogger {
|
|||||||
return bl
|
return bl
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetLogger provides a given logger adapter into BeeLogger with config string.
|
||||||
|
// config need to be correct JSON as string: {"interval":360}.
|
||||||
func (bl *BeeLogger) SetLogger(adaptername string, config string) error {
|
func (bl *BeeLogger) SetLogger(adaptername string, config string) error {
|
||||||
bl.lock.Lock()
|
bl.lock.Lock()
|
||||||
defer bl.lock.Unlock()
|
defer bl.lock.Unlock()
|
||||||
@ -73,6 +81,7 @@ func (bl *BeeLogger) SetLogger(adaptername string, config string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// remove a logger adapter in BeeLogger.
|
||||||
func (bl *BeeLogger) DelLogger(adaptername string) error {
|
func (bl *BeeLogger) DelLogger(adaptername string) error {
|
||||||
bl.lock.Lock()
|
bl.lock.Lock()
|
||||||
defer bl.lock.Unlock()
|
defer bl.lock.Unlock()
|
||||||
@ -96,10 +105,14 @@ func (bl *BeeLogger) writerMsg(loglevel int, msg string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set log message level.
|
||||||
|
// if message level (such as LevelTrace) is less than logger level (such as LevelWarn), ignore message.
|
||||||
func (bl *BeeLogger) SetLevel(l int) {
|
func (bl *BeeLogger) SetLevel(l int) {
|
||||||
bl.level = l
|
bl.level = l
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// start logger chan reading.
|
||||||
|
// when chan is full, write logs.
|
||||||
func (bl *BeeLogger) StartLogger() {
|
func (bl *BeeLogger) StartLogger() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@ -111,43 +124,50 @@ func (bl *BeeLogger) StartLogger() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// log trace level message.
|
||||||
func (bl *BeeLogger) Trace(format string, v ...interface{}) {
|
func (bl *BeeLogger) Trace(format string, v ...interface{}) {
|
||||||
msg := fmt.Sprintf("[T] "+format, v...)
|
msg := fmt.Sprintf("[T] "+format, v...)
|
||||||
bl.writerMsg(LevelTrace, msg)
|
bl.writerMsg(LevelTrace, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// log debug level message.
|
||||||
func (bl *BeeLogger) Debug(format string, v ...interface{}) {
|
func (bl *BeeLogger) Debug(format string, v ...interface{}) {
|
||||||
msg := fmt.Sprintf("[D] "+format, v...)
|
msg := fmt.Sprintf("[D] "+format, v...)
|
||||||
bl.writerMsg(LevelDebug, msg)
|
bl.writerMsg(LevelDebug, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// log info level message.
|
||||||
func (bl *BeeLogger) Info(format string, v ...interface{}) {
|
func (bl *BeeLogger) Info(format string, v ...interface{}) {
|
||||||
msg := fmt.Sprintf("[I] "+format, v...)
|
msg := fmt.Sprintf("[I] "+format, v...)
|
||||||
bl.writerMsg(LevelInfo, msg)
|
bl.writerMsg(LevelInfo, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// log warn level message.
|
||||||
func (bl *BeeLogger) Warn(format string, v ...interface{}) {
|
func (bl *BeeLogger) Warn(format string, v ...interface{}) {
|
||||||
msg := fmt.Sprintf("[W] "+format, v...)
|
msg := fmt.Sprintf("[W] "+format, v...)
|
||||||
bl.writerMsg(LevelWarn, msg)
|
bl.writerMsg(LevelWarn, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// log error level message.
|
||||||
func (bl *BeeLogger) Error(format string, v ...interface{}) {
|
func (bl *BeeLogger) Error(format string, v ...interface{}) {
|
||||||
msg := fmt.Sprintf("[E] "+format, v...)
|
msg := fmt.Sprintf("[E] "+format, v...)
|
||||||
bl.writerMsg(LevelError, msg)
|
bl.writerMsg(LevelError, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// log critical level message.
|
||||||
func (bl *BeeLogger) Critical(format string, v ...interface{}) {
|
func (bl *BeeLogger) Critical(format string, v ...interface{}) {
|
||||||
msg := fmt.Sprintf("[C] "+format, v...)
|
msg := fmt.Sprintf("[C] "+format, v...)
|
||||||
bl.writerMsg(LevelCritical, msg)
|
bl.writerMsg(LevelCritical, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
//flush all chan data
|
// flush all chan data.
|
||||||
func (bl *BeeLogger) Flush() {
|
func (bl *BeeLogger) Flush() {
|
||||||
for _, l := range bl.outputs {
|
for _, l := range bl.outputs {
|
||||||
l.Flush()
|
l.Flush()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// close logger, flush all chan data and destroy all adapters in BeeLogger.
|
||||||
func (bl *BeeLogger) Close() {
|
func (bl *BeeLogger) Close() {
|
||||||
for {
|
for {
|
||||||
if len(bl.msg) > 0 {
|
if len(bl.msg) > 0 {
|
||||||
|
18
logs/smtp.go
18
logs/smtp.go
@ -12,7 +12,7 @@ const (
|
|||||||
subjectPhrase = "Diagnostic message from server"
|
subjectPhrase = "Diagnostic message from server"
|
||||||
)
|
)
|
||||||
|
|
||||||
// smtpWriter is used to send emails via given SMTP-server.
|
// smtpWriter implements LoggerInterface and is used to send emails via given SMTP-server.
|
||||||
type SmtpWriter struct {
|
type SmtpWriter struct {
|
||||||
Username string `json:"Username"`
|
Username string `json:"Username"`
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
@ -22,10 +22,21 @@ type SmtpWriter struct {
|
|||||||
Level int `json:"level"`
|
Level int `json:"level"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create smtp writer.
|
||||||
func NewSmtpWriter() LoggerInterface {
|
func NewSmtpWriter() LoggerInterface {
|
||||||
return &SmtpWriter{Level: LevelTrace}
|
return &SmtpWriter{Level: LevelTrace}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// init smtp writer with json config.
|
||||||
|
// config like:
|
||||||
|
// {
|
||||||
|
// "Username":"example@gmail.com",
|
||||||
|
// "password:"password",
|
||||||
|
// "host":"smtp.gmail.com:465",
|
||||||
|
// "subject":"email title",
|
||||||
|
// "sendTos":["email1","email2"],
|
||||||
|
// "level":LevelError
|
||||||
|
// }
|
||||||
func (s *SmtpWriter) Init(jsonconfig string) error {
|
func (s *SmtpWriter) Init(jsonconfig string) error {
|
||||||
err := json.Unmarshal([]byte(jsonconfig), s)
|
err := json.Unmarshal([]byte(jsonconfig), s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -34,6 +45,8 @@ func (s *SmtpWriter) Init(jsonconfig string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// write message in smtp writer.
|
||||||
|
// it will send an email with subject and only this message.
|
||||||
func (s *SmtpWriter) WriteMsg(msg string, level int) error {
|
func (s *SmtpWriter) WriteMsg(msg string, level int) error {
|
||||||
if level < s.Level {
|
if level < s.Level {
|
||||||
return nil
|
return nil
|
||||||
@ -65,9 +78,12 @@ func (s *SmtpWriter) WriteMsg(msg string, level int) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// implementing method. empty.
|
||||||
func (s *SmtpWriter) Flush() {
|
func (s *SmtpWriter) Flush() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// implementing method. empty.
|
||||||
func (s *SmtpWriter) Destroy() {
|
func (s *SmtpWriter) Destroy() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -5,16 +5,17 @@ import (
|
|||||||
"compress/flate"
|
"compress/flate"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
"errors"
|
"errors"
|
||||||
//"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var gmfim map[string]*MemFileInfo = make(map[string]*MemFileInfo)
|
var gmfim map[string]*MemFileInfo = make(map[string]*MemFileInfo)
|
||||||
|
var lock sync.RWMutex
|
||||||
|
|
||||||
// OpenMemZipFile returns MemFile object with a compressed static file.
|
// OpenMemZipFile returns MemFile object with a compressed static file.
|
||||||
// it's used for serve static file if gzip enable.
|
// it's used for serve static file if gzip enable.
|
||||||
@ -32,12 +33,12 @@ func OpenMemZipFile(path string, zip string) (*MemFile, error) {
|
|||||||
|
|
||||||
modtime := osfileinfo.ModTime()
|
modtime := osfileinfo.ModTime()
|
||||||
fileSize := osfileinfo.Size()
|
fileSize := osfileinfo.Size()
|
||||||
|
lock.RLock()
|
||||||
cfi, ok := gmfim[zip+":"+path]
|
cfi, ok := gmfim[zip+":"+path]
|
||||||
|
lock.RUnlock()
|
||||||
if ok && cfi.ModTime() == modtime && cfi.fileSize == fileSize {
|
if ok && cfi.ModTime() == modtime && cfi.fileSize == fileSize {
|
||||||
//fmt.Printf("read %s file %s from cache\n", zip, path)
|
|
||||||
} else {
|
} else {
|
||||||
//fmt.Printf("NOT read %s file %s from cache\n", zip, path)
|
|
||||||
var content []byte
|
var content []byte
|
||||||
if zip == "gzip" {
|
if zip == "gzip" {
|
||||||
//将文件内容压缩到zipbuf中
|
//将文件内容压缩到zipbuf中
|
||||||
@ -81,8 +82,9 @@ func OpenMemZipFile(path string, zip string) (*MemFile, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
cfi = &MemFileInfo{osfileinfo, modtime, content, int64(len(content)), fileSize}
|
cfi = &MemFileInfo{osfileinfo, modtime, content, int64(len(content)), fileSize}
|
||||||
|
lock.Lock()
|
||||||
|
defer lock.Unlock()
|
||||||
gmfim[zip+":"+path] = cfi
|
gmfim[zip+":"+path] = cfi
|
||||||
//fmt.Printf("%s file %s to %d, cache it\n", zip, path, len(content))
|
|
||||||
}
|
}
|
||||||
return &MemFile{fi: cfi, offset: 0}, nil
|
return &MemFile{fi: cfi, offset: 0}, nil
|
||||||
}
|
}
|
||||||
|
@ -61,6 +61,7 @@ var tpl = `
|
|||||||
</html>
|
</html>
|
||||||
`
|
`
|
||||||
|
|
||||||
|
// render default application error page with error and stack string.
|
||||||
func ShowErr(err interface{}, rw http.ResponseWriter, r *http.Request, Stack string) {
|
func ShowErr(err interface{}, rw http.ResponseWriter, r *http.Request, Stack string) {
|
||||||
t, _ := template.New("beegoerrortemp").Parse(tpl)
|
t, _ := template.New("beegoerrortemp").Parse(tpl)
|
||||||
data := make(map[string]string)
|
data := make(map[string]string)
|
||||||
@ -71,6 +72,7 @@ func ShowErr(err interface{}, rw http.ResponseWriter, r *http.Request, Stack str
|
|||||||
data["Stack"] = Stack
|
data["Stack"] = Stack
|
||||||
data["BeegoVersion"] = VERSION
|
data["BeegoVersion"] = VERSION
|
||||||
data["GoVersion"] = runtime.Version()
|
data["GoVersion"] = runtime.Version()
|
||||||
|
rw.WriteHeader(500)
|
||||||
t.Execute(rw, data)
|
t.Execute(rw, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -174,18 +176,19 @@ var errtpl = `
|
|||||||
</html>
|
</html>
|
||||||
`
|
`
|
||||||
|
|
||||||
|
// map of http handlers for each error string.
|
||||||
var ErrorMaps map[string]http.HandlerFunc
|
var ErrorMaps map[string]http.HandlerFunc
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
ErrorMaps = make(map[string]http.HandlerFunc)
|
ErrorMaps = make(map[string]http.HandlerFunc)
|
||||||
}
|
}
|
||||||
|
|
||||||
//404
|
// show 404 notfound error.
|
||||||
func NotFound(rw http.ResponseWriter, r *http.Request) {
|
func NotFound(rw http.ResponseWriter, r *http.Request) {
|
||||||
t, _ := template.New("beegoerrortemp").Parse(errtpl)
|
t, _ := template.New("beegoerrortemp").Parse(errtpl)
|
||||||
data := make(map[string]interface{})
|
data := make(map[string]interface{})
|
||||||
data["Title"] = "Page Not Found"
|
data["Title"] = "Page Not Found"
|
||||||
data["Content"] = template.HTML("<br>The Page You have requested flown the coop." +
|
data["Content"] = template.HTML("<br>The page you have requested has flown the coop." +
|
||||||
"<br>Perhaps you are here because:" +
|
"<br>Perhaps you are here because:" +
|
||||||
"<br><br><ul>" +
|
"<br><br><ul>" +
|
||||||
"<br>The page has moved" +
|
"<br>The page has moved" +
|
||||||
@ -198,28 +201,28 @@ func NotFound(rw http.ResponseWriter, r *http.Request) {
|
|||||||
t.Execute(rw, data)
|
t.Execute(rw, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
//401
|
// show 401 unauthorized error.
|
||||||
func Unauthorized(rw http.ResponseWriter, r *http.Request) {
|
func Unauthorized(rw http.ResponseWriter, r *http.Request) {
|
||||||
t, _ := template.New("beegoerrortemp").Parse(errtpl)
|
t, _ := template.New("beegoerrortemp").Parse(errtpl)
|
||||||
data := make(map[string]interface{})
|
data := make(map[string]interface{})
|
||||||
data["Title"] = "Unauthorized"
|
data["Title"] = "Unauthorized"
|
||||||
data["Content"] = template.HTML("<br>The Page You have requested can't authorized." +
|
data["Content"] = template.HTML("<br>The page you have requested can't be authorized." +
|
||||||
"<br>Perhaps you are here because:" +
|
"<br>Perhaps you are here because:" +
|
||||||
"<br><br><ul>" +
|
"<br><br><ul>" +
|
||||||
"<br>Check the credentials that you supplied" +
|
"<br>The credentials you supplied are incorrect" +
|
||||||
"<br>Check the address for errors" +
|
"<br>There are errors in the website address" +
|
||||||
"</ul>")
|
"</ul>")
|
||||||
data["BeegoVersion"] = VERSION
|
data["BeegoVersion"] = VERSION
|
||||||
//rw.WriteHeader(http.StatusUnauthorized)
|
//rw.WriteHeader(http.StatusUnauthorized)
|
||||||
t.Execute(rw, data)
|
t.Execute(rw, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
//403
|
// show 403 forbidden error.
|
||||||
func Forbidden(rw http.ResponseWriter, r *http.Request) {
|
func Forbidden(rw http.ResponseWriter, r *http.Request) {
|
||||||
t, _ := template.New("beegoerrortemp").Parse(errtpl)
|
t, _ := template.New("beegoerrortemp").Parse(errtpl)
|
||||||
data := make(map[string]interface{})
|
data := make(map[string]interface{})
|
||||||
data["Title"] = "Forbidden"
|
data["Title"] = "Forbidden"
|
||||||
data["Content"] = template.HTML("<br>The Page You have requested forbidden." +
|
data["Content"] = template.HTML("<br>The page you have requested is forbidden." +
|
||||||
"<br>Perhaps you are here because:" +
|
"<br>Perhaps you are here because:" +
|
||||||
"<br><br><ul>" +
|
"<br><br><ul>" +
|
||||||
"<br>Your address may be blocked" +
|
"<br>Your address may be blocked" +
|
||||||
@ -231,12 +234,12 @@ func Forbidden(rw http.ResponseWriter, r *http.Request) {
|
|||||||
t.Execute(rw, data)
|
t.Execute(rw, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
//503
|
// show 503 service unavailable error.
|
||||||
func ServiceUnavailable(rw http.ResponseWriter, r *http.Request) {
|
func ServiceUnavailable(rw http.ResponseWriter, r *http.Request) {
|
||||||
t, _ := template.New("beegoerrortemp").Parse(errtpl)
|
t, _ := template.New("beegoerrortemp").Parse(errtpl)
|
||||||
data := make(map[string]interface{})
|
data := make(map[string]interface{})
|
||||||
data["Title"] = "Service Unavailable"
|
data["Title"] = "Service Unavailable"
|
||||||
data["Content"] = template.HTML("<br>The Page You have requested unavailable." +
|
data["Content"] = template.HTML("<br>The page you have requested is unavailable." +
|
||||||
"<br>Perhaps you are here because:" +
|
"<br>Perhaps you are here because:" +
|
||||||
"<br><br><ul>" +
|
"<br><br><ul>" +
|
||||||
"<br><br>The page is overloaded" +
|
"<br><br>The page is overloaded" +
|
||||||
@ -247,30 +250,32 @@ func ServiceUnavailable(rw http.ResponseWriter, r *http.Request) {
|
|||||||
t.Execute(rw, data)
|
t.Execute(rw, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
//500
|
// show 500 internal server error.
|
||||||
func InternalServerError(rw http.ResponseWriter, r *http.Request) {
|
func InternalServerError(rw http.ResponseWriter, r *http.Request) {
|
||||||
t, _ := template.New("beegoerrortemp").Parse(errtpl)
|
t, _ := template.New("beegoerrortemp").Parse(errtpl)
|
||||||
data := make(map[string]interface{})
|
data := make(map[string]interface{})
|
||||||
data["Title"] = "Internal Server Error"
|
data["Title"] = "Internal Server Error"
|
||||||
data["Content"] = template.HTML("<br>The Page You have requested has down now." +
|
data["Content"] = template.HTML("<br>The page you have requested is down right now." +
|
||||||
"<br><br><ul>" +
|
"<br><br><ul>" +
|
||||||
"<br>simply try again later" +
|
"<br>Please try again later and report the error to the website administrator" +
|
||||||
"<br>you should report the fault to the website administrator" +
|
"<br></ul>")
|
||||||
"</ul>")
|
|
||||||
data["BeegoVersion"] = VERSION
|
data["BeegoVersion"] = VERSION
|
||||||
//rw.WriteHeader(http.StatusInternalServerError)
|
//rw.WriteHeader(http.StatusInternalServerError)
|
||||||
t.Execute(rw, data)
|
t.Execute(rw, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// show 500 internal error with simple text string.
|
||||||
func SimpleServerError(rw http.ResponseWriter, r *http.Request) {
|
func SimpleServerError(rw http.ResponseWriter, r *http.Request) {
|
||||||
http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add http handler for given error string.
|
||||||
func Errorhandler(err string, h http.HandlerFunc) {
|
func Errorhandler(err string, h http.HandlerFunc) {
|
||||||
ErrorMaps[err] = h
|
ErrorMaps[err] = h
|
||||||
}
|
}
|
||||||
|
|
||||||
func RegisterErrorHander() {
|
// register default error http handlers, 404,401,403,500 and 503.
|
||||||
|
func RegisterErrorHandler() {
|
||||||
if _, ok := ErrorMaps["404"]; !ok {
|
if _, ok := ErrorMaps["404"]; !ok {
|
||||||
ErrorMaps["404"] = NotFound
|
ErrorMaps["404"] = NotFound
|
||||||
}
|
}
|
||||||
@ -292,6 +297,8 @@ func RegisterErrorHander() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// show error string as simple text message.
|
||||||
|
// if error string is empty, show 500 error as default.
|
||||||
func Exception(errcode string, w http.ResponseWriter, r *http.Request, msg string) {
|
func Exception(errcode string, w http.ResponseWriter, r *http.Request, msg string) {
|
||||||
if h, ok := ErrorMaps[errcode]; ok {
|
if h, ok := ErrorMaps[errcode]; ok {
|
||||||
isint, err := strconv.Atoi(errcode)
|
isint, err := strconv.Atoi(errcode)
|
||||||
|
@ -2,16 +2,19 @@ package middleware
|
|||||||
|
|
||||||
import "fmt"
|
import "fmt"
|
||||||
|
|
||||||
|
// http exceptions
|
||||||
type HTTPException struct {
|
type HTTPException struct {
|
||||||
StatusCode int // http status code 4xx, 5xx
|
StatusCode int // http status code 4xx, 5xx
|
||||||
Description string
|
Description string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// return http exception error string, e.g. "400 Bad Request".
|
||||||
func (e *HTTPException) Error() string {
|
func (e *HTTPException) Error() string {
|
||||||
// return `status description`, e.g. `400 Bad Request`
|
|
||||||
return fmt.Sprintf("%d %s", e.StatusCode, e.Description)
|
return fmt.Sprintf("%d %s", e.StatusCode, e.Description)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// map of http exceptions for each http status code int.
|
||||||
|
// defined 400,401,403,404,405,500,502,503 and 504 default.
|
||||||
var HTTPExceptionMaps map[int]HTTPException
|
var HTTPExceptionMaps map[int]HTTPException
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
3
mime.go
3
mime.go
@ -544,8 +544,9 @@ var mimemaps map[string]string = map[string]string{
|
|||||||
".mustache": "text/html",
|
".mustache": "text/html",
|
||||||
}
|
}
|
||||||
|
|
||||||
func initMime() {
|
func initMime() error {
|
||||||
for k, v := range mimemaps {
|
for k, v := range mimemaps {
|
||||||
mime.AddExtensionType(k, v)
|
mime.AddExtensionType(k, v)
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
12
orm/cmd.go
12
orm/cmd.go
@ -16,6 +16,7 @@ var (
|
|||||||
commands = make(map[string]commander)
|
commands = make(map[string]commander)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// print help.
|
||||||
func printHelp(errs ...string) {
|
func printHelp(errs ...string) {
|
||||||
content := `orm command usage:
|
content := `orm command usage:
|
||||||
|
|
||||||
@ -31,6 +32,7 @@ func printHelp(errs ...string) {
|
|||||||
os.Exit(2)
|
os.Exit(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// listen for orm command and then run it if command arguments passed.
|
||||||
func RunCommand() {
|
func RunCommand() {
|
||||||
if len(os.Args) < 2 || os.Args[1] != "orm" {
|
if len(os.Args) < 2 || os.Args[1] != "orm" {
|
||||||
return
|
return
|
||||||
@ -58,6 +60,7 @@ func RunCommand() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sync database struct command interface.
|
||||||
type commandSyncDb struct {
|
type commandSyncDb struct {
|
||||||
al *alias
|
al *alias
|
||||||
force bool
|
force bool
|
||||||
@ -66,6 +69,7 @@ type commandSyncDb struct {
|
|||||||
rtOnError bool
|
rtOnError bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parse orm command line arguments.
|
||||||
func (d *commandSyncDb) Parse(args []string) {
|
func (d *commandSyncDb) Parse(args []string) {
|
||||||
var name string
|
var name string
|
||||||
|
|
||||||
@ -78,6 +82,7 @@ func (d *commandSyncDb) Parse(args []string) {
|
|||||||
d.al = getDbAlias(name)
|
d.al = getDbAlias(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// run orm line command.
|
||||||
func (d *commandSyncDb) Run() error {
|
func (d *commandSyncDb) Run() error {
|
||||||
var drops []string
|
var drops []string
|
||||||
if d.force {
|
if d.force {
|
||||||
@ -208,10 +213,12 @@ func (d *commandSyncDb) Run() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// database creation commander interface implement.
|
||||||
type commandSqlAll struct {
|
type commandSqlAll struct {
|
||||||
al *alias
|
al *alias
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parse orm command line arguments.
|
||||||
func (d *commandSqlAll) Parse(args []string) {
|
func (d *commandSqlAll) Parse(args []string) {
|
||||||
var name string
|
var name string
|
||||||
|
|
||||||
@ -222,6 +229,7 @@ func (d *commandSqlAll) Parse(args []string) {
|
|||||||
d.al = getDbAlias(name)
|
d.al = getDbAlias(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// run orm line command.
|
||||||
func (d *commandSqlAll) Run() error {
|
func (d *commandSqlAll) Run() error {
|
||||||
sqls, indexes := getDbCreateSql(d.al)
|
sqls, indexes := getDbCreateSql(d.al)
|
||||||
var all []string
|
var all []string
|
||||||
@ -243,6 +251,10 @@ func init() {
|
|||||||
commands["sqlall"] = new(commandSqlAll)
|
commands["sqlall"] = new(commandSqlAll)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// run syncdb command line.
|
||||||
|
// name means table's alias name. default is "default".
|
||||||
|
// force means run next sql if the current is error.
|
||||||
|
// verbose means show all info when running command or not.
|
||||||
func RunSyncdb(name string, force bool, verbose bool) error {
|
func RunSyncdb(name string, force bool, verbose bool) error {
|
||||||
BootStrap()
|
BootStrap()
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ type dbIndex struct {
|
|||||||
Sql string
|
Sql string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create database drop sql.
|
||||||
func getDbDropSql(al *alias) (sqls []string) {
|
func getDbDropSql(al *alias) (sqls []string) {
|
||||||
if len(modelCache.cache) == 0 {
|
if len(modelCache.cache) == 0 {
|
||||||
fmt.Println("no Model found, need register your model")
|
fmt.Println("no Model found, need register your model")
|
||||||
@ -26,6 +27,7 @@ func getDbDropSql(al *alias) (sqls []string) {
|
|||||||
return sqls
|
return sqls
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get database column type string.
|
||||||
func getColumnTyp(al *alias, fi *fieldInfo) (col string) {
|
func getColumnTyp(al *alias, fi *fieldInfo) (col string) {
|
||||||
T := al.DbBaser.DbTypes()
|
T := al.DbBaser.DbTypes()
|
||||||
fieldType := fi.fieldType
|
fieldType := fi.fieldType
|
||||||
@ -79,6 +81,7 @@ checkColumn:
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create alter sql string.
|
||||||
func getColumnAddQuery(al *alias, fi *fieldInfo) string {
|
func getColumnAddQuery(al *alias, fi *fieldInfo) string {
|
||||||
Q := al.DbBaser.TableQuote()
|
Q := al.DbBaser.TableQuote()
|
||||||
typ := getColumnTyp(al, fi)
|
typ := getColumnTyp(al, fi)
|
||||||
@ -90,6 +93,7 @@ func getColumnAddQuery(al *alias, fi *fieldInfo) string {
|
|||||||
return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s", Q, fi.mi.table, Q, Q, fi.column, Q, typ)
|
return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s", Q, fi.mi.table, Q, Q, fi.column, Q, typ)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create database creation string.
|
||||||
func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
|
func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
|
||||||
if len(modelCache.cache) == 0 {
|
if len(modelCache.cache) == 0 {
|
||||||
fmt.Println("no Model found, need register your model")
|
fmt.Println("no Model found, need register your model")
|
||||||
|
175
orm/db.go
175
orm/db.go
@ -15,7 +15,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrMissPK = errors.New("missed pk value")
|
ErrMissPK = errors.New("missed pk value") // missing pk error
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -45,13 +45,22 @@ var (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// an instance of dbBaser interface/
|
||||||
type dbBase struct {
|
type dbBase struct {
|
||||||
ins dbBaser
|
ins dbBaser
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check dbBase implements dbBaser interface.
|
||||||
var _ dbBaser = new(dbBase)
|
var _ dbBaser = new(dbBase)
|
||||||
|
|
||||||
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) {
|
// get struct columns values as interface slice.
|
||||||
|
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, err error) {
|
||||||
|
var columns []string
|
||||||
|
|
||||||
|
if names != nil {
|
||||||
|
columns = *names
|
||||||
|
}
|
||||||
|
|
||||||
for _, column := range cols {
|
for _, column := range cols {
|
||||||
var fi *fieldInfo
|
var fi *fieldInfo
|
||||||
if fi, _ = mi.fields.GetByAny(column); fi != nil {
|
if fi, _ = mi.fields.GetByAny(column); fi != nil {
|
||||||
@ -64,14 +73,24 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
|
|||||||
}
|
}
|
||||||
value, err := d.collectFieldValue(mi, fi, ind, insert, tz)
|
value, err := d.collectFieldValue(mi, fi, ind, insert, tz)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if names != nil {
|
||||||
columns = append(columns, column)
|
columns = append(columns, column)
|
||||||
|
}
|
||||||
|
|
||||||
values = append(values, value)
|
values = append(values, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if names != nil {
|
||||||
|
*names = columns
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get one field value in struct column as interface.
|
||||||
func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) {
|
func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) {
|
||||||
var value interface{}
|
var value interface{}
|
||||||
if fi.pk {
|
if fi.pk {
|
||||||
@ -140,6 +159,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
|
|||||||
return value, nil
|
return value, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create insert sql preparation statement object.
|
||||||
func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) {
|
func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) {
|
||||||
Q := d.ins.TableQuote()
|
Q := d.ins.TableQuote()
|
||||||
|
|
||||||
@ -165,8 +185,9 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
|
|||||||
return stmt, query, err
|
return stmt, query, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// insert struct with prepared statement and given struct reflect value.
|
||||||
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||||
_, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz)
|
values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@ -185,6 +206,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// query sql ,read records and persist in dbBaser.
|
||||||
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) error {
|
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) error {
|
||||||
var whereCols []string
|
var whereCols []string
|
||||||
var args []interface{}
|
var args []interface{}
|
||||||
@ -192,7 +214,8 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
|||||||
// if specify cols length > 0, then use it for where condition.
|
// if specify cols length > 0, then use it for where condition.
|
||||||
if len(cols) > 0 {
|
if len(cols) > 0 {
|
||||||
var err error
|
var err error
|
||||||
whereCols, args, err = d.collectValues(mi, ind, cols, false, false, tz)
|
whereCols = make([]string, 0, len(cols))
|
||||||
|
args, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -202,7 +225,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
|||||||
if ok == false {
|
if ok == false {
|
||||||
return ErrMissPK
|
return ErrMissPK
|
||||||
}
|
}
|
||||||
whereCols = append(whereCols, pkColumn)
|
whereCols = []string{pkColumn}
|
||||||
args = append(args, pkValue)
|
args = append(args, pkValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -243,16 +266,77 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// execute insert sql dbQuerier with given struct reflect.Value.
|
||||||
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||||
names, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz)
|
names := make([]string, 0, len(mi.fields.dbcols)-1)
|
||||||
|
values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return d.InsertValue(q, mi, names, values)
|
return d.InsertValue(q, mi, false, names, values)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values []interface{}) (int64, error) {
|
// multi-insert sql with given slice struct reflect.Value.
|
||||||
|
func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) {
|
||||||
|
var (
|
||||||
|
cnt int64
|
||||||
|
nums int
|
||||||
|
values []interface{}
|
||||||
|
names []string
|
||||||
|
)
|
||||||
|
|
||||||
|
// typ := reflect.Indirect(mi.addrField).Type()
|
||||||
|
|
||||||
|
length := sind.Len()
|
||||||
|
|
||||||
|
for i := 1; i <= length; i++ {
|
||||||
|
|
||||||
|
ind := reflect.Indirect(sind.Index(i - 1))
|
||||||
|
|
||||||
|
// Is this needed ?
|
||||||
|
// if !ind.Type().AssignableTo(typ) {
|
||||||
|
// return cnt, ErrArgs
|
||||||
|
// }
|
||||||
|
|
||||||
|
if i == 1 {
|
||||||
|
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
|
||||||
|
if err != nil {
|
||||||
|
return cnt, err
|
||||||
|
}
|
||||||
|
values = make([]interface{}, bulk*len(vus))
|
||||||
|
nums += copy(values, vus)
|
||||||
|
|
||||||
|
} else {
|
||||||
|
|
||||||
|
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
|
||||||
|
if err != nil {
|
||||||
|
return cnt, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(vus) != len(names) {
|
||||||
|
return cnt, ErrArgs
|
||||||
|
}
|
||||||
|
|
||||||
|
nums += copy(values[nums:], vus)
|
||||||
|
}
|
||||||
|
|
||||||
|
if i > 1 && i%bulk == 0 || length == i {
|
||||||
|
num, err := d.InsertValue(q, mi, true, names, values[:nums])
|
||||||
|
if err != nil {
|
||||||
|
return cnt, err
|
||||||
|
}
|
||||||
|
cnt += num
|
||||||
|
nums = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cnt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// execute insert sql with given struct and given values.
|
||||||
|
// insert the given values, not the field values in struct.
|
||||||
|
func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
|
||||||
Q := d.ins.TableQuote()
|
Q := d.ins.TableQuote()
|
||||||
|
|
||||||
marks := make([]string, len(names))
|
marks := make([]string, len(names))
|
||||||
@ -264,36 +348,51 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values
|
|||||||
qmarks := strings.Join(marks, ", ")
|
qmarks := strings.Join(marks, ", ")
|
||||||
columns := strings.Join(names, sep)
|
columns := strings.Join(names, sep)
|
||||||
|
|
||||||
|
multi := len(values) / len(names)
|
||||||
|
|
||||||
|
if isMulti {
|
||||||
|
qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
|
||||||
|
}
|
||||||
|
|
||||||
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks)
|
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks)
|
||||||
|
|
||||||
d.ins.ReplaceMarks(&query)
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
if d.ins.HasReturningID(mi, &query) {
|
if isMulti || !d.ins.HasReturningID(mi, &query) {
|
||||||
row := q.QueryRow(query, values...)
|
|
||||||
var id int64
|
|
||||||
err := row.Scan(&id)
|
|
||||||
return id, err
|
|
||||||
} else {
|
|
||||||
if res, err := q.Exec(query, values...); err == nil {
|
if res, err := q.Exec(query, values...); err == nil {
|
||||||
|
if isMulti {
|
||||||
|
return res.RowsAffected()
|
||||||
|
}
|
||||||
return res.LastInsertId()
|
return res.LastInsertId()
|
||||||
} else {
|
} else {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
row := q.QueryRow(query, values...)
|
||||||
|
var id int64
|
||||||
|
err := row.Scan(&id)
|
||||||
|
return id, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// execute update sql dbQuerier with given struct reflect.Value.
|
||||||
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
|
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
|
||||||
pkName, pkValue, ok := getExistPk(mi, ind)
|
pkName, pkValue, ok := getExistPk(mi, ind)
|
||||||
if ok == false {
|
if ok == false {
|
||||||
return 0, ErrMissPK
|
return 0, ErrMissPK
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var setNames []string
|
||||||
|
|
||||||
// if specify cols length is zero, then commit all columns.
|
// if specify cols length is zero, then commit all columns.
|
||||||
if len(cols) == 0 {
|
if len(cols) == 0 {
|
||||||
cols = mi.fields.dbcols
|
cols = mi.fields.dbcols
|
||||||
|
setNames = make([]string, 0, len(mi.fields.dbcols)-1)
|
||||||
|
} else {
|
||||||
|
setNames = make([]string, 0, len(cols))
|
||||||
}
|
}
|
||||||
|
|
||||||
setNames, setValues, err := d.collectValues(mi, ind, cols, true, false, tz)
|
setValues, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@ -317,6 +416,8 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
|
|||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// execute delete sql dbQuerier with given struct reflect.Value.
|
||||||
|
// delete index is pk.
|
||||||
func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||||
pkName, pkValue, ok := getExistPk(mi, ind)
|
pkName, pkValue, ok := getExistPk(mi, ind)
|
||||||
if ok == false {
|
if ok == false {
|
||||||
@ -358,6 +459,8 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
|
|||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// update table-related record by querySet.
|
||||||
|
// need querySet not struct reflect.Value to update related records.
|
||||||
func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) {
|
func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) {
|
||||||
columns := make([]string, 0, len(params))
|
columns := make([]string, 0, len(params))
|
||||||
values := make([]interface{}, 0, len(params))
|
values := make([]interface{}, 0, len(params))
|
||||||
@ -433,6 +536,8 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
|||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// delete related records.
|
||||||
|
// do UpdateBanch or DeleteBanch by condition of tables' relationship.
|
||||||
func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error {
|
func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error {
|
||||||
for _, fi := range mi.fields.fieldsReverse {
|
for _, fi := range mi.fields.fieldsReverse {
|
||||||
fi = fi.reverseFieldInfo
|
fi = fi.reverseFieldInfo
|
||||||
@ -459,8 +564,11 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// delete table-related records.
|
||||||
func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) {
|
func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) {
|
||||||
tables := newDbTables(mi, d.ins)
|
tables := newDbTables(mi, d.ins)
|
||||||
|
tables.skipEnd = true
|
||||||
|
|
||||||
if qs != nil {
|
if qs != nil {
|
||||||
tables.parseRelated(qs.related, qs.relDepth)
|
tables.parseRelated(qs.related, qs.relDepth)
|
||||||
}
|
}
|
||||||
@ -486,6 +594,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
|||||||
rs = r
|
rs = r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer rs.Close()
|
||||||
|
|
||||||
var ref interface{}
|
var ref interface{}
|
||||||
|
|
||||||
args = make([]interface{}, 0)
|
args = make([]interface{}, 0)
|
||||||
@ -532,6 +642,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
|||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// read related records.
|
||||||
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
|
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
|
||||||
|
|
||||||
val := reflect.ValueOf(container)
|
val := reflect.ValueOf(container)
|
||||||
@ -640,6 +751,8 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
|||||||
refs[i] = &ref
|
refs[i] = &ref
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer rs.Close()
|
||||||
|
|
||||||
slice := ind
|
slice := ind
|
||||||
|
|
||||||
var cnt int64
|
var cnt int64
|
||||||
@ -739,6 +852,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
|||||||
return cnt, nil
|
return cnt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// excute count sql and return count result int64.
|
||||||
func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
|
func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
|
||||||
tables := newDbTables(mi, d.ins)
|
tables := newDbTables(mi, d.ins)
|
||||||
tables.parseRelated(qs.related, qs.relDepth)
|
tables.parseRelated(qs.related, qs.relDepth)
|
||||||
@ -759,6 +873,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generate sql with replacing operator string placeholders and replaced values.
|
||||||
func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
|
func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
|
||||||
sql := ""
|
sql := ""
|
||||||
params := getFlatParams(fi, args, tz)
|
params := getFlatParams(fi, args, tz)
|
||||||
@ -812,10 +927,12 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri
|
|||||||
return sql, params
|
return sql, params
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// gernerate sql string with inner function, such as UPPER(text).
|
||||||
func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) {
|
func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) {
|
||||||
// default not use
|
// default not use
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set values to struct column.
|
||||||
func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) {
|
func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) {
|
||||||
for i, column := range cols {
|
for i, column := range cols {
|
||||||
val := reflect.Indirect(reflect.ValueOf(values[i])).Interface()
|
val := reflect.Indirect(reflect.ValueOf(values[i])).Interface()
|
||||||
@ -837,6 +954,7 @@ func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// convert value from database result to value following in field type.
|
||||||
func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) {
|
func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) {
|
||||||
if val == nil {
|
if val == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -989,6 +1107,7 @@ end:
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set one value to struct column field.
|
||||||
func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) {
|
func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) {
|
||||||
|
|
||||||
fieldType := fi.fieldType
|
fieldType := fi.fieldType
|
||||||
@ -1063,6 +1182,7 @@ setValue:
|
|||||||
return value, nil
|
return value, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// query sql, read values , save to *[]ParamList.
|
||||||
func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) {
|
func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -1150,6 +1270,8 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
|
|||||||
refs[i] = &ref
|
refs[i] = &ref
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer rs.Close()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
cnt int64
|
cnt int64
|
||||||
columns []string
|
columns []string
|
||||||
@ -1228,6 +1350,11 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
|
|||||||
return cnt, nil
|
return cnt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *dbBase) RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// flag of update joined record.
|
||||||
func (d *dbBase) SupportUpdateJoin() bool {
|
func (d *dbBase) SupportUpdateJoin() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -1236,30 +1363,37 @@ func (d *dbBase) MaxLimit() uint64 {
|
|||||||
return 18446744073709551615
|
return 18446744073709551615
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// return quote.
|
||||||
func (d *dbBase) TableQuote() string {
|
func (d *dbBase) TableQuote() string {
|
||||||
return "`"
|
return "`"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// replace value placeholer in parametered sql string.
|
||||||
func (d *dbBase) ReplaceMarks(query *string) {
|
func (d *dbBase) ReplaceMarks(query *string) {
|
||||||
// default use `?` as mark, do nothing
|
// default use `?` as mark, do nothing
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// flag of RETURNING sql.
|
||||||
func (d *dbBase) HasReturningID(*modelInfo, *string) bool {
|
func (d *dbBase) HasReturningID(*modelInfo, *string) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// convert time from db.
|
||||||
func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) {
|
func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) {
|
||||||
*t = t.In(tz)
|
*t = t.In(tz)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// convert time to db.
|
||||||
func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) {
|
func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) {
|
||||||
*t = t.In(tz)
|
*t = t.In(tz)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get database types.
|
||||||
func (d *dbBase) DbTypes() map[string]string {
|
func (d *dbBase) DbTypes() map[string]string {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// gt all tables.
|
||||||
func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
|
func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
|
||||||
tables := make(map[string]bool)
|
tables := make(map[string]bool)
|
||||||
query := d.ins.ShowTablesQuery()
|
query := d.ins.ShowTablesQuery()
|
||||||
@ -1268,6 +1402,8 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
|
|||||||
return tables, err
|
return tables, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var table string
|
var table string
|
||||||
err := rows.Scan(&table)
|
err := rows.Scan(&table)
|
||||||
@ -1282,6 +1418,7 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
|
|||||||
return tables, nil
|
return tables, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get all cloumns in table.
|
||||||
func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
|
func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
|
||||||
columns := make(map[string][3]string)
|
columns := make(map[string][3]string)
|
||||||
query := d.ins.ShowColumnsQuery(table)
|
query := d.ins.ShowColumnsQuery(table)
|
||||||
@ -1290,6 +1427,8 @@ func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, e
|
|||||||
return columns, err
|
return columns, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var (
|
var (
|
||||||
name string
|
name string
|
||||||
@ -1306,18 +1445,22 @@ func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, e
|
|||||||
return columns, nil
|
return columns, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// not implement.
|
||||||
func (d *dbBase) OperatorSql(operator string) string {
|
func (d *dbBase) OperatorSql(operator string) string {
|
||||||
panic(ErrNotImplement)
|
panic(ErrNotImplement)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// not implement.
|
||||||
func (d *dbBase) ShowTablesQuery() string {
|
func (d *dbBase) ShowTablesQuery() string {
|
||||||
panic(ErrNotImplement)
|
panic(ErrNotImplement)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// not implement.
|
||||||
func (d *dbBase) ShowColumnsQuery(table string) string {
|
func (d *dbBase) ShowColumnsQuery(table string) string {
|
||||||
panic(ErrNotImplement)
|
panic(ErrNotImplement)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// not implement.
|
||||||
func (d *dbBase) IndexExists(dbQuerier, string, string) bool {
|
func (d *dbBase) IndexExists(dbQuerier, string, string) bool {
|
||||||
panic(ErrNotImplement)
|
panic(ErrNotImplement)
|
||||||
}
|
}
|
||||||
|
154
orm/db_alias.go
154
orm/db_alias.go
@ -3,33 +3,37 @@ package orm
|
|||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// database driver constant int.
|
||||||
type DriverType int
|
type DriverType int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
_ DriverType = iota
|
_ DriverType = iota // int enum type
|
||||||
DR_MySQL
|
DR_MySQL // mysql
|
||||||
DR_Sqlite
|
DR_Sqlite // sqlite
|
||||||
DR_Oracle
|
DR_Oracle // oracle
|
||||||
DR_Postgres
|
DR_Postgres // pgsql
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// database driver string.
|
||||||
type driver string
|
type driver string
|
||||||
|
|
||||||
|
// get type constant int of current driver..
|
||||||
func (d driver) Type() DriverType {
|
func (d driver) Type() DriverType {
|
||||||
a, _ := dataBaseCache.get(string(d))
|
a, _ := dataBaseCache.get(string(d))
|
||||||
return a.Driver
|
return a.Driver
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get name of current driver
|
||||||
func (d driver) Name() string {
|
func (d driver) Name() string {
|
||||||
return string(d)
|
return string(d)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check driver iis implemented Driver interface or not.
|
||||||
var _ Driver = new(driver)
|
var _ Driver = new(driver)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -47,11 +51,13 @@ var (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// database alias cacher.
|
||||||
type _dbCache struct {
|
type _dbCache struct {
|
||||||
mux sync.RWMutex
|
mux sync.RWMutex
|
||||||
cache map[string]*alias
|
cache map[string]*alias
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add database alias with original name.
|
||||||
func (ac *_dbCache) add(name string, al *alias) (added bool) {
|
func (ac *_dbCache) add(name string, al *alias) (added bool) {
|
||||||
ac.mux.Lock()
|
ac.mux.Lock()
|
||||||
defer ac.mux.Unlock()
|
defer ac.mux.Unlock()
|
||||||
@ -62,6 +68,7 @@ func (ac *_dbCache) add(name string, al *alias) (added bool) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get database alias if cached.
|
||||||
func (ac *_dbCache) get(name string) (al *alias, ok bool) {
|
func (ac *_dbCache) get(name string) (al *alias, ok bool) {
|
||||||
ac.mux.RLock()
|
ac.mux.RLock()
|
||||||
defer ac.mux.RUnlock()
|
defer ac.mux.RUnlock()
|
||||||
@ -69,6 +76,7 @@ func (ac *_dbCache) get(name string) (al *alias, ok bool) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get default alias.
|
||||||
func (ac *_dbCache) getDefault() (al *alias) {
|
func (ac *_dbCache) getDefault() (al *alias) {
|
||||||
al, _ = ac.get("default")
|
al, _ = ac.get("default")
|
||||||
return
|
return
|
||||||
@ -87,57 +95,29 @@ type alias struct {
|
|||||||
Engine string
|
Engine string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Setting the database connect params. Use the database driver self dataSource args.
|
func detectTZ(al *alias) {
|
||||||
func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
|
|
||||||
al := new(alias)
|
|
||||||
al.Name = aliasName
|
|
||||||
al.DriverName = driverName
|
|
||||||
al.DataSource = dataSource
|
|
||||||
|
|
||||||
var (
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
if dr, ok := drivers[driverName]; ok {
|
|
||||||
al.DbBaser = dbBasers[dr]
|
|
||||||
al.Driver = dr
|
|
||||||
} else {
|
|
||||||
err = fmt.Errorf("driver name `%s` have not registered", driverName)
|
|
||||||
goto end
|
|
||||||
}
|
|
||||||
|
|
||||||
if dataBaseCache.add(aliasName, al) == false {
|
|
||||||
err = fmt.Errorf("db name `%s` already registered, cannot reuse", aliasName)
|
|
||||||
goto end
|
|
||||||
}
|
|
||||||
|
|
||||||
al.DB, err = sql.Open(driverName, dataSource)
|
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
|
|
||||||
goto end
|
|
||||||
}
|
|
||||||
|
|
||||||
// orm timezone system match database
|
// orm timezone system match database
|
||||||
// default use Local
|
// default use Local
|
||||||
al.TZ = time.Local
|
al.TZ = time.Local
|
||||||
|
|
||||||
|
if al.DriverName == "sphinx" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
switch al.Driver {
|
switch al.Driver {
|
||||||
case DR_MySQL:
|
case DR_MySQL:
|
||||||
row := al.DB.QueryRow("SELECT @@session.time_zone")
|
row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
|
||||||
var tz string
|
var tz string
|
||||||
row.Scan(&tz)
|
row.Scan(&tz)
|
||||||
if tz == "SYSTEM" {
|
if len(tz) >= 8 {
|
||||||
tz = ""
|
if tz[0] != '-' {
|
||||||
row = al.DB.QueryRow("SELECT @@system_time_zone")
|
tz = "+" + tz
|
||||||
row.Scan(&tz)
|
|
||||||
t, err := time.Parse("MST", tz)
|
|
||||||
if err == nil {
|
|
||||||
al.TZ = t.Location()
|
|
||||||
}
|
}
|
||||||
} else {
|
t, err := time.Parse("-07:00:00", tz)
|
||||||
t, err := time.Parse("-07:00", tz)
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
al.TZ = t.Location()
|
al.TZ = t.Location()
|
||||||
|
} else {
|
||||||
|
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -163,8 +143,64 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
|
|||||||
loc, err := time.LoadLocation(tz)
|
loc, err := time.LoadLocation(tz)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
al.TZ = loc
|
al.TZ = loc
|
||||||
|
} else {
|
||||||
|
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
|
||||||
|
al := new(alias)
|
||||||
|
al.Name = aliasName
|
||||||
|
al.DriverName = driverName
|
||||||
|
al.DB = db
|
||||||
|
|
||||||
|
if dr, ok := drivers[driverName]; ok {
|
||||||
|
al.DbBaser = dbBasers[dr]
|
||||||
|
al.Driver = dr
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := db.Ping()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if dataBaseCache.add(aliasName, al) == false {
|
||||||
|
return nil, fmt.Errorf("db name `%s` already registered, cannot reuse", aliasName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return al, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error {
|
||||||
|
_, err := addAliasWthDB(aliasName, driverName, db)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setting the database connect params. Use the database driver self dataSource args.
|
||||||
|
func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error {
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
db *sql.DB
|
||||||
|
al *alias
|
||||||
|
)
|
||||||
|
|
||||||
|
db, err = sql.Open(driverName, dataSource)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
|
||||||
|
goto end
|
||||||
|
}
|
||||||
|
|
||||||
|
al, err = addAliasWthDB(aliasName, driverName, db)
|
||||||
|
if err != nil {
|
||||||
|
goto end
|
||||||
|
}
|
||||||
|
|
||||||
|
al.DataSource = dataSource
|
||||||
|
|
||||||
|
detectTZ(al)
|
||||||
|
|
||||||
for i, v := range params {
|
for i, v := range params {
|
||||||
switch i {
|
switch i {
|
||||||
@ -175,39 +211,37 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = al.DB.Ping()
|
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
|
|
||||||
goto end
|
|
||||||
}
|
|
||||||
|
|
||||||
end:
|
end:
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err.Error())
|
if db != nil {
|
||||||
os.Exit(2)
|
db.Close()
|
||||||
}
|
}
|
||||||
|
DebugLog.Println(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register a database driver use specify driver name, this can be definition the driver is which database type.
|
// Register a database driver use specify driver name, this can be definition the driver is which database type.
|
||||||
func RegisterDriver(driverName string, typ DriverType) {
|
func RegisterDriver(driverName string, typ DriverType) error {
|
||||||
if t, ok := drivers[driverName]; ok == false {
|
if t, ok := drivers[driverName]; ok == false {
|
||||||
drivers[driverName] = typ
|
drivers[driverName] = typ
|
||||||
} else {
|
} else {
|
||||||
if t != typ {
|
if t != typ {
|
||||||
fmt.Sprintf("driverName `%s` db driver already registered and is other type\n", driverName)
|
return fmt.Errorf("driverName `%s` db driver already registered and is other type\n", driverName)
|
||||||
os.Exit(2)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Change the database default used timezone
|
// Change the database default used timezone
|
||||||
func SetDataBaseTZ(aliasName string, tz *time.Location) {
|
func SetDataBaseTZ(aliasName string, tz *time.Location) error {
|
||||||
if al, ok := dataBaseCache.get(aliasName); ok {
|
if al, ok := dataBaseCache.get(aliasName); ok {
|
||||||
al.TZ = tz
|
al.TZ = tz
|
||||||
} else {
|
} else {
|
||||||
fmt.Sprintf("DataBase name `%s` not registered\n", aliasName)
|
return fmt.Errorf("DataBase name `%s` not registered\n", aliasName)
|
||||||
os.Exit(2)
|
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Change the max idle conns for *sql.DB, use specify database alias name
|
// Change the max idle conns for *sql.DB, use specify database alias name
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// mysql operators.
|
||||||
var mysqlOperators = map[string]string{
|
var mysqlOperators = map[string]string{
|
||||||
"exact": "= ?",
|
"exact": "= ?",
|
||||||
"iexact": "LIKE ?",
|
"iexact": "LIKE ?",
|
||||||
@ -21,6 +22,7 @@ var mysqlOperators = map[string]string{
|
|||||||
"iendswith": "LIKE ?",
|
"iendswith": "LIKE ?",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mysql column field types.
|
||||||
var mysqlTypes = map[string]string{
|
var mysqlTypes = map[string]string{
|
||||||
"auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY",
|
"auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY",
|
||||||
"pk": "NOT NULL PRIMARY KEY",
|
"pk": "NOT NULL PRIMARY KEY",
|
||||||
@ -41,29 +43,35 @@ var mysqlTypes = map[string]string{
|
|||||||
"float64-decimal": "numeric(%d, %d)",
|
"float64-decimal": "numeric(%d, %d)",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mysql dbBaser implementation.
|
||||||
type dbBaseMysql struct {
|
type dbBaseMysql struct {
|
||||||
dbBase
|
dbBase
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ dbBaser = new(dbBaseMysql)
|
var _ dbBaser = new(dbBaseMysql)
|
||||||
|
|
||||||
|
// get mysql operator.
|
||||||
func (d *dbBaseMysql) OperatorSql(operator string) string {
|
func (d *dbBaseMysql) OperatorSql(operator string) string {
|
||||||
return mysqlOperators[operator]
|
return mysqlOperators[operator]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get mysql table field types.
|
||||||
func (d *dbBaseMysql) DbTypes() map[string]string {
|
func (d *dbBaseMysql) DbTypes() map[string]string {
|
||||||
return mysqlTypes
|
return mysqlTypes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// show table sql for mysql.
|
||||||
func (d *dbBaseMysql) ShowTablesQuery() string {
|
func (d *dbBaseMysql) ShowTablesQuery() string {
|
||||||
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
|
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// show columns sql of table for mysql.
|
||||||
func (d *dbBaseMysql) ShowColumnsQuery(table string) string {
|
func (d *dbBaseMysql) ShowColumnsQuery(table string) string {
|
||||||
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
|
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
|
||||||
"WHERE table_schema = DATABASE() AND table_name = '%s'", table)
|
"WHERE table_schema = DATABASE() AND table_name = '%s'", table)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// execute sql to check index exist.
|
||||||
func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool {
|
func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool {
|
||||||
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
|
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
|
||||||
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
|
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
|
||||||
@ -72,6 +80,7 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool
|
|||||||
return cnt > 0
|
return cnt > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create new mysql dbBaser.
|
||||||
func newdbBaseMysql() dbBaser {
|
func newdbBaseMysql() dbBaser {
|
||||||
b := new(dbBaseMysql)
|
b := new(dbBaseMysql)
|
||||||
b.ins = b
|
b.ins = b
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
package orm
|
package orm
|
||||||
|
|
||||||
|
// oracle dbBaser
|
||||||
type dbBaseOracle struct {
|
type dbBaseOracle struct {
|
||||||
dbBase
|
dbBase
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ dbBaser = new(dbBaseOracle)
|
var _ dbBaser = new(dbBaseOracle)
|
||||||
|
|
||||||
|
// create oracle dbBaser.
|
||||||
func newdbBaseOracle() dbBaser {
|
func newdbBaseOracle() dbBaser {
|
||||||
b := new(dbBaseOracle)
|
b := new(dbBaseOracle)
|
||||||
b.ins = b
|
b.ins = b
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// postgresql operators.
|
||||||
var postgresOperators = map[string]string{
|
var postgresOperators = map[string]string{
|
||||||
"exact": "= ?",
|
"exact": "= ?",
|
||||||
"iexact": "= UPPER(?)",
|
"iexact": "= UPPER(?)",
|
||||||
@ -20,6 +21,7 @@ var postgresOperators = map[string]string{
|
|||||||
"iendswith": "LIKE UPPER(?)",
|
"iendswith": "LIKE UPPER(?)",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// postgresql column field types.
|
||||||
var postgresTypes = map[string]string{
|
var postgresTypes = map[string]string{
|
||||||
"auto": "serial NOT NULL PRIMARY KEY",
|
"auto": "serial NOT NULL PRIMARY KEY",
|
||||||
"pk": "NOT NULL PRIMARY KEY",
|
"pk": "NOT NULL PRIMARY KEY",
|
||||||
@ -40,16 +42,19 @@ var postgresTypes = map[string]string{
|
|||||||
"float64-decimal": "numeric(%d, %d)",
|
"float64-decimal": "numeric(%d, %d)",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// postgresql dbBaser.
|
||||||
type dbBasePostgres struct {
|
type dbBasePostgres struct {
|
||||||
dbBase
|
dbBase
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ dbBaser = new(dbBasePostgres)
|
var _ dbBaser = new(dbBasePostgres)
|
||||||
|
|
||||||
|
// get postgresql operator.
|
||||||
func (d *dbBasePostgres) OperatorSql(operator string) string {
|
func (d *dbBasePostgres) OperatorSql(operator string) string {
|
||||||
return postgresOperators[operator]
|
return postgresOperators[operator]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generate functioned sql string, such as contains(text).
|
||||||
func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
|
func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
|
||||||
switch operator {
|
switch operator {
|
||||||
case "contains", "startswith", "endswith":
|
case "contains", "startswith", "endswith":
|
||||||
@ -59,6 +64,7 @@ func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// postgresql unsupports updating joined record.
|
||||||
func (d *dbBasePostgres) SupportUpdateJoin() bool {
|
func (d *dbBasePostgres) SupportUpdateJoin() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -67,10 +73,13 @@ func (d *dbBasePostgres) MaxLimit() uint64 {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// postgresql quote is ".
|
||||||
func (d *dbBasePostgres) TableQuote() string {
|
func (d *dbBasePostgres) TableQuote() string {
|
||||||
return `"`
|
return `"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// postgresql value placeholder is $n.
|
||||||
|
// replace default ? to $n.
|
||||||
func (d *dbBasePostgres) ReplaceMarks(query *string) {
|
func (d *dbBasePostgres) ReplaceMarks(query *string) {
|
||||||
q := *query
|
q := *query
|
||||||
num := 0
|
num := 0
|
||||||
@ -97,6 +106,7 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
|
|||||||
*query = string(data)
|
*query = string(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// make returning sql support for postgresql.
|
||||||
func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) {
|
func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) {
|
||||||
if mi.fields.pk.auto {
|
if mi.fields.pk.auto {
|
||||||
if query != nil {
|
if query != nil {
|
||||||
@ -107,18 +117,22 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// show table sql for postgresql.
|
||||||
func (d *dbBasePostgres) ShowTablesQuery() string {
|
func (d *dbBasePostgres) ShowTablesQuery() string {
|
||||||
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')"
|
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// show table columns sql for postgresql.
|
||||||
func (d *dbBasePostgres) ShowColumnsQuery(table string) string {
|
func (d *dbBasePostgres) ShowColumnsQuery(table string) string {
|
||||||
return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table)
|
return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get column types of postgresql.
|
||||||
func (d *dbBasePostgres) DbTypes() map[string]string {
|
func (d *dbBasePostgres) DbTypes() map[string]string {
|
||||||
return postgresTypes
|
return postgresTypes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check index exist in postgresql.
|
||||||
func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool {
|
func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool {
|
||||||
query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name)
|
query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name)
|
||||||
row := db.QueryRow(query)
|
row := db.QueryRow(query)
|
||||||
@ -127,6 +141,7 @@ func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bo
|
|||||||
return cnt > 0
|
return cnt > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create new postgresql dbBaser.
|
||||||
func newdbBasePostgres() dbBaser {
|
func newdbBasePostgres() dbBaser {
|
||||||
b := new(dbBasePostgres)
|
b := new(dbBasePostgres)
|
||||||
b.ins = b
|
b.ins = b
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// sqlite operators.
|
||||||
var sqliteOperators = map[string]string{
|
var sqliteOperators = map[string]string{
|
||||||
"exact": "= ?",
|
"exact": "= ?",
|
||||||
"iexact": "LIKE ? ESCAPE '\\'",
|
"iexact": "LIKE ? ESCAPE '\\'",
|
||||||
@ -20,6 +21,7 @@ var sqliteOperators = map[string]string{
|
|||||||
"iendswith": "LIKE ? ESCAPE '\\'",
|
"iendswith": "LIKE ? ESCAPE '\\'",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sqlite column types.
|
||||||
var sqliteTypes = map[string]string{
|
var sqliteTypes = map[string]string{
|
||||||
"auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT",
|
"auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT",
|
||||||
"pk": "NOT NULL PRIMARY KEY",
|
"pk": "NOT NULL PRIMARY KEY",
|
||||||
@ -40,38 +42,47 @@ var sqliteTypes = map[string]string{
|
|||||||
"float64-decimal": "decimal",
|
"float64-decimal": "decimal",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sqlite dbBaser.
|
||||||
type dbBaseSqlite struct {
|
type dbBaseSqlite struct {
|
||||||
dbBase
|
dbBase
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ dbBaser = new(dbBaseSqlite)
|
var _ dbBaser = new(dbBaseSqlite)
|
||||||
|
|
||||||
|
// get sqlite operator.
|
||||||
func (d *dbBaseSqlite) OperatorSql(operator string) string {
|
func (d *dbBaseSqlite) OperatorSql(operator string) string {
|
||||||
return sqliteOperators[operator]
|
return sqliteOperators[operator]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generate functioned sql for sqlite.
|
||||||
|
// only support DATE(text).
|
||||||
func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
|
func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
|
||||||
if fi.fieldType == TypeDateField {
|
if fi.fieldType == TypeDateField {
|
||||||
*leftCol = fmt.Sprintf("DATE(%s)", *leftCol)
|
*leftCol = fmt.Sprintf("DATE(%s)", *leftCol)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// unable updating joined record in sqlite.
|
||||||
func (d *dbBaseSqlite) SupportUpdateJoin() bool {
|
func (d *dbBaseSqlite) SupportUpdateJoin() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// max int in sqlite.
|
||||||
func (d *dbBaseSqlite) MaxLimit() uint64 {
|
func (d *dbBaseSqlite) MaxLimit() uint64 {
|
||||||
return 9223372036854775807
|
return 9223372036854775807
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get column types in sqlite.
|
||||||
func (d *dbBaseSqlite) DbTypes() map[string]string {
|
func (d *dbBaseSqlite) DbTypes() map[string]string {
|
||||||
return sqliteTypes
|
return sqliteTypes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get show tables sql in sqlite.
|
||||||
func (d *dbBaseSqlite) ShowTablesQuery() string {
|
func (d *dbBaseSqlite) ShowTablesQuery() string {
|
||||||
return "SELECT name FROM sqlite_master WHERE type = 'table'"
|
return "SELECT name FROM sqlite_master WHERE type = 'table'"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get columns in sqlite.
|
||||||
func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
|
func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
|
||||||
query := d.ins.ShowColumnsQuery(table)
|
query := d.ins.ShowColumnsQuery(table)
|
||||||
rows, err := db.Query(query)
|
rows, err := db.Query(query)
|
||||||
@ -92,10 +103,12 @@ func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]str
|
|||||||
return columns, nil
|
return columns, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get show columns sql in sqlite.
|
||||||
func (d *dbBaseSqlite) ShowColumnsQuery(table string) string {
|
func (d *dbBaseSqlite) ShowColumnsQuery(table string) string {
|
||||||
return fmt.Sprintf("pragma table_info('%s')", table)
|
return fmt.Sprintf("pragma table_info('%s')", table)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check index exist in sqlite.
|
||||||
func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool {
|
func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool {
|
||||||
query := fmt.Sprintf("PRAGMA index_list('%s')", table)
|
query := fmt.Sprintf("PRAGMA index_list('%s')", table)
|
||||||
rows, err := db.Query(query)
|
rows, err := db.Query(query)
|
||||||
@ -113,6 +126,7 @@ func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create new sqlite dbBaser.
|
||||||
func newdbBaseSqlite() dbBaser {
|
func newdbBaseSqlite() dbBaser {
|
||||||
b := new(dbBaseSqlite)
|
b := new(dbBaseSqlite)
|
||||||
b.ins = b
|
b.ins = b
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// table info struct.
|
||||||
type dbTable struct {
|
type dbTable struct {
|
||||||
id int
|
id int
|
||||||
index string
|
index string
|
||||||
@ -18,13 +19,17 @@ type dbTable struct {
|
|||||||
jtl *dbTable
|
jtl *dbTable
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// tables collection struct, contains some tables.
|
||||||
type dbTables struct {
|
type dbTables struct {
|
||||||
tablesM map[string]*dbTable
|
tablesM map[string]*dbTable
|
||||||
tables []*dbTable
|
tables []*dbTable
|
||||||
mi *modelInfo
|
mi *modelInfo
|
||||||
base dbBaser
|
base dbBaser
|
||||||
|
skipEnd bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set table info to collection.
|
||||||
|
// if not exist, create new.
|
||||||
func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
|
func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
|
||||||
name := strings.Join(names, ExprSep)
|
name := strings.Join(names, ExprSep)
|
||||||
if j, ok := t.tablesM[name]; ok {
|
if j, ok := t.tablesM[name]; ok {
|
||||||
@ -41,6 +46,7 @@ func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool)
|
|||||||
return t.tablesM[name]
|
return t.tablesM[name]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add table info to collection.
|
||||||
func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
|
func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
|
||||||
name := strings.Join(names, ExprSep)
|
name := strings.Join(names, ExprSep)
|
||||||
if _, ok := t.tablesM[name]; ok == false {
|
if _, ok := t.tablesM[name]; ok == false {
|
||||||
@ -53,11 +59,14 @@ func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool)
|
|||||||
return t.tablesM[name], false
|
return t.tablesM[name], false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get table info in collection.
|
||||||
func (t *dbTables) get(name string) (*dbTable, bool) {
|
func (t *dbTables) get(name string) (*dbTable, bool) {
|
||||||
j, ok := t.tablesM[name]
|
j, ok := t.tablesM[name]
|
||||||
return j, ok
|
return j, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get related fields info in recursive depth loop.
|
||||||
|
// loop once, depth decreases one.
|
||||||
func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
|
func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
|
||||||
if depth < 0 || fi.fieldType == RelManyToMany {
|
if depth < 0 || fi.fieldType == RelManyToMany {
|
||||||
return related
|
return related
|
||||||
@ -78,6 +87,7 @@ func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []
|
|||||||
return related
|
return related
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parse related fields.
|
||||||
func (t *dbTables) parseRelated(rels []string, depth int) {
|
func (t *dbTables) parseRelated(rels []string, depth int) {
|
||||||
|
|
||||||
relsNum := len(rels)
|
relsNum := len(rels)
|
||||||
@ -111,7 +121,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
|
|||||||
names = append(names, fi.name)
|
names = append(names, fi.name)
|
||||||
mmi = fi.relModelInfo
|
mmi = fi.relModelInfo
|
||||||
|
|
||||||
if fi.null {
|
if fi.null || t.skipEnd {
|
||||||
inner = false
|
inner = false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -139,6 +149,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generate join string.
|
||||||
func (t *dbTables) getJoinSql() (join string) {
|
func (t *dbTables) getJoinSql() (join string) {
|
||||||
Q := t.base.TableQuote()
|
Q := t.base.TableQuote()
|
||||||
|
|
||||||
@ -185,9 +196,12 @@ func (t *dbTables) getJoinSql() (join string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parse orm model struct field tag expression.
|
||||||
func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) {
|
func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) {
|
||||||
var (
|
var (
|
||||||
jtl *dbTable
|
jtl *dbTable
|
||||||
|
fi *fieldInfo
|
||||||
|
fiN *fieldInfo
|
||||||
mmi = mi
|
mmi = mi
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -196,9 +210,22 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
|
|||||||
|
|
||||||
inner := true
|
inner := true
|
||||||
|
|
||||||
|
loopFor:
|
||||||
for i, ex := range exprs {
|
for i, ex := range exprs {
|
||||||
|
|
||||||
fi, ok := mmi.fields.GetByAny(ex)
|
var ok, okN bool
|
||||||
|
|
||||||
|
if fiN != nil {
|
||||||
|
fi = fiN
|
||||||
|
ok = true
|
||||||
|
fiN = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if i == 0 {
|
||||||
|
fi, ok = mmi.fields.GetByAny(ex)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = okN
|
||||||
|
|
||||||
if ok {
|
if ok {
|
||||||
|
|
||||||
@ -216,17 +243,33 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
|
|||||||
mmi = fi.reverseFieldInfo.mi
|
mmi = fi.reverseFieldInfo.mi
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if i < num {
|
||||||
|
fiN, okN = mmi.fields.GetByAny(exprs[i+1])
|
||||||
|
}
|
||||||
|
|
||||||
if isRel && (fi.mi.isThrough == false || num != i) {
|
if isRel && (fi.mi.isThrough == false || num != i) {
|
||||||
if fi.null {
|
if fi.null || t.skipEnd {
|
||||||
inner = false
|
inner = false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if t.skipEnd && okN || !t.skipEnd {
|
||||||
|
if t.skipEnd && okN && fiN.pk {
|
||||||
|
goto loopEnd
|
||||||
|
}
|
||||||
|
|
||||||
jt, _ := t.add(names, mmi, fi, inner)
|
jt, _ := t.add(names, mmi, fi, inner)
|
||||||
jt.jtl = jtl
|
jt.jtl = jtl
|
||||||
jtl = jt
|
jtl = jt
|
||||||
}
|
}
|
||||||
|
|
||||||
if num == i {
|
}
|
||||||
|
|
||||||
|
if num != i {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
loopEnd:
|
||||||
|
|
||||||
if i == 0 || jtl == nil {
|
if i == 0 || jtl == nil {
|
||||||
index = "T0"
|
index = "T0"
|
||||||
} else {
|
} else {
|
||||||
@ -252,7 +295,8 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
|
|||||||
name = info.name
|
name = info.name
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
break loopFor
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
index = ""
|
index = ""
|
||||||
@ -267,6 +311,7 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generate condition sql.
|
||||||
func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
|
func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
|
||||||
if cond == nil || cond.IsEmpty() {
|
if cond == nil || cond.IsEmpty() {
|
||||||
return
|
return
|
||||||
@ -331,6 +376,7 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generate order sql.
|
||||||
func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
|
func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
|
||||||
if len(orders) == 0 {
|
if len(orders) == 0 {
|
||||||
return
|
return
|
||||||
@ -359,6 +405,7 @@ func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generate limit sql.
|
||||||
func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits string) {
|
func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits string) {
|
||||||
if limit == 0 {
|
if limit == 0 {
|
||||||
limit = int64(DefaultRowsLimit)
|
limit = int64(DefaultRowsLimit)
|
||||||
@ -381,6 +428,7 @@ func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// crete new tables collection.
|
||||||
func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
|
func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
|
||||||
tables := &dbTables{}
|
tables := &dbTables{}
|
||||||
tables.tablesM = make(map[string]*dbTable)
|
tables.tablesM = make(map[string]*dbTable)
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// get table alias.
|
||||||
func getDbAlias(name string) *alias {
|
func getDbAlias(name string) *alias {
|
||||||
if al, ok := dataBaseCache.get(name); ok {
|
if al, ok := dataBaseCache.get(name); ok {
|
||||||
return al
|
return al
|
||||||
@ -15,6 +16,7 @@ func getDbAlias(name string) *alias {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get pk column info.
|
||||||
func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
|
func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
|
||||||
fi := mi.fields.pk
|
fi := mi.fields.pk
|
||||||
|
|
||||||
@ -37,6 +39,7 @@ func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interfac
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get fields description as flatted string.
|
||||||
func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) {
|
func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) {
|
||||||
|
|
||||||
outFor:
|
outFor:
|
||||||
|
@ -41,6 +41,7 @@ var (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// model info collection
|
||||||
type _modelCache struct {
|
type _modelCache struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
orders []string
|
orders []string
|
||||||
@ -49,6 +50,7 @@ type _modelCache struct {
|
|||||||
done bool
|
done bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get all model info
|
||||||
func (mc *_modelCache) all() map[string]*modelInfo {
|
func (mc *_modelCache) all() map[string]*modelInfo {
|
||||||
m := make(map[string]*modelInfo, len(mc.cache))
|
m := make(map[string]*modelInfo, len(mc.cache))
|
||||||
for k, v := range mc.cache {
|
for k, v := range mc.cache {
|
||||||
@ -57,6 +59,7 @@ func (mc *_modelCache) all() map[string]*modelInfo {
|
|||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get orderd model info
|
||||||
func (mc *_modelCache) allOrdered() []*modelInfo {
|
func (mc *_modelCache) allOrdered() []*modelInfo {
|
||||||
m := make([]*modelInfo, 0, len(mc.orders))
|
m := make([]*modelInfo, 0, len(mc.orders))
|
||||||
for _, table := range mc.orders {
|
for _, table := range mc.orders {
|
||||||
@ -65,16 +68,19 @@ func (mc *_modelCache) allOrdered() []*modelInfo {
|
|||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get model info by table name
|
||||||
func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) {
|
func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) {
|
||||||
mi, ok = mc.cache[table]
|
mi, ok = mc.cache[table]
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get model info by field name
|
||||||
func (mc *_modelCache) getByFN(name string) (mi *modelInfo, ok bool) {
|
func (mc *_modelCache) getByFN(name string) (mi *modelInfo, ok bool) {
|
||||||
mi, ok = mc.cacheByFN[name]
|
mi, ok = mc.cacheByFN[name]
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set model info to collection
|
||||||
func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
|
func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
|
||||||
mii := mc.cache[table]
|
mii := mc.cache[table]
|
||||||
mc.cache[table] = mi
|
mc.cache[table] = mi
|
||||||
@ -85,6 +91,7 @@ func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
|
|||||||
return mii
|
return mii
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clean all model info.
|
||||||
func (mc *_modelCache) clean() {
|
func (mc *_modelCache) clean() {
|
||||||
mc.orders = make([]string, 0)
|
mc.orders = make([]string, 0)
|
||||||
mc.cache = make(map[string]*modelInfo)
|
mc.cache = make(map[string]*modelInfo)
|
||||||
|
@ -8,6 +8,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// register models.
|
||||||
|
// prefix means table name prefix.
|
||||||
func registerModel(model interface{}, prefix string) {
|
func registerModel(model interface{}, prefix string) {
|
||||||
val := reflect.ValueOf(model)
|
val := reflect.ValueOf(model)
|
||||||
ind := reflect.Indirect(val)
|
ind := reflect.Indirect(val)
|
||||||
@ -67,6 +69,7 @@ func registerModel(model interface{}, prefix string) {
|
|||||||
modelCache.set(table, info)
|
modelCache.set(table, info)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// boostrap models
|
||||||
func bootStrap() {
|
func bootStrap() {
|
||||||
if modelCache.done {
|
if modelCache.done {
|
||||||
return
|
return
|
||||||
@ -281,6 +284,7 @@ end:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// register models
|
||||||
func RegisterModel(models ...interface{}) {
|
func RegisterModel(models ...interface{}) {
|
||||||
if modelCache.done {
|
if modelCache.done {
|
||||||
panic(fmt.Errorf("RegisterModel must be run before BootStrap"))
|
panic(fmt.Errorf("RegisterModel must be run before BootStrap"))
|
||||||
@ -302,6 +306,8 @@ func RegisterModelWithPrefix(prefix string, models ...interface{}) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// bootrap models.
|
||||||
|
// make all model parsed and can not add more models
|
||||||
func BootStrap() {
|
func BootStrap() {
|
||||||
if modelCache.done {
|
if modelCache.done {
|
||||||
return
|
return
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
var errSkipField = errors.New("skip field")
|
var errSkipField = errors.New("skip field")
|
||||||
|
|
||||||
|
// field info collection
|
||||||
type fields struct {
|
type fields struct {
|
||||||
pk *fieldInfo
|
pk *fieldInfo
|
||||||
columns map[string]*fieldInfo
|
columns map[string]*fieldInfo
|
||||||
@ -23,6 +24,7 @@ type fields struct {
|
|||||||
dbcols []string
|
dbcols []string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add field info
|
||||||
func (f *fields) Add(fi *fieldInfo) (added bool) {
|
func (f *fields) Add(fi *fieldInfo) (added bool) {
|
||||||
if f.fields[fi.name] == nil && f.columns[fi.column] == nil {
|
if f.fields[fi.name] == nil && f.columns[fi.column] == nil {
|
||||||
f.columns[fi.column] = fi
|
f.columns[fi.column] = fi
|
||||||
@ -49,14 +51,17 @@ func (f *fields) Add(fi *fieldInfo) (added bool) {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get field info by name
|
||||||
func (f *fields) GetByName(name string) *fieldInfo {
|
func (f *fields) GetByName(name string) *fieldInfo {
|
||||||
return f.fields[name]
|
return f.fields[name]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get field info by column name
|
||||||
func (f *fields) GetByColumn(column string) *fieldInfo {
|
func (f *fields) GetByColumn(column string) *fieldInfo {
|
||||||
return f.columns[column]
|
return f.columns[column]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get field info by string, name is prior
|
||||||
func (f *fields) GetByAny(name string) (*fieldInfo, bool) {
|
func (f *fields) GetByAny(name string) (*fieldInfo, bool) {
|
||||||
if fi, ok := f.fields[name]; ok {
|
if fi, ok := f.fields[name]; ok {
|
||||||
return fi, ok
|
return fi, ok
|
||||||
@ -70,6 +75,7 @@ func (f *fields) GetByAny(name string) (*fieldInfo, bool) {
|
|||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create new field info collection
|
||||||
func newFields() *fields {
|
func newFields() *fields {
|
||||||
f := new(fields)
|
f := new(fields)
|
||||||
f.fields = make(map[string]*fieldInfo)
|
f.fields = make(map[string]*fieldInfo)
|
||||||
@ -79,6 +85,7 @@ func newFields() *fields {
|
|||||||
return f
|
return f
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// single field info
|
||||||
type fieldInfo struct {
|
type fieldInfo struct {
|
||||||
mi *modelInfo
|
mi *modelInfo
|
||||||
fieldIndex int
|
fieldIndex int
|
||||||
@ -115,6 +122,7 @@ type fieldInfo struct {
|
|||||||
onDelete string
|
onDelete string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// new field info
|
||||||
func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField) (fi *fieldInfo, err error) {
|
func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField) (fi *fieldInfo, err error) {
|
||||||
var (
|
var (
|
||||||
tag string
|
tag string
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// single model info
|
||||||
type modelInfo struct {
|
type modelInfo struct {
|
||||||
pkg string
|
pkg string
|
||||||
name string
|
name string
|
||||||
@ -20,6 +21,7 @@ type modelInfo struct {
|
|||||||
isThrough bool
|
isThrough bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// new model info
|
||||||
func newModelInfo(val reflect.Value) (info *modelInfo) {
|
func newModelInfo(val reflect.Value) (info *modelInfo) {
|
||||||
var (
|
var (
|
||||||
err error
|
err error
|
||||||
@ -79,6 +81,8 @@ func newModelInfo(val reflect.Value) (info *modelInfo) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// combine related model info to new model info.
|
||||||
|
// prepare for relation models query.
|
||||||
func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
|
func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
|
||||||
info = new(modelInfo)
|
info = new(modelInfo)
|
||||||
info.fields = newFields()
|
info.fields = newFields()
|
||||||
|
@ -7,10 +7,12 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// get reflect.Type name with package path.
|
||||||
func getFullName(typ reflect.Type) string {
|
func getFullName(typ reflect.Type) string {
|
||||||
return typ.PkgPath() + "." + typ.Name()
|
return typ.PkgPath() + "." + typ.Name()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get table name. method, or field name. auto snaked.
|
||||||
func getTableName(val reflect.Value) string {
|
func getTableName(val reflect.Value) string {
|
||||||
ind := reflect.Indirect(val)
|
ind := reflect.Indirect(val)
|
||||||
fun := val.MethodByName("TableName")
|
fun := val.MethodByName("TableName")
|
||||||
@ -26,6 +28,7 @@ func getTableName(val reflect.Value) string {
|
|||||||
return snakeString(ind.Type().Name())
|
return snakeString(ind.Type().Name())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get table engine, mysiam or innodb.
|
||||||
func getTableEngine(val reflect.Value) string {
|
func getTableEngine(val reflect.Value) string {
|
||||||
fun := val.MethodByName("TableEngine")
|
fun := val.MethodByName("TableEngine")
|
||||||
if fun.IsValid() {
|
if fun.IsValid() {
|
||||||
@ -40,6 +43,7 @@ func getTableEngine(val reflect.Value) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get table index from method.
|
||||||
func getTableIndex(val reflect.Value) [][]string {
|
func getTableIndex(val reflect.Value) [][]string {
|
||||||
fun := val.MethodByName("TableIndex")
|
fun := val.MethodByName("TableIndex")
|
||||||
if fun.IsValid() {
|
if fun.IsValid() {
|
||||||
@ -56,6 +60,7 @@ func getTableIndex(val reflect.Value) [][]string {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get table unique from method
|
||||||
func getTableUnique(val reflect.Value) [][]string {
|
func getTableUnique(val reflect.Value) [][]string {
|
||||||
fun := val.MethodByName("TableUnique")
|
fun := val.MethodByName("TableUnique")
|
||||||
if fun.IsValid() {
|
if fun.IsValid() {
|
||||||
@ -72,6 +77,7 @@ func getTableUnique(val reflect.Value) [][]string {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get snaked column name
|
||||||
func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string {
|
func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string {
|
||||||
col = strings.ToLower(col)
|
col = strings.ToLower(col)
|
||||||
column := col
|
column := col
|
||||||
@ -89,6 +95,7 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col
|
|||||||
return column
|
return column
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// return field type as type constant from reflect.Value
|
||||||
func getFieldType(val reflect.Value) (ft int, err error) {
|
func getFieldType(val reflect.Value) (ft int, err error) {
|
||||||
elm := reflect.Indirect(val)
|
elm := reflect.Indirect(val)
|
||||||
switch elm.Kind() {
|
switch elm.Kind() {
|
||||||
@ -128,6 +135,7 @@ func getFieldType(val reflect.Value) (ft int, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parse struct tag string
|
||||||
func parseStructTag(data string, attrs *map[string]bool, tags *map[string]string) {
|
func parseStructTag(data string, attrs *map[string]bool, tags *map[string]string) {
|
||||||
attr := make(map[string]bool)
|
attr := make(map[string]bool)
|
||||||
tag := make(map[string]string)
|
tag := make(map[string]string)
|
||||||
|
151
orm/orm.go
151
orm/orm.go
@ -25,6 +25,7 @@ var (
|
|||||||
ErrMultiRows = errors.New("<QuerySeter> return multi rows")
|
ErrMultiRows = errors.New("<QuerySeter> return multi rows")
|
||||||
ErrNoRows = errors.New("<QuerySeter> no row found")
|
ErrNoRows = errors.New("<QuerySeter> no row found")
|
||||||
ErrStmtClosed = errors.New("<QuerySeter> stmt already closed")
|
ErrStmtClosed = errors.New("<QuerySeter> stmt already closed")
|
||||||
|
ErrArgs = errors.New("<Ormer> args error may be empty")
|
||||||
ErrNotImplement = errors.New("have not implement")
|
ErrNotImplement = errors.New("have not implement")
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -39,11 +40,12 @@ type orm struct {
|
|||||||
|
|
||||||
var _ Ormer = new(orm)
|
var _ Ormer = new(orm)
|
||||||
|
|
||||||
func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) {
|
// get model info and model reflect value
|
||||||
|
func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) {
|
||||||
val := reflect.ValueOf(md)
|
val := reflect.ValueOf(md)
|
||||||
ind = reflect.Indirect(val)
|
ind = reflect.Indirect(val)
|
||||||
typ := ind.Type()
|
typ := ind.Type()
|
||||||
if val.Kind() != reflect.Ptr {
|
if needPtr && val.Kind() != reflect.Ptr {
|
||||||
panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ)))
|
panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ)))
|
||||||
}
|
}
|
||||||
name := getFullName(typ)
|
name := getFullName(typ)
|
||||||
@ -53,6 +55,7 @@ func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) {
|
|||||||
panic(fmt.Errorf("<Ormer> table: `%s` not found, maybe not RegisterModel", name))
|
panic(fmt.Errorf("<Ormer> table: `%s` not found, maybe not RegisterModel", name))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get field info from model info by given field name
|
||||||
func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
|
func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
|
||||||
fi, ok := mi.fields.GetByAny(name)
|
fi, ok := mi.fields.GetByAny(name)
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -61,8 +64,9 @@ func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
|
|||||||
return fi
|
return fi
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// read data to model
|
||||||
func (o *orm) Read(md interface{}, cols ...string) error {
|
func (o *orm) Read(md interface{}, cols ...string) error {
|
||||||
mi, ind := o.getMiInd(md)
|
mi, ind := o.getMiInd(md, true)
|
||||||
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols)
|
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -70,13 +74,35 @@ func (o *orm) Read(md interface{}, cols ...string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Try to read a row from the database, or insert one if it doesn't exist
|
||||||
|
func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) {
|
||||||
|
cols = append([]string{col1}, cols...)
|
||||||
|
mi, ind := o.getMiInd(md, true)
|
||||||
|
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols)
|
||||||
|
if err == ErrNoRows {
|
||||||
|
// Create
|
||||||
|
id, err := o.Insert(md)
|
||||||
|
return (err == nil), id, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, ind.Field(mi.fields.pk.fieldIndex).Int(), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// insert model data to database
|
||||||
func (o *orm) Insert(md interface{}) (int64, error) {
|
func (o *orm) Insert(md interface{}) (int64, error) {
|
||||||
mi, ind := o.getMiInd(md)
|
mi, ind := o.getMiInd(md, true)
|
||||||
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
|
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return id, err
|
return id, err
|
||||||
}
|
}
|
||||||
if id > 0 {
|
|
||||||
|
o.setPk(mi, ind, id)
|
||||||
|
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// set auto pk field
|
||||||
|
func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) {
|
||||||
if mi.fields.pk.auto {
|
if mi.fields.pk.auto {
|
||||||
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
|
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
|
||||||
ind.Field(mi.fields.pk.fieldIndex).SetUint(uint64(id))
|
ind.Field(mi.fields.pk.fieldIndex).SetUint(uint64(id))
|
||||||
@ -85,11 +111,46 @@ func (o *orm) Insert(md interface{}) (int64, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return id, nil
|
|
||||||
|
// insert some models to database
|
||||||
|
func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) {
|
||||||
|
var cnt int64
|
||||||
|
|
||||||
|
sind := reflect.Indirect(reflect.ValueOf(mds))
|
||||||
|
|
||||||
|
switch sind.Kind() {
|
||||||
|
case reflect.Array, reflect.Slice:
|
||||||
|
if sind.Len() == 0 {
|
||||||
|
return cnt, ErrArgs
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return cnt, ErrArgs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if bulk <= 1 {
|
||||||
|
for i := 0; i < sind.Len(); i++ {
|
||||||
|
ind := sind.Index(i)
|
||||||
|
mi, _ := o.getMiInd(ind.Interface(), false)
|
||||||
|
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
|
||||||
|
if err != nil {
|
||||||
|
return cnt, err
|
||||||
|
}
|
||||||
|
|
||||||
|
o.setPk(mi, ind, id)
|
||||||
|
|
||||||
|
cnt += 1
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
|
||||||
|
return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ)
|
||||||
|
}
|
||||||
|
return cnt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// update model to database.
|
||||||
|
// cols set the columns those want to update.
|
||||||
func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
|
func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
|
||||||
mi, ind := o.getMiInd(md)
|
mi, ind := o.getMiInd(md, true)
|
||||||
num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
|
num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return num, err
|
return num, err
|
||||||
@ -97,26 +158,22 @@ func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
|
|||||||
return num, nil
|
return num, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// delete model in database
|
||||||
func (o *orm) Delete(md interface{}) (int64, error) {
|
func (o *orm) Delete(md interface{}) (int64, error) {
|
||||||
mi, ind := o.getMiInd(md)
|
mi, ind := o.getMiInd(md, true)
|
||||||
num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ)
|
num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return num, err
|
return num, err
|
||||||
}
|
}
|
||||||
if num > 0 {
|
if num > 0 {
|
||||||
if mi.fields.pk.auto {
|
o.setPk(mi, ind, 0)
|
||||||
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
|
|
||||||
ind.Field(mi.fields.pk.fieldIndex).SetUint(0)
|
|
||||||
} else {
|
|
||||||
ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return num, nil
|
return num, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create a models to models queryer
|
||||||
func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer {
|
func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer {
|
||||||
mi, ind := o.getMiInd(md)
|
mi, ind := o.getMiInd(md, true)
|
||||||
fi := o.getFieldInfo(mi, name)
|
fi := o.getFieldInfo(mi, name)
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
@ -129,6 +186,14 @@ func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer {
|
|||||||
return newQueryM2M(md, o, mi, fi, ind)
|
return newQueryM2M(md, o, mi, fi, ind)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// load related models to md model.
|
||||||
|
// args are limit, offset int and order string.
|
||||||
|
//
|
||||||
|
// example:
|
||||||
|
// orm.LoadRelated(post,"Tags")
|
||||||
|
// for _,tag := range post.Tags{...}
|
||||||
|
//
|
||||||
|
// make sure the relation is defined in model struct tags.
|
||||||
func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) {
|
func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) {
|
||||||
_, fi, ind, qseter := o.queryRelated(md, name)
|
_, fi, ind, qseter := o.queryRelated(md, name)
|
||||||
|
|
||||||
@ -190,14 +255,21 @@ func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int
|
|||||||
return nums, err
|
return nums, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// return a QuerySeter for related models to md model.
|
||||||
|
// it can do all, update, delete in QuerySeter.
|
||||||
|
// example:
|
||||||
|
// qs := orm.QueryRelated(post,"Tag")
|
||||||
|
// qs.All(&[]*Tag{})
|
||||||
|
//
|
||||||
func (o *orm) QueryRelated(md interface{}, name string) QuerySeter {
|
func (o *orm) QueryRelated(md interface{}, name string) QuerySeter {
|
||||||
// is this api needed ?
|
// is this api needed ?
|
||||||
_, _, _, qs := o.queryRelated(md, name)
|
_, _, _, qs := o.queryRelated(md, name)
|
||||||
return qs
|
return qs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get QuerySeter for related models to md model
|
||||||
func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) {
|
func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) {
|
||||||
mi, ind := o.getMiInd(md)
|
mi, ind := o.getMiInd(md, true)
|
||||||
fi := o.getFieldInfo(mi, name)
|
fi := o.getFieldInfo(mi, name)
|
||||||
|
|
||||||
_, _, exist := getExistPk(mi, ind)
|
_, _, exist := getExistPk(mi, ind)
|
||||||
@ -227,6 +299,7 @@ func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo,
|
|||||||
return mi, fi, ind, qs
|
return mi, fi, ind, qs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get reverse relation QuerySeter
|
||||||
func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
|
func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
|
||||||
switch fi.fieldType {
|
switch fi.fieldType {
|
||||||
case RelReverseOne, RelReverseMany:
|
case RelReverseOne, RelReverseMany:
|
||||||
@ -247,6 +320,7 @@ func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS
|
|||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get relation QuerySeter
|
||||||
func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
|
func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
|
||||||
switch fi.fieldType {
|
switch fi.fieldType {
|
||||||
case RelOneToOne, RelForeignKey, RelManyToMany:
|
case RelOneToOne, RelForeignKey, RelManyToMany:
|
||||||
@ -266,6 +340,9 @@ func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
|
|||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// return a QuerySeter for table operations.
|
||||||
|
// table name can be string or struct.
|
||||||
|
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
|
||||||
func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
|
func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
|
||||||
name := ""
|
name := ""
|
||||||
if table, ok := ptrStructOrTableName.(string); ok {
|
if table, ok := ptrStructOrTableName.(string); ok {
|
||||||
@ -285,6 +362,7 @@ func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// switch to another registered database driver by given name.
|
||||||
func (o *orm) Using(name string) error {
|
func (o *orm) Using(name string) error {
|
||||||
if o.isTx {
|
if o.isTx {
|
||||||
panic(fmt.Errorf("<Ormer.Using> transaction has been start, cannot change db"))
|
panic(fmt.Errorf("<Ormer.Using> transaction has been start, cannot change db"))
|
||||||
@ -302,6 +380,7 @@ func (o *orm) Using(name string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// begin transaction
|
||||||
func (o *orm) Begin() error {
|
func (o *orm) Begin() error {
|
||||||
if o.isTx {
|
if o.isTx {
|
||||||
return ErrTxHasBegan
|
return ErrTxHasBegan
|
||||||
@ -320,6 +399,7 @@ func (o *orm) Begin() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// commit transaction
|
||||||
func (o *orm) Commit() error {
|
func (o *orm) Commit() error {
|
||||||
if o.isTx == false {
|
if o.isTx == false {
|
||||||
return ErrTxDone
|
return ErrTxDone
|
||||||
@ -334,6 +414,7 @@ func (o *orm) Commit() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// rollback transaction
|
||||||
func (o *orm) Rollback() error {
|
func (o *orm) Rollback() error {
|
||||||
if o.isTx == false {
|
if o.isTx == false {
|
||||||
return ErrTxDone
|
return ErrTxDone
|
||||||
@ -348,14 +429,23 @@ func (o *orm) Rollback() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// return a raw query seter for raw sql string.
|
||||||
func (o *orm) Raw(query string, args ...interface{}) RawSeter {
|
func (o *orm) Raw(query string, args ...interface{}) RawSeter {
|
||||||
return newRawSet(o, query, args)
|
return newRawSet(o, query, args)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// return current using database Driver
|
||||||
func (o *orm) Driver() Driver {
|
func (o *orm) Driver() Driver {
|
||||||
return driver(o.alias.Name)
|
return driver(o.alias.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *orm) GetDB() dbQuerier {
|
||||||
|
panic(ErrNotImplement)
|
||||||
|
// not enough
|
||||||
|
return o.db
|
||||||
|
}
|
||||||
|
|
||||||
|
// create new orm
|
||||||
func NewOrm() Ormer {
|
func NewOrm() Ormer {
|
||||||
BootStrap() // execute only once
|
BootStrap() // execute only once
|
||||||
|
|
||||||
@ -366,3 +456,30 @@ func NewOrm() Ormer {
|
|||||||
}
|
}
|
||||||
return o
|
return o
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create a new ormer object with specify *sql.DB for query
|
||||||
|
func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) {
|
||||||
|
var al *alias
|
||||||
|
|
||||||
|
if dr, ok := drivers[driverName]; ok {
|
||||||
|
al = new(alias)
|
||||||
|
al.DbBaser = dbBasers[dr]
|
||||||
|
al.Driver = dr
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
|
||||||
|
}
|
||||||
|
|
||||||
|
al.Name = aliasName
|
||||||
|
al.DriverName = driverName
|
||||||
|
|
||||||
|
o := new(orm)
|
||||||
|
o.alias = al
|
||||||
|
|
||||||
|
if Debug {
|
||||||
|
o.db = newDbQueryLog(o.alias, db)
|
||||||
|
} else {
|
||||||
|
o.db = db
|
||||||
|
}
|
||||||
|
|
||||||
|
return o, nil
|
||||||
|
}
|
||||||
|
@ -18,15 +18,19 @@ type condValue struct {
|
|||||||
isCond bool
|
isCond bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// condition struct.
|
||||||
|
// work for WHERE conditions.
|
||||||
type Condition struct {
|
type Condition struct {
|
||||||
params []condValue
|
params []condValue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// return new condition struct
|
||||||
func NewCondition() *Condition {
|
func NewCondition() *Condition {
|
||||||
c := &Condition{}
|
c := &Condition{}
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add expression to condition
|
||||||
func (c Condition) And(expr string, args ...interface{}) *Condition {
|
func (c Condition) And(expr string, args ...interface{}) *Condition {
|
||||||
if expr == "" || len(args) == 0 {
|
if expr == "" || len(args) == 0 {
|
||||||
panic(fmt.Errorf("<Condition.And> args cannot empty"))
|
panic(fmt.Errorf("<Condition.And> args cannot empty"))
|
||||||
@ -35,6 +39,7 @@ func (c Condition) And(expr string, args ...interface{}) *Condition {
|
|||||||
return &c
|
return &c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add NOT expression to condition
|
||||||
func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
|
func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
|
||||||
if expr == "" || len(args) == 0 {
|
if expr == "" || len(args) == 0 {
|
||||||
panic(fmt.Errorf("<Condition.AndNot> args cannot empty"))
|
panic(fmt.Errorf("<Condition.AndNot> args cannot empty"))
|
||||||
@ -43,6 +48,7 @@ func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
|
|||||||
return &c
|
return &c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// combine a condition to current condition
|
||||||
func (c *Condition) AndCond(cond *Condition) *Condition {
|
func (c *Condition) AndCond(cond *Condition) *Condition {
|
||||||
c = c.clone()
|
c = c.clone()
|
||||||
if c == cond {
|
if c == cond {
|
||||||
@ -54,6 +60,7 @@ func (c *Condition) AndCond(cond *Condition) *Condition {
|
|||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add OR expression to condition
|
||||||
func (c Condition) Or(expr string, args ...interface{}) *Condition {
|
func (c Condition) Or(expr string, args ...interface{}) *Condition {
|
||||||
if expr == "" || len(args) == 0 {
|
if expr == "" || len(args) == 0 {
|
||||||
panic(fmt.Errorf("<Condition.Or> args cannot empty"))
|
panic(fmt.Errorf("<Condition.Or> args cannot empty"))
|
||||||
@ -62,6 +69,7 @@ func (c Condition) Or(expr string, args ...interface{}) *Condition {
|
|||||||
return &c
|
return &c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add OR NOT expression to condition
|
||||||
func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
|
func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
|
||||||
if expr == "" || len(args) == 0 {
|
if expr == "" || len(args) == 0 {
|
||||||
panic(fmt.Errorf("<Condition.OrNot> args cannot empty"))
|
panic(fmt.Errorf("<Condition.OrNot> args cannot empty"))
|
||||||
@ -70,6 +78,7 @@ func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
|
|||||||
return &c
|
return &c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// combine a OR condition to current condition
|
||||||
func (c *Condition) OrCond(cond *Condition) *Condition {
|
func (c *Condition) OrCond(cond *Condition) *Condition {
|
||||||
c = c.clone()
|
c = c.clone()
|
||||||
if c == cond {
|
if c == cond {
|
||||||
@ -81,10 +90,12 @@ func (c *Condition) OrCond(cond *Condition) *Condition {
|
|||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check the condition arguments are empty or not.
|
||||||
func (c *Condition) IsEmpty() bool {
|
func (c *Condition) IsEmpty() bool {
|
||||||
return len(c.params) == 0
|
return len(c.params) == 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clone a condition
|
||||||
func (c Condition) clone() *Condition {
|
func (c Condition) clone() *Condition {
|
||||||
return &c
|
return &c
|
||||||
}
|
}
|
||||||
|
@ -13,6 +13,7 @@ type Log struct {
|
|||||||
*log.Logger
|
*log.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set io.Writer to create a Logger.
|
||||||
func NewLog(out io.Writer) *Log {
|
func NewLog(out io.Writer) *Log {
|
||||||
d := new(Log)
|
d := new(Log)
|
||||||
d.Logger = log.New(out, "[ORM]", 1e9)
|
d.Logger = log.New(out, "[ORM]", 1e9)
|
||||||
@ -40,6 +41,8 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error
|
|||||||
DebugLog.Println(con)
|
DebugLog.Println(con)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// statement query logger struct.
|
||||||
|
// if dev mode, use stmtQueryLog, or use stmtQuerier.
|
||||||
type stmtQueryLog struct {
|
type stmtQueryLog struct {
|
||||||
alias *alias
|
alias *alias
|
||||||
query string
|
query string
|
||||||
@ -84,6 +87,8 @@ func newStmtQueryLog(alias *alias, stmt stmtQuerier, query string) stmtQuerier {
|
|||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// database query logger struct.
|
||||||
|
// if dev mode, use dbQueryLog, or use dbQuerier.
|
||||||
type dbQueryLog struct {
|
type dbQueryLog struct {
|
||||||
alias *alias
|
alias *alias
|
||||||
db dbQuerier
|
db dbQuerier
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// an insert queryer struct
|
||||||
type insertSet struct {
|
type insertSet struct {
|
||||||
mi *modelInfo
|
mi *modelInfo
|
||||||
orm *orm
|
orm *orm
|
||||||
@ -14,6 +15,7 @@ type insertSet struct {
|
|||||||
|
|
||||||
var _ Inserter = new(insertSet)
|
var _ Inserter = new(insertSet)
|
||||||
|
|
||||||
|
// insert model ignore it's registered or not.
|
||||||
func (o *insertSet) Insert(md interface{}) (int64, error) {
|
func (o *insertSet) Insert(md interface{}) (int64, error) {
|
||||||
if o.closed {
|
if o.closed {
|
||||||
return 0, ErrStmtClosed
|
return 0, ErrStmtClosed
|
||||||
@ -44,6 +46,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) {
|
|||||||
return id, nil
|
return id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// close insert queryer statement
|
||||||
func (o *insertSet) Close() error {
|
func (o *insertSet) Close() error {
|
||||||
if o.closed {
|
if o.closed {
|
||||||
return ErrStmtClosed
|
return ErrStmtClosed
|
||||||
@ -52,6 +55,7 @@ func (o *insertSet) Close() error {
|
|||||||
return o.stmt.Close()
|
return o.stmt.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create new insert queryer.
|
||||||
func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) {
|
func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) {
|
||||||
bi := new(insertSet)
|
bi := new(insertSet)
|
||||||
bi.orm = orm
|
bi.orm = orm
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// model to model struct
|
||||||
type queryM2M struct {
|
type queryM2M struct {
|
||||||
md interface{}
|
md interface{}
|
||||||
mi *modelInfo
|
mi *modelInfo
|
||||||
@ -12,6 +13,13 @@ type queryM2M struct {
|
|||||||
ind reflect.Value
|
ind reflect.Value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add models to origin models when creating queryM2M.
|
||||||
|
// example:
|
||||||
|
// m2m := orm.QueryM2M(post,"Tag")
|
||||||
|
// m2m.Add(&Tag1{},&Tag2{})
|
||||||
|
// for _,tag := range post.Tags{}
|
||||||
|
//
|
||||||
|
// make sure the relation is defined in post model struct tag.
|
||||||
func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
|
func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
|
||||||
fi := o.fi
|
fi := o.fi
|
||||||
mi := fi.relThroughModelInfo
|
mi := fi.relThroughModelInfo
|
||||||
@ -44,7 +52,8 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
|
|||||||
|
|
||||||
names := []string{mfi.column, rfi.column}
|
names := []string{mfi.column, rfi.column}
|
||||||
|
|
||||||
var nums int64
|
values := make([]interface{}, 0, len(models)*2)
|
||||||
|
|
||||||
for _, md := range models {
|
for _, md := range models {
|
||||||
|
|
||||||
ind := reflect.Indirect(reflect.ValueOf(md))
|
ind := reflect.Indirect(reflect.ValueOf(md))
|
||||||
@ -59,18 +68,14 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
values := []interface{}{v1, v2}
|
values = append(values, v1, v2)
|
||||||
_, err := dbase.InsertValue(orm.db, mi, names, values)
|
|
||||||
if err != nil {
|
|
||||||
return nums, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
nums += 1
|
return dbase.InsertValue(orm.db, mi, true, names, values)
|
||||||
}
|
|
||||||
|
|
||||||
return nums, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// remove models following the origin model relationship
|
||||||
func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
|
func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
|
||||||
fi := o.fi
|
fi := o.fi
|
||||||
qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md)
|
qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md)
|
||||||
@ -82,17 +87,20 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
|
|||||||
return nums, nil
|
return nums, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check model is existed in relationship of origin model
|
||||||
func (o *queryM2M) Exist(md interface{}) bool {
|
func (o *queryM2M) Exist(md interface{}) bool {
|
||||||
fi := o.fi
|
fi := o.fi
|
||||||
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).
|
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).
|
||||||
Filter(fi.reverseFieldInfoTwo.name, md).Exist()
|
Filter(fi.reverseFieldInfoTwo.name, md).Exist()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clean all models in related of origin model
|
||||||
func (o *queryM2M) Clear() (int64, error) {
|
func (o *queryM2M) Clear() (int64, error) {
|
||||||
fi := o.fi
|
fi := o.fi
|
||||||
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete()
|
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// count all related models of origin model
|
||||||
func (o *queryM2M) Count() (int64, error) {
|
func (o *queryM2M) Count() (int64, error) {
|
||||||
fi := o.fi
|
fi := o.fi
|
||||||
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count()
|
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count()
|
||||||
@ -100,6 +108,7 @@ func (o *queryM2M) Count() (int64, error) {
|
|||||||
|
|
||||||
var _ QueryM2Mer = new(queryM2M)
|
var _ QueryM2Mer = new(queryM2M)
|
||||||
|
|
||||||
|
// create new M2M queryer.
|
||||||
func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer {
|
func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer {
|
||||||
qm2m := new(queryM2M)
|
qm2m := new(queryM2M)
|
||||||
qm2m.md = md
|
qm2m.md = md
|
||||||
|
@ -18,6 +18,10 @@ const (
|
|||||||
Col_Except
|
Col_Except
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ColValue do the field raw changes. e.g Nums = Nums + 10. usage:
|
||||||
|
// Params{
|
||||||
|
// "Nums": ColValue(Col_Add, 10),
|
||||||
|
// }
|
||||||
func ColValue(opt operator, value interface{}) interface{} {
|
func ColValue(opt operator, value interface{}) interface{} {
|
||||||
switch opt {
|
switch opt {
|
||||||
case Col_Add, Col_Minus, Col_Multiply, Col_Except:
|
case Col_Add, Col_Minus, Col_Multiply, Col_Except:
|
||||||
@ -34,6 +38,7 @@ func ColValue(opt operator, value interface{}) interface{} {
|
|||||||
return val
|
return val
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// real query struct
|
||||||
type querySet struct {
|
type querySet struct {
|
||||||
mi *modelInfo
|
mi *modelInfo
|
||||||
cond *Condition
|
cond *Condition
|
||||||
@ -47,6 +52,7 @@ type querySet struct {
|
|||||||
|
|
||||||
var _ QuerySeter = new(querySet)
|
var _ QuerySeter = new(querySet)
|
||||||
|
|
||||||
|
// add condition expression to QuerySeter.
|
||||||
func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
|
func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
|
||||||
if o.cond == nil {
|
if o.cond == nil {
|
||||||
o.cond = NewCondition()
|
o.cond = NewCondition()
|
||||||
@ -55,6 +61,7 @@ func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
|
|||||||
return &o
|
return &o
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add NOT condition to querySeter.
|
||||||
func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
|
func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
|
||||||
if o.cond == nil {
|
if o.cond == nil {
|
||||||
o.cond = NewCondition()
|
o.cond = NewCondition()
|
||||||
@ -63,10 +70,13 @@ func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
|
|||||||
return &o
|
return &o
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set offset number
|
||||||
func (o *querySet) setOffset(num interface{}) {
|
func (o *querySet) setOffset(num interface{}) {
|
||||||
o.offset = ToInt64(num)
|
o.offset = ToInt64(num)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add LIMIT value.
|
||||||
|
// args[0] means offset, e.g. LIMIT num,offset.
|
||||||
func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter {
|
func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter {
|
||||||
o.limit = ToInt64(limit)
|
o.limit = ToInt64(limit)
|
||||||
if len(args) > 0 {
|
if len(args) > 0 {
|
||||||
@ -75,16 +85,21 @@ func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter {
|
|||||||
return &o
|
return &o
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add OFFSET value
|
||||||
func (o querySet) Offset(offset interface{}) QuerySeter {
|
func (o querySet) Offset(offset interface{}) QuerySeter {
|
||||||
o.setOffset(offset)
|
o.setOffset(offset)
|
||||||
return &o
|
return &o
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add ORDER expression.
|
||||||
|
// "column" means ASC, "-column" means DESC.
|
||||||
func (o querySet) OrderBy(exprs ...string) QuerySeter {
|
func (o querySet) OrderBy(exprs ...string) QuerySeter {
|
||||||
o.orders = exprs
|
o.orders = exprs
|
||||||
return &o
|
return &o
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set relation model to query together.
|
||||||
|
// it will query relation models and assign to parent model.
|
||||||
func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
|
func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
|
||||||
var related []string
|
var related []string
|
||||||
if len(params) == 0 {
|
if len(params) == 0 {
|
||||||
@ -105,36 +120,50 @@ func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
|
|||||||
return &o
|
return &o
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set condition to QuerySeter.
|
||||||
func (o querySet) SetCond(cond *Condition) QuerySeter {
|
func (o querySet) SetCond(cond *Condition) QuerySeter {
|
||||||
o.cond = cond
|
o.cond = cond
|
||||||
return &o
|
return &o
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// return QuerySeter execution result number
|
||||||
func (o *querySet) Count() (int64, error) {
|
func (o *querySet) Count() (int64, error) {
|
||||||
return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check result empty or not after QuerySeter executed
|
||||||
func (o *querySet) Exist() bool {
|
func (o *querySet) Exist() bool {
|
||||||
cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
||||||
return cnt > 0
|
return cnt > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// execute update with parameters
|
||||||
func (o *querySet) Update(values Params) (int64, error) {
|
func (o *querySet) Update(values Params) (int64, error) {
|
||||||
return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ)
|
return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// execute delete
|
||||||
func (o *querySet) Delete() (int64, error) {
|
func (o *querySet) Delete() (int64, error) {
|
||||||
return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// return a insert queryer.
|
||||||
|
// it can be used in times.
|
||||||
|
// example:
|
||||||
|
// i,err := sq.PrepareInsert()
|
||||||
|
// i.Add(&user1{},&user2{})
|
||||||
func (o *querySet) PrepareInsert() (Inserter, error) {
|
func (o *querySet) PrepareInsert() (Inserter, error) {
|
||||||
return newInsertSet(o.orm, o.mi)
|
return newInsertSet(o.orm, o.mi)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// query all data and map to containers.
|
||||||
|
// cols means the columns when querying.
|
||||||
func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
|
func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
|
||||||
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
|
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// query one row data and map to containers.
|
||||||
|
// cols means the columns when querying.
|
||||||
func (o *querySet) One(container interface{}, cols ...string) error {
|
func (o *querySet) One(container interface{}, cols ...string) error {
|
||||||
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
|
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -149,18 +178,56 @@ func (o *querySet) One(container interface{}, cols ...string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// query all data and map to []map[string]interface.
|
||||||
|
// expres means condition expression.
|
||||||
|
// it converts data to []map[column]value.
|
||||||
func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) {
|
func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) {
|
||||||
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
|
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// query all data and map to [][]interface
|
||||||
|
// it converts data to [][column_index]value
|
||||||
func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) {
|
func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) {
|
||||||
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
|
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// query all data and map to []interface.
|
||||||
|
// it's designed for one row record set, auto change to []value, not [][column]value.
|
||||||
func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) {
|
func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) {
|
||||||
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ)
|
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// query all rows into map[string]interface with specify key and value column name.
|
||||||
|
// keyCol = "name", valueCol = "value"
|
||||||
|
// table data
|
||||||
|
// name | value
|
||||||
|
// total | 100
|
||||||
|
// found | 200
|
||||||
|
// to map[string]interface{}{
|
||||||
|
// "total": 100,
|
||||||
|
// "found": 200,
|
||||||
|
// }
|
||||||
|
func (o *querySet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) {
|
||||||
|
panic(ErrNotImplement)
|
||||||
|
return o.orm.alias.DbBaser.RowsTo(o.orm.db, o, o.mi, o.cond, result, keyCol, valueCol, o.orm.alias.TZ)
|
||||||
|
}
|
||||||
|
|
||||||
|
// query all rows into struct with specify key and value column name.
|
||||||
|
// keyCol = "name", valueCol = "value"
|
||||||
|
// table data
|
||||||
|
// name | value
|
||||||
|
// total | 100
|
||||||
|
// found | 200
|
||||||
|
// to struct {
|
||||||
|
// Total int
|
||||||
|
// Found int
|
||||||
|
// }
|
||||||
|
func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) {
|
||||||
|
panic(ErrNotImplement)
|
||||||
|
return o.orm.alias.DbBaser.RowsTo(o.orm.db, o, o.mi, o.cond, ptrStruct, keyCol, valueCol, o.orm.alias.TZ)
|
||||||
|
}
|
||||||
|
|
||||||
|
// create new QuerySeter.
|
||||||
func newQuerySet(orm *orm, mi *modelInfo) QuerySeter {
|
func newQuerySet(orm *orm, mi *modelInfo) QuerySeter {
|
||||||
o := new(querySet)
|
o := new(querySet)
|
||||||
o.mi = mi
|
o.mi = mi
|
||||||
|
487
orm/orm_raw.go
487
orm/orm_raw.go
@ -4,10 +4,10 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// raw sql string prepared statement
|
||||||
type rawPrepare struct {
|
type rawPrepare struct {
|
||||||
rs *rawSet
|
rs *rawSet
|
||||||
stmt stmtQuerier
|
stmt stmtQuerier
|
||||||
@ -45,6 +45,7 @@ func newRawPreparer(rs *rawSet) (RawPreparer, error) {
|
|||||||
return o, nil
|
return o, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// raw query seter
|
||||||
type rawSet struct {
|
type rawSet struct {
|
||||||
query string
|
query string
|
||||||
args []interface{}
|
args []interface{}
|
||||||
@ -53,11 +54,13 @@ type rawSet struct {
|
|||||||
|
|
||||||
var _ RawSeter = new(rawSet)
|
var _ RawSeter = new(rawSet)
|
||||||
|
|
||||||
|
// set args for every query
|
||||||
func (o rawSet) SetArgs(args ...interface{}) RawSeter {
|
func (o rawSet) SetArgs(args ...interface{}) RawSeter {
|
||||||
o.args = args
|
o.args = args
|
||||||
return &o
|
return &o
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// execute raw sql and return sql.Result
|
||||||
func (o *rawSet) Exec() (sql.Result, error) {
|
func (o *rawSet) Exec() (sql.Result, error) {
|
||||||
query := o.query
|
query := o.query
|
||||||
o.orm.alias.DbBaser.ReplaceMarks(&query)
|
o.orm.alias.DbBaser.ReplaceMarks(&query)
|
||||||
@ -66,6 +69,7 @@ func (o *rawSet) Exec() (sql.Result, error) {
|
|||||||
return o.orm.db.Exec(query, args...)
|
return o.orm.db.Exec(query, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set field value to row container
|
||||||
func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
|
func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
|
||||||
switch ind.Kind() {
|
switch ind.Kind() {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
@ -164,65 +168,12 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *rawSet) loopInitRefs(typ reflect.Type, refsPtr *[]interface{}, sIdxesPtr *[][]int) {
|
// set field value in loop for slice container
|
||||||
sIdxes := *sIdxesPtr
|
func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) {
|
||||||
refs := *refsPtr
|
|
||||||
|
|
||||||
if typ.Kind() == reflect.Struct {
|
|
||||||
if typ.String() == "time.Time" {
|
|
||||||
var ref interface{}
|
|
||||||
refs = append(refs, &ref)
|
|
||||||
sIdxes = append(sIdxes, []int{0})
|
|
||||||
} else {
|
|
||||||
idxs := []int{}
|
|
||||||
outFor:
|
|
||||||
for idx := 0; idx < typ.NumField(); idx++ {
|
|
||||||
ctyp := typ.Field(idx)
|
|
||||||
|
|
||||||
tag := ctyp.Tag.Get(defaultStructTagName)
|
|
||||||
for _, v := range strings.Split(tag, defaultStructTagDelim) {
|
|
||||||
if v == "-" {
|
|
||||||
continue outFor
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
tp := ctyp.Type
|
|
||||||
if tp.Kind() == reflect.Ptr {
|
|
||||||
tp = tp.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if tp.String() == "time.Time" {
|
|
||||||
var ref interface{}
|
|
||||||
refs = append(refs, &ref)
|
|
||||||
|
|
||||||
} else if tp.Kind() != reflect.Struct {
|
|
||||||
var ref interface{}
|
|
||||||
refs = append(refs, &ref)
|
|
||||||
|
|
||||||
} else {
|
|
||||||
// skip other type
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
idxs = append(idxs, idx)
|
|
||||||
}
|
|
||||||
sIdxes = append(sIdxes, idxs)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
var ref interface{}
|
|
||||||
refs = append(refs, &ref)
|
|
||||||
sIdxes = append(sIdxes, []int{0})
|
|
||||||
}
|
|
||||||
|
|
||||||
*sIdxesPtr = sIdxes
|
|
||||||
*refsPtr = refs
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) {
|
|
||||||
nInds := *nIndsPtr
|
nInds := *nIndsPtr
|
||||||
|
|
||||||
cur := 0
|
cur := 0
|
||||||
for i, idxs := range sIdxes {
|
for i := 0; i < len(sInds); i++ {
|
||||||
sInd := sInds[i]
|
sInd := sInds[i]
|
||||||
eTyp := eTyps[i]
|
eTyp := eTyps[i]
|
||||||
|
|
||||||
@ -258,32 +209,8 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect
|
|||||||
o.setFieldValue(ind, value)
|
o.setFieldValue(ind, value)
|
||||||
}
|
}
|
||||||
cur++
|
cur++
|
||||||
} else {
|
|
||||||
hasValue := false
|
|
||||||
for _, idx := range idxs {
|
|
||||||
tind := ind.Field(idx)
|
|
||||||
value := reflect.ValueOf(refs[cur]).Elem().Interface()
|
|
||||||
if value != nil {
|
|
||||||
hasValue = true
|
|
||||||
}
|
|
||||||
if tind.Kind() == reflect.Ptr {
|
|
||||||
if value == nil {
|
|
||||||
tindV := reflect.New(tind.Type()).Elem()
|
|
||||||
tind.Set(tindV)
|
|
||||||
} else {
|
|
||||||
tindV := reflect.New(tind.Type().Elem())
|
|
||||||
o.setFieldValue(tindV.Elem(), value)
|
|
||||||
tind.Set(tindV)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
o.setFieldValue(tind, value)
|
|
||||||
}
|
|
||||||
cur++
|
|
||||||
}
|
|
||||||
if hasValue == false && isPtr {
|
|
||||||
val = reflect.New(val.Type()).Elem()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
value := reflect.ValueOf(refs[cur]).Elem().Interface()
|
value := reflect.ValueOf(refs[cur]).Elem().Interface()
|
||||||
if isPtr && value == nil {
|
if isPtr && value == nil {
|
||||||
@ -312,16 +239,14 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// query data and map to container
|
||||||
func (o *rawSet) QueryRow(containers ...interface{}) error {
|
func (o *rawSet) QueryRow(containers ...interface{}) error {
|
||||||
if len(containers) == 0 {
|
|
||||||
panic(fmt.Errorf("<RawSeter.QueryRow> need at least one arg"))
|
|
||||||
}
|
|
||||||
|
|
||||||
refs := make([]interface{}, 0, len(containers))
|
refs := make([]interface{}, 0, len(containers))
|
||||||
sIdxes := make([][]int, 0)
|
|
||||||
sInds := make([]reflect.Value, 0)
|
sInds := make([]reflect.Value, 0)
|
||||||
eTyps := make([]reflect.Type, 0)
|
eTyps := make([]reflect.Type, 0)
|
||||||
|
|
||||||
|
structMode := false
|
||||||
|
var sMi *modelInfo
|
||||||
for _, container := range containers {
|
for _, container := range containers {
|
||||||
val := reflect.ValueOf(container)
|
val := reflect.ValueOf(container)
|
||||||
ind := reflect.Indirect(val)
|
ind := reflect.Indirect(val)
|
||||||
@ -335,44 +260,123 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
|
|||||||
if typ.Kind() == reflect.Ptr {
|
if typ.Kind() == reflect.Ptr {
|
||||||
typ = typ.Elem()
|
typ = typ.Elem()
|
||||||
}
|
}
|
||||||
if typ.Kind() == reflect.Ptr {
|
|
||||||
typ = typ.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
sInds = append(sInds, ind)
|
sInds = append(sInds, ind)
|
||||||
eTyps = append(eTyps, etyp)
|
eTyps = append(eTyps, etyp)
|
||||||
|
|
||||||
o.loopInitRefs(typ, &refs, &sIdxes)
|
if typ.Kind() == reflect.Struct && typ.String() != "time.Time" {
|
||||||
|
if len(containers) > 1 {
|
||||||
|
panic(fmt.Errorf("<RawSeter.QueryRow> now support one struct only. see #384"))
|
||||||
|
}
|
||||||
|
|
||||||
|
structMode = true
|
||||||
|
fn := getFullName(typ)
|
||||||
|
if mi, ok := modelCache.getByFN(fn); ok {
|
||||||
|
sMi = mi
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
var ref interface{}
|
||||||
|
refs = append(refs, &ref)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
query := o.query
|
query := o.query
|
||||||
o.orm.alias.DbBaser.ReplaceMarks(&query)
|
o.orm.alias.DbBaser.ReplaceMarks(&query)
|
||||||
|
|
||||||
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
|
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
|
||||||
row := o.orm.db.QueryRow(query, args...)
|
rows, err := o.orm.db.Query(query, args...)
|
||||||
|
if err != nil {
|
||||||
if err := row.Scan(refs...); err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return ErrNoRows
|
return ErrNoRows
|
||||||
} else if err != nil {
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
if rows.Next() {
|
||||||
|
if structMode {
|
||||||
|
columns, err := rows.Columns()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
columnsMp := make(map[string]interface{}, len(columns))
|
||||||
|
|
||||||
|
refs = make([]interface{}, 0, len(columns))
|
||||||
|
for _, col := range columns {
|
||||||
|
var ref interface{}
|
||||||
|
columnsMp[col] = &ref
|
||||||
|
refs = append(refs, &ref)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Scan(refs...); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ind := sInds[0]
|
||||||
|
|
||||||
|
if ind.Kind() == reflect.Ptr {
|
||||||
|
if ind.IsNil() || !ind.IsValid() {
|
||||||
|
ind.Set(reflect.New(eTyps[0].Elem()))
|
||||||
|
}
|
||||||
|
ind = ind.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if sMi != nil {
|
||||||
|
for _, col := range columns {
|
||||||
|
if fi := sMi.fields.GetByColumn(col); fi != nil {
|
||||||
|
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
|
||||||
|
o.setFieldValue(ind.FieldByIndex([]int{fi.fieldIndex}), value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i := 0; i < ind.NumField(); i++ {
|
||||||
|
f := ind.Field(i)
|
||||||
|
fe := ind.Type().Field(i)
|
||||||
|
|
||||||
|
var attrs map[string]bool
|
||||||
|
var tags map[string]string
|
||||||
|
parseStructTag(fe.Tag.Get("orm"), &attrs, &tags)
|
||||||
|
var col string
|
||||||
|
if col = tags["column"]; len(col) == 0 {
|
||||||
|
col = snakeString(fe.Name)
|
||||||
|
}
|
||||||
|
if v, ok := columnsMp[col]; ok {
|
||||||
|
value := reflect.ValueOf(v).Elem().Interface()
|
||||||
|
o.setFieldValue(f, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
if err := rows.Scan(refs...); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
nInds := make([]reflect.Value, len(sInds))
|
nInds := make([]reflect.Value, len(sInds))
|
||||||
o.loopSetRefs(refs, sIdxes, sInds, &nInds, eTyps, true)
|
o.loopSetRefs(refs, sInds, &nInds, eTyps, true)
|
||||||
for i, sInd := range sInds {
|
for i, sInd := range sInds {
|
||||||
nInd := nInds[i]
|
nInd := nInds[i]
|
||||||
sInd.Set(nInd)
|
sInd.Set(nInd)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
return ErrNoRows
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// query data rows and map to container
|
||||||
func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
|
func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
|
||||||
refs := make([]interface{}, 0)
|
refs := make([]interface{}, 0, len(containers))
|
||||||
sIdxes := make([][]int, 0)
|
|
||||||
sInds := make([]reflect.Value, 0)
|
sInds := make([]reflect.Value, 0)
|
||||||
eTyps := make([]reflect.Type, 0)
|
eTyps := make([]reflect.Type, 0)
|
||||||
|
|
||||||
|
structMode := false
|
||||||
|
var sMi *modelInfo
|
||||||
for _, container := range containers {
|
for _, container := range containers {
|
||||||
val := reflect.ValueOf(container)
|
val := reflect.ValueOf(container)
|
||||||
sInd := reflect.Indirect(val)
|
sInd := reflect.Indirect(val)
|
||||||
@ -389,7 +393,20 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
|
|||||||
sInds = append(sInds, sInd)
|
sInds = append(sInds, sInd)
|
||||||
eTyps = append(eTyps, etyp)
|
eTyps = append(eTyps, etyp)
|
||||||
|
|
||||||
o.loopInitRefs(typ, &refs, &sIdxes)
|
if typ.Kind() == reflect.Struct && typ.String() != "time.Time" {
|
||||||
|
if len(containers) > 1 {
|
||||||
|
panic(fmt.Errorf("<RawSeter.QueryRow> now support one struct only. see #384"))
|
||||||
|
}
|
||||||
|
|
||||||
|
structMode = true
|
||||||
|
fn := getFullName(typ)
|
||||||
|
if mi, ok := modelCache.getByFN(fn); ok {
|
||||||
|
sMi = mi
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
var ref interface{}
|
||||||
|
refs = append(refs, &ref)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
query := o.query
|
query := o.query
|
||||||
@ -401,30 +418,107 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
nInds := make([]reflect.Value, len(sInds))
|
defer rows.Close()
|
||||||
|
|
||||||
var cnt int64
|
var cnt int64
|
||||||
|
nInds := make([]reflect.Value, len(sInds))
|
||||||
|
sInd := sInds[0]
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
|
|
||||||
|
if structMode {
|
||||||
|
columns, err := rows.Columns()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
columnsMp := make(map[string]interface{}, len(columns))
|
||||||
|
|
||||||
|
refs = make([]interface{}, 0, len(columns))
|
||||||
|
for _, col := range columns {
|
||||||
|
var ref interface{}
|
||||||
|
columnsMp[col] = &ref
|
||||||
|
refs = append(refs, &ref)
|
||||||
|
}
|
||||||
|
|
||||||
if err := rows.Scan(refs...); err != nil {
|
if err := rows.Scan(refs...); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
o.loopSetRefs(refs, sIdxes, sInds, &nInds, eTyps, cnt == 0)
|
if cnt == 0 && !sInd.IsNil() {
|
||||||
|
sInd.Set(reflect.New(sInd.Type()).Elem())
|
||||||
|
}
|
||||||
|
|
||||||
|
var ind reflect.Value
|
||||||
|
if eTyps[0].Kind() == reflect.Ptr {
|
||||||
|
ind = reflect.New(eTyps[0].Elem())
|
||||||
|
} else {
|
||||||
|
ind = reflect.New(eTyps[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if ind.Kind() == reflect.Ptr {
|
||||||
|
ind = ind.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if sMi != nil {
|
||||||
|
for _, col := range columns {
|
||||||
|
if fi := sMi.fields.GetByColumn(col); fi != nil {
|
||||||
|
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
|
||||||
|
o.setFieldValue(ind.FieldByIndex([]int{fi.fieldIndex}), value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i := 0; i < ind.NumField(); i++ {
|
||||||
|
f := ind.Field(i)
|
||||||
|
fe := ind.Type().Field(i)
|
||||||
|
|
||||||
|
var attrs map[string]bool
|
||||||
|
var tags map[string]string
|
||||||
|
parseStructTag(fe.Tag.Get("orm"), &attrs, &tags)
|
||||||
|
var col string
|
||||||
|
if col = tags["column"]; len(col) == 0 {
|
||||||
|
col = snakeString(fe.Name)
|
||||||
|
}
|
||||||
|
if v, ok := columnsMp[col]; ok {
|
||||||
|
value := reflect.ValueOf(v).Elem().Interface()
|
||||||
|
o.setFieldValue(f, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if eTyps[0].Kind() == reflect.Ptr {
|
||||||
|
ind = ind.Addr()
|
||||||
|
}
|
||||||
|
|
||||||
|
sInd = reflect.Append(sInd, ind)
|
||||||
|
|
||||||
|
} else {
|
||||||
|
if err := rows.Scan(refs...); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
o.loopSetRefs(refs, sInds, &nInds, eTyps, cnt == 0)
|
||||||
|
}
|
||||||
|
|
||||||
cnt++
|
cnt++
|
||||||
}
|
}
|
||||||
|
|
||||||
if cnt > 0 {
|
if cnt > 0 {
|
||||||
|
|
||||||
|
if structMode {
|
||||||
|
sInds[0].Set(sInd)
|
||||||
|
} else {
|
||||||
for i, sInd := range sInds {
|
for i, sInd := range sInds {
|
||||||
nInd := nInds[i]
|
nInd := nInds[i]
|
||||||
sInd.Set(nInd)
|
sInd.Set(nInd)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return cnt, nil
|
return cnt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *rawSet) readValues(container interface{}) (int64, error) {
|
func (o *rawSet) readValues(container interface{}, needCols []string) (int64, error) {
|
||||||
var (
|
var (
|
||||||
maps []Params
|
maps []Params
|
||||||
lists []ParamsList
|
lists []ParamsList
|
||||||
@ -455,21 +549,41 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
|
|||||||
rs = r
|
rs = r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer rs.Close()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
refs []interface{}
|
refs []interface{}
|
||||||
cnt int64
|
cnt int64
|
||||||
cols []string
|
cols []string
|
||||||
|
indexs []int
|
||||||
)
|
)
|
||||||
|
|
||||||
for rs.Next() {
|
for rs.Next() {
|
||||||
if cnt == 0 {
|
if cnt == 0 {
|
||||||
if columns, err := rs.Columns(); err != nil {
|
if columns, err := rs.Columns(); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
} else {
|
} else {
|
||||||
|
if len(needCols) > 0 {
|
||||||
|
indexs = make([]int, 0, len(needCols))
|
||||||
|
} else {
|
||||||
|
indexs = make([]int, 0, len(columns))
|
||||||
|
}
|
||||||
|
|
||||||
cols = columns
|
cols = columns
|
||||||
refs = make([]interface{}, len(cols))
|
refs = make([]interface{}, len(cols))
|
||||||
for i, _ := range refs {
|
for i, _ := range refs {
|
||||||
var ref sql.NullString
|
var ref sql.NullString
|
||||||
refs[i] = &ref
|
refs[i] = &ref
|
||||||
|
|
||||||
|
if len(needCols) > 0 {
|
||||||
|
for _, c := range needCols {
|
||||||
|
if c == cols[i] {
|
||||||
|
indexs = append(indexs, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
indexs = append(indexs, i)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -481,7 +595,8 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
|
|||||||
switch typ {
|
switch typ {
|
||||||
case 1:
|
case 1:
|
||||||
params := make(Params, len(cols))
|
params := make(Params, len(cols))
|
||||||
for i, ref := range refs {
|
for _, i := range indexs {
|
||||||
|
ref := refs[i]
|
||||||
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
|
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
|
||||||
if value.Valid {
|
if value.Valid {
|
||||||
params[cols[i]] = value.String
|
params[cols[i]] = value.String
|
||||||
@ -492,7 +607,8 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
|
|||||||
maps = append(maps, params)
|
maps = append(maps, params)
|
||||||
case 2:
|
case 2:
|
||||||
params := make(ParamsList, 0, len(cols))
|
params := make(ParamsList, 0, len(cols))
|
||||||
for _, ref := range refs {
|
for _, i := range indexs {
|
||||||
|
ref := refs[i]
|
||||||
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
|
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
|
||||||
if value.Valid {
|
if value.Valid {
|
||||||
params = append(params, value.String)
|
params = append(params, value.String)
|
||||||
@ -502,7 +618,8 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
|
|||||||
}
|
}
|
||||||
lists = append(lists, params)
|
lists = append(lists, params)
|
||||||
case 3:
|
case 3:
|
||||||
for _, ref := range refs {
|
for _, i := range indexs {
|
||||||
|
ref := refs[i]
|
||||||
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
|
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
|
||||||
if value.Valid {
|
if value.Valid {
|
||||||
list = append(list, value.String)
|
list = append(list, value.String)
|
||||||
@ -527,18 +644,166 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
|
|||||||
return cnt, nil
|
return cnt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *rawSet) Values(container *[]Params) (int64, error) {
|
func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (int64, error) {
|
||||||
return o.readValues(container)
|
var (
|
||||||
|
maps Params
|
||||||
|
ind *reflect.Value
|
||||||
|
)
|
||||||
|
|
||||||
|
typ := 0
|
||||||
|
switch container.(type) {
|
||||||
|
case *Params:
|
||||||
|
typ = 1
|
||||||
|
default:
|
||||||
|
typ = 2
|
||||||
|
vl := reflect.ValueOf(container)
|
||||||
|
id := reflect.Indirect(vl)
|
||||||
|
if vl.Kind() != reflect.Ptr || id.Kind() != reflect.Struct {
|
||||||
|
panic(fmt.Errorf("<RawSeter> RowsTo unsupport type `%T` need ptr struct", container))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *rawSet) ValuesList(container *[]ParamsList) (int64, error) {
|
ind = &id
|
||||||
return o.readValues(container)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *rawSet) ValuesFlat(container *ParamsList) (int64, error) {
|
query := o.query
|
||||||
return o.readValues(container)
|
o.orm.alias.DbBaser.ReplaceMarks(&query)
|
||||||
|
|
||||||
|
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
|
||||||
|
|
||||||
|
var rs *sql.Rows
|
||||||
|
if r, err := o.orm.db.Query(query, args...); err != nil {
|
||||||
|
return 0, err
|
||||||
|
} else {
|
||||||
|
rs = r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer rs.Close()
|
||||||
|
|
||||||
|
var (
|
||||||
|
refs []interface{}
|
||||||
|
cnt int64
|
||||||
|
cols []string
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
keyIndex = -1
|
||||||
|
valueIndex = -1
|
||||||
|
)
|
||||||
|
|
||||||
|
for rs.Next() {
|
||||||
|
if cnt == 0 {
|
||||||
|
if columns, err := rs.Columns(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
} else {
|
||||||
|
cols = columns
|
||||||
|
refs = make([]interface{}, len(cols))
|
||||||
|
for i, _ := range refs {
|
||||||
|
if keyCol == cols[i] {
|
||||||
|
keyIndex = i
|
||||||
|
}
|
||||||
|
|
||||||
|
if typ == 1 || keyIndex == i {
|
||||||
|
var ref sql.NullString
|
||||||
|
refs[i] = &ref
|
||||||
|
} else {
|
||||||
|
var ref interface{}
|
||||||
|
refs[i] = &ref
|
||||||
|
}
|
||||||
|
|
||||||
|
if valueCol == cols[i] {
|
||||||
|
valueIndex = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if keyIndex == -1 || valueIndex == -1 {
|
||||||
|
panic(fmt.Errorf("<RawSeter> RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rs.Scan(refs...); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if cnt == 0 {
|
||||||
|
switch typ {
|
||||||
|
case 1:
|
||||||
|
maps = make(Params)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
key := reflect.Indirect(reflect.ValueOf(refs[keyIndex])).Interface().(sql.NullString).String
|
||||||
|
|
||||||
|
switch typ {
|
||||||
|
case 1:
|
||||||
|
value := reflect.Indirect(reflect.ValueOf(refs[valueIndex])).Interface().(sql.NullString)
|
||||||
|
if value.Valid {
|
||||||
|
maps[key] = value.String
|
||||||
|
} else {
|
||||||
|
maps[key] = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
if id := ind.FieldByName(camelString(key)); id.IsValid() {
|
||||||
|
o.setFieldValue(id, reflect.ValueOf(refs[valueIndex]).Elem().Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cnt++
|
||||||
|
}
|
||||||
|
|
||||||
|
if typ == 1 {
|
||||||
|
v, _ := container.(*Params)
|
||||||
|
*v = maps
|
||||||
|
}
|
||||||
|
|
||||||
|
return cnt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// query data to []map[string]interface
|
||||||
|
func (o *rawSet) Values(container *[]Params, cols ...string) (int64, error) {
|
||||||
|
return o.readValues(container, cols)
|
||||||
|
}
|
||||||
|
|
||||||
|
// query data to [][]interface
|
||||||
|
func (o *rawSet) ValuesList(container *[]ParamsList, cols ...string) (int64, error) {
|
||||||
|
return o.readValues(container, cols)
|
||||||
|
}
|
||||||
|
|
||||||
|
// query data to []interface
|
||||||
|
func (o *rawSet) ValuesFlat(container *ParamsList, cols ...string) (int64, error) {
|
||||||
|
return o.readValues(container, cols)
|
||||||
|
}
|
||||||
|
|
||||||
|
// query all rows into map[string]interface with specify key and value column name.
|
||||||
|
// keyCol = "name", valueCol = "value"
|
||||||
|
// table data
|
||||||
|
// name | value
|
||||||
|
// total | 100
|
||||||
|
// found | 200
|
||||||
|
// to map[string]interface{}{
|
||||||
|
// "total": 100,
|
||||||
|
// "found": 200,
|
||||||
|
// }
|
||||||
|
func (o *rawSet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) {
|
||||||
|
return o.queryRowsTo(result, keyCol, valueCol)
|
||||||
|
}
|
||||||
|
|
||||||
|
// query all rows into struct with specify key and value column name.
|
||||||
|
// keyCol = "name", valueCol = "value"
|
||||||
|
// table data
|
||||||
|
// name | value
|
||||||
|
// total | 100
|
||||||
|
// found | 200
|
||||||
|
// to struct {
|
||||||
|
// Total int
|
||||||
|
// Found int
|
||||||
|
// }
|
||||||
|
func (o *rawSet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) {
|
||||||
|
return o.queryRowsTo(ptrStruct, keyCol, valueCol)
|
||||||
|
}
|
||||||
|
|
||||||
|
// return prepared raw statement for used in times.
|
||||||
func (o *rawSet) Prepare() (RawPreparer, error) {
|
func (o *rawSet) Prepare() (RawPreparer, error) {
|
||||||
return newRawPreparer(o)
|
return newRawPreparer(o)
|
||||||
}
|
}
|
||||||
|
242
orm/orm_test.go
242
orm/orm_test.go
@ -1322,58 +1322,6 @@ func TestRawQueryRow(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type Tmp struct {
|
|
||||||
Skip0 string
|
|
||||||
Id int
|
|
||||||
Char *string
|
|
||||||
Skip1 int `orm:"-"`
|
|
||||||
Date time.Time
|
|
||||||
DateTime time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
Boolean = false
|
|
||||||
Text = ""
|
|
||||||
Int64 = 0
|
|
||||||
Uint = 0
|
|
||||||
|
|
||||||
tmp := new(Tmp)
|
|
||||||
|
|
||||||
cols = []string{
|
|
||||||
"int", "char", "date", "datetime", "boolean", "text", "int64", "uint",
|
|
||||||
}
|
|
||||||
query = fmt.Sprintf("SELECT NULL, %s%s%s FROM data WHERE id = ?", Q, strings.Join(cols, sep), Q)
|
|
||||||
values = []interface{}{
|
|
||||||
tmp, &Boolean, &Text, &Int64, &Uint,
|
|
||||||
}
|
|
||||||
err = dORM.Raw(query, 1).QueryRow(values...)
|
|
||||||
throwFailNow(t, err)
|
|
||||||
|
|
||||||
for _, col := range cols {
|
|
||||||
switch col {
|
|
||||||
case "id":
|
|
||||||
throwFail(t, AssertIs(tmp.Id, data_values[col]))
|
|
||||||
case "char":
|
|
||||||
c := tmp.Char
|
|
||||||
throwFail(t, AssertIs(*c, data_values[col]))
|
|
||||||
case "date":
|
|
||||||
v := tmp.Date.In(DefaultTimeLoc)
|
|
||||||
value := data_values[col].(time.Time).In(DefaultTimeLoc)
|
|
||||||
throwFail(t, AssertIs(v, value, test_Date))
|
|
||||||
case "datetime":
|
|
||||||
v := tmp.DateTime.In(DefaultTimeLoc)
|
|
||||||
value := data_values[col].(time.Time).In(DefaultTimeLoc)
|
|
||||||
throwFail(t, AssertIs(v, value, test_DateTime))
|
|
||||||
case "boolean":
|
|
||||||
throwFail(t, AssertIs(Boolean, data_values[col]))
|
|
||||||
case "text":
|
|
||||||
throwFail(t, AssertIs(Text, data_values[col]))
|
|
||||||
case "int64":
|
|
||||||
throwFail(t, AssertIs(Int64, data_values[col]))
|
|
||||||
case "uint":
|
|
||||||
throwFail(t, AssertIs(Uint, data_values[col]))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
uid int
|
uid int
|
||||||
status *int
|
status *int
|
||||||
@ -1394,22 +1342,13 @@ func TestRawQueryRow(t *testing.T) {
|
|||||||
func TestQueryRows(t *testing.T) {
|
func TestQueryRows(t *testing.T) {
|
||||||
Q := dDbBaser.TableQuote()
|
Q := dDbBaser.TableQuote()
|
||||||
|
|
||||||
cols := []string{
|
|
||||||
"id", "boolean", "char", "text", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32",
|
|
||||||
"int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal",
|
|
||||||
}
|
|
||||||
|
|
||||||
var datas []*Data
|
var datas []*Data
|
||||||
var dids []int
|
|
||||||
|
|
||||||
sep := fmt.Sprintf("%s, %s", Q, Q)
|
query := fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q)
|
||||||
query := fmt.Sprintf("SELECT %s%s%s, id FROM %sdata%s", Q, strings.Join(cols, sep), Q, Q, Q)
|
num, err := dORM.Raw(query).QueryRows(&datas)
|
||||||
num, err := dORM.Raw(query).QueryRows(&datas, &dids)
|
|
||||||
throwFailNow(t, err)
|
throwFailNow(t, err)
|
||||||
throwFailNow(t, AssertIs(num, 1))
|
throwFailNow(t, AssertIs(num, 1))
|
||||||
throwFailNow(t, AssertIs(len(datas), 1))
|
throwFailNow(t, AssertIs(len(datas), 1))
|
||||||
throwFailNow(t, AssertIs(len(dids), 1))
|
|
||||||
throwFailNow(t, AssertIs(dids[0], 1))
|
|
||||||
|
|
||||||
ind := reflect.Indirect(reflect.ValueOf(datas[0]))
|
ind := reflect.Indirect(reflect.ValueOf(datas[0]))
|
||||||
|
|
||||||
@ -1427,90 +1366,43 @@ func TestQueryRows(t *testing.T) {
|
|||||||
throwFail(t, AssertIs(vu == value, true), value, vu)
|
throwFail(t, AssertIs(vu == value, true), value, vu)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Tmp struct {
|
var datas2 []Data
|
||||||
Id int
|
|
||||||
Name string
|
query = fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q)
|
||||||
Skiped0 string `orm:"-"`
|
num, err = dORM.Raw(query).QueryRows(&datas2)
|
||||||
Pid *int
|
throwFailNow(t, err)
|
||||||
Skiped1 Data
|
throwFailNow(t, AssertIs(num, 1))
|
||||||
Skiped2 *Data
|
throwFailNow(t, AssertIs(len(datas2), 1))
|
||||||
|
|
||||||
|
ind = reflect.Indirect(reflect.ValueOf(datas2[0]))
|
||||||
|
|
||||||
|
for name, value := range Data_Values {
|
||||||
|
e := ind.FieldByName(name)
|
||||||
|
vu := e.Interface()
|
||||||
|
switch name {
|
||||||
|
case "Date":
|
||||||
|
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date)
|
||||||
|
value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date)
|
||||||
|
case "DateTime":
|
||||||
|
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
|
||||||
|
value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
|
||||||
|
}
|
||||||
|
throwFail(t, AssertIs(vu == value, true), value, vu)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var ids []int
|
||||||
ids []int
|
var usernames []string
|
||||||
userNames []string
|
query = fmt.Sprintf("SELECT %sid%s, %suser_name%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q, Q, Q)
|
||||||
profileIds1 []int
|
num, err = dORM.Raw(query).QueryRows(&ids, &usernames)
|
||||||
profileIds2 []*int
|
|
||||||
createds []time.Time
|
|
||||||
updateds []time.Time
|
|
||||||
tmps1 []*Tmp
|
|
||||||
tmps2 []Tmp
|
|
||||||
)
|
|
||||||
cols = []string{
|
|
||||||
"id", "user_name", "profile_id", "profile_id", "id", "user_name", "profile_id", "id", "user_name", "profile_id", "created", "updated",
|
|
||||||
}
|
|
||||||
query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s ORDER BY id", Q, strings.Join(cols, sep), Q, Q, Q)
|
|
||||||
num, err = dORM.Raw(query).QueryRows(&ids, &userNames, &profileIds1, &profileIds2, &tmps1, &tmps2, &createds, &updateds)
|
|
||||||
throwFailNow(t, err)
|
throwFailNow(t, err)
|
||||||
throwFailNow(t, AssertIs(num, 3))
|
throwFailNow(t, AssertIs(num, 3))
|
||||||
|
throwFailNow(t, AssertIs(len(ids), 3))
|
||||||
var users []User
|
throwFailNow(t, AssertIs(ids[0], 2))
|
||||||
dORM.QueryTable("user").OrderBy("Id").All(&users)
|
throwFailNow(t, AssertIs(usernames[0], "slene"))
|
||||||
|
throwFailNow(t, AssertIs(ids[1], 3))
|
||||||
for i := 0; i < 3; i++ {
|
throwFailNow(t, AssertIs(usernames[1], "astaxie"))
|
||||||
id := ids[i]
|
throwFailNow(t, AssertIs(ids[2], 4))
|
||||||
name := userNames[i]
|
throwFailNow(t, AssertIs(usernames[2], "nobody"))
|
||||||
pid1 := profileIds1[i]
|
|
||||||
pid2 := profileIds2[i]
|
|
||||||
created := createds[i]
|
|
||||||
updated := updateds[i]
|
|
||||||
|
|
||||||
user := users[i]
|
|
||||||
throwFailNow(t, AssertIs(id, user.Id))
|
|
||||||
throwFailNow(t, AssertIs(name, user.UserName))
|
|
||||||
if user.Profile != nil {
|
|
||||||
throwFailNow(t, AssertIs(pid1, user.Profile.Id))
|
|
||||||
throwFailNow(t, AssertIs(*pid2, user.Profile.Id))
|
|
||||||
} else {
|
|
||||||
throwFailNow(t, AssertIs(pid1, 0))
|
|
||||||
throwFailNow(t, AssertIs(pid2, nil))
|
|
||||||
}
|
|
||||||
throwFailNow(t, AssertIs(created, user.Created, test_Date))
|
|
||||||
throwFailNow(t, AssertIs(updated, user.Updated, test_DateTime))
|
|
||||||
|
|
||||||
tmp := tmps1[i]
|
|
||||||
tmp1 := *tmp
|
|
||||||
throwFailNow(t, AssertIs(tmp1.Id, user.Id))
|
|
||||||
throwFailNow(t, AssertIs(tmp1.Name, user.UserName))
|
|
||||||
if user.Profile != nil {
|
|
||||||
pid := tmp1.Pid
|
|
||||||
throwFailNow(t, AssertIs(*pid, user.Profile.Id))
|
|
||||||
} else {
|
|
||||||
throwFailNow(t, AssertIs(tmp1.Pid, nil))
|
|
||||||
}
|
|
||||||
|
|
||||||
tmp2 := tmps2[i]
|
|
||||||
throwFailNow(t, AssertIs(tmp2.Id, user.Id))
|
|
||||||
throwFailNow(t, AssertIs(tmp2.Name, user.UserName))
|
|
||||||
if user.Profile != nil {
|
|
||||||
pid := tmp2.Pid
|
|
||||||
throwFailNow(t, AssertIs(*pid, user.Profile.Id))
|
|
||||||
} else {
|
|
||||||
throwFailNow(t, AssertIs(tmp2.Pid, nil))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type Sec struct {
|
|
||||||
Id int
|
|
||||||
Name string
|
|
||||||
}
|
|
||||||
|
|
||||||
var tmp []*Sec
|
|
||||||
query = fmt.Sprintf("SELECT NULL, NULL FROM %suser%s LIMIT 1", Q, Q)
|
|
||||||
num, err = dORM.Raw(query).QueryRows(&tmp)
|
|
||||||
throwFail(t, err)
|
|
||||||
throwFail(t, AssertIs(num, 1))
|
|
||||||
throwFail(t, AssertIs(tmp[0], nil))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRawValues(t *testing.T) {
|
func TestRawValues(t *testing.T) {
|
||||||
@ -1669,6 +1561,32 @@ func TestDelete(t *testing.T) {
|
|||||||
num, err = qs.Filter("user_name", "slene").Filter("profile__isnull", true).Count()
|
num, err = qs.Filter("user_name", "slene").Filter("profile__isnull", true).Count()
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(num, 1))
|
throwFail(t, AssertIs(num, 1))
|
||||||
|
|
||||||
|
qs = dORM.QueryTable("comment")
|
||||||
|
num, err = qs.Count()
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 6))
|
||||||
|
|
||||||
|
qs = dORM.QueryTable("post")
|
||||||
|
num, err = qs.Filter("Id", 3).Delete()
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 1))
|
||||||
|
|
||||||
|
qs = dORM.QueryTable("comment")
|
||||||
|
num, err = qs.Count()
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 4))
|
||||||
|
|
||||||
|
fmt.Println("...")
|
||||||
|
qs = dORM.QueryTable("comment")
|
||||||
|
num, err = qs.Filter("Post__User", 3).Delete()
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 3))
|
||||||
|
|
||||||
|
qs = dORM.QueryTable("comment")
|
||||||
|
num, err = qs.Count()
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 1))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTransaction(t *testing.T) {
|
func TestTransaction(t *testing.T) {
|
||||||
@ -1724,3 +1642,41 @@ func TestTransaction(t *testing.T) {
|
|||||||
throwFail(t, AssertIs(num, 1))
|
throwFail(t, AssertIs(num, 1))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReadOrCreate(t *testing.T) {
|
||||||
|
u := &User{
|
||||||
|
UserName: "Kyle",
|
||||||
|
Email: "kylemcc@gmail.com",
|
||||||
|
Password: "other_pass",
|
||||||
|
Status: 7,
|
||||||
|
IsStaff: false,
|
||||||
|
IsActive: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
created, pk, err := dORM.ReadOrCreate(u, "UserName")
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(created, true))
|
||||||
|
throwFail(t, AssertIs(u.UserName, "Kyle"))
|
||||||
|
throwFail(t, AssertIs(u.Email, "kylemcc@gmail.com"))
|
||||||
|
throwFail(t, AssertIs(u.Password, "other_pass"))
|
||||||
|
throwFail(t, AssertIs(u.Status, 7))
|
||||||
|
throwFail(t, AssertIs(u.IsStaff, false))
|
||||||
|
throwFail(t, AssertIs(u.IsActive, true))
|
||||||
|
throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), test_Date))
|
||||||
|
throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), test_DateTime))
|
||||||
|
|
||||||
|
nu := &User{UserName: u.UserName, Email: "someotheremail@gmail.com"}
|
||||||
|
created, pk, err = dORM.ReadOrCreate(nu, "UserName")
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(created, false))
|
||||||
|
throwFail(t, AssertIs(nu.Id, u.Id))
|
||||||
|
throwFail(t, AssertIs(pk, u.Id))
|
||||||
|
throwFail(t, AssertIs(nu.UserName, u.UserName))
|
||||||
|
throwFail(t, AssertIs(nu.Email, u.Email)) // should contain the value in the table, not the one specified above
|
||||||
|
throwFail(t, AssertIs(nu.Password, u.Password))
|
||||||
|
throwFail(t, AssertIs(nu.Status, u.Status))
|
||||||
|
throwFail(t, AssertIs(nu.IsStaff, u.IsStaff))
|
||||||
|
throwFail(t, AssertIs(nu.IsActive, u.IsActive))
|
||||||
|
|
||||||
|
dORM.Delete(u)
|
||||||
|
}
|
||||||
|
38
orm/types.go
38
orm/types.go
@ -6,11 +6,13 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// database driver
|
||||||
type Driver interface {
|
type Driver interface {
|
||||||
Name() string
|
Name() string
|
||||||
Type() DriverType
|
Type() DriverType
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// field info
|
||||||
type Fielder interface {
|
type Fielder interface {
|
||||||
String() string
|
String() string
|
||||||
FieldType() int
|
FieldType() int
|
||||||
@ -18,9 +20,12 @@ type Fielder interface {
|
|||||||
RawValue() interface{}
|
RawValue() interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// orm struct
|
||||||
type Ormer interface {
|
type Ormer interface {
|
||||||
Read(interface{}, ...string) error
|
Read(interface{}, ...string) error
|
||||||
|
ReadOrCreate(interface{}, string, ...string) (bool, int64, error)
|
||||||
Insert(interface{}) (int64, error)
|
Insert(interface{}) (int64, error)
|
||||||
|
InsertMulti(int, interface{}) (int64, error)
|
||||||
Update(interface{}, ...string) (int64, error)
|
Update(interface{}, ...string) (int64, error)
|
||||||
Delete(interface{}) (int64, error)
|
Delete(interface{}) (int64, error)
|
||||||
LoadRelated(interface{}, string, ...interface{}) (int64, error)
|
LoadRelated(interface{}, string, ...interface{}) (int64, error)
|
||||||
@ -32,13 +37,16 @@ type Ormer interface {
|
|||||||
Rollback() error
|
Rollback() error
|
||||||
Raw(string, ...interface{}) RawSeter
|
Raw(string, ...interface{}) RawSeter
|
||||||
Driver() Driver
|
Driver() Driver
|
||||||
|
GetDB() dbQuerier
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// insert prepared statement
|
||||||
type Inserter interface {
|
type Inserter interface {
|
||||||
Insert(interface{}) (int64, error)
|
Insert(interface{}) (int64, error)
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// query seter
|
||||||
type QuerySeter interface {
|
type QuerySeter interface {
|
||||||
Filter(string, ...interface{}) QuerySeter
|
Filter(string, ...interface{}) QuerySeter
|
||||||
Exclude(string, ...interface{}) QuerySeter
|
Exclude(string, ...interface{}) QuerySeter
|
||||||
@ -57,8 +65,11 @@ type QuerySeter interface {
|
|||||||
Values(*[]Params, ...string) (int64, error)
|
Values(*[]Params, ...string) (int64, error)
|
||||||
ValuesList(*[]ParamsList, ...string) (int64, error)
|
ValuesList(*[]ParamsList, ...string) (int64, error)
|
||||||
ValuesFlat(*ParamsList, string) (int64, error)
|
ValuesFlat(*ParamsList, string) (int64, error)
|
||||||
|
RowsToMap(*Params, string, string) (int64, error)
|
||||||
|
RowsToStruct(interface{}, string, string) (int64, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// model to model query struct
|
||||||
type QueryM2Mer interface {
|
type QueryM2Mer interface {
|
||||||
Add(...interface{}) (int64, error)
|
Add(...interface{}) (int64, error)
|
||||||
Remove(...interface{}) (int64, error)
|
Remove(...interface{}) (int64, error)
|
||||||
@ -67,22 +78,27 @@ type QueryM2Mer interface {
|
|||||||
Count() (int64, error)
|
Count() (int64, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// raw query statement
|
||||||
type RawPreparer interface {
|
type RawPreparer interface {
|
||||||
Exec(...interface{}) (sql.Result, error)
|
Exec(...interface{}) (sql.Result, error)
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// raw query seter
|
||||||
type RawSeter interface {
|
type RawSeter interface {
|
||||||
Exec() (sql.Result, error)
|
Exec() (sql.Result, error)
|
||||||
QueryRow(...interface{}) error
|
QueryRow(...interface{}) error
|
||||||
QueryRows(...interface{}) (int64, error)
|
QueryRows(...interface{}) (int64, error)
|
||||||
SetArgs(...interface{}) RawSeter
|
SetArgs(...interface{}) RawSeter
|
||||||
Values(*[]Params) (int64, error)
|
Values(*[]Params, ...string) (int64, error)
|
||||||
ValuesList(*[]ParamsList) (int64, error)
|
ValuesList(*[]ParamsList, ...string) (int64, error)
|
||||||
ValuesFlat(*ParamsList) (int64, error)
|
ValuesFlat(*ParamsList, ...string) (int64, error)
|
||||||
|
RowsToMap(*Params, string, string) (int64, error)
|
||||||
|
RowsToStruct(interface{}, string, string) (int64, error)
|
||||||
Prepare() (RawPreparer, error)
|
Prepare() (RawPreparer, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// statement querier
|
||||||
type stmtQuerier interface {
|
type stmtQuerier interface {
|
||||||
Close() error
|
Close() error
|
||||||
Exec(args ...interface{}) (sql.Result, error)
|
Exec(args ...interface{}) (sql.Result, error)
|
||||||
@ -90,6 +106,7 @@ type stmtQuerier interface {
|
|||||||
QueryRow(args ...interface{}) *sql.Row
|
QueryRow(args ...interface{}) *sql.Row
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// db querier
|
||||||
type dbQuerier interface {
|
type dbQuerier interface {
|
||||||
Prepare(query string) (*sql.Stmt, error)
|
Prepare(query string) (*sql.Stmt, error)
|
||||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||||
@ -97,19 +114,31 @@ type dbQuerier interface {
|
|||||||
QueryRow(query string, args ...interface{}) *sql.Row
|
QueryRow(query string, args ...interface{}) *sql.Row
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// type DB interface {
|
||||||
|
// Begin() (*sql.Tx, error)
|
||||||
|
// Prepare(query string) (stmtQuerier, error)
|
||||||
|
// Exec(query string, args ...interface{}) (sql.Result, error)
|
||||||
|
// Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||||
|
// QueryRow(query string, args ...interface{}) *sql.Row
|
||||||
|
// }
|
||||||
|
|
||||||
|
// transaction beginner
|
||||||
type txer interface {
|
type txer interface {
|
||||||
Begin() (*sql.Tx, error)
|
Begin() (*sql.Tx, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// transaction ending
|
||||||
type txEnder interface {
|
type txEnder interface {
|
||||||
Commit() error
|
Commit() error
|
||||||
Rollback() error
|
Rollback() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// base database struct
|
||||||
type dbBaser interface {
|
type dbBaser interface {
|
||||||
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error
|
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error
|
||||||
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
||||||
InsertValue(dbQuerier, *modelInfo, []string, []interface{}) (int64, error)
|
InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
|
||||||
|
InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
|
||||||
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
||||||
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
|
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
|
||||||
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
||||||
@ -123,6 +152,7 @@ type dbBaser interface {
|
|||||||
GenerateOperatorLeftCol(*fieldInfo, string, *string)
|
GenerateOperatorLeftCol(*fieldInfo, string, *string)
|
||||||
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
|
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
|
||||||
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
|
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
|
||||||
|
RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error)
|
||||||
MaxLimit() uint64
|
MaxLimit() uint64
|
||||||
TableQuote() string
|
TableQuote() string
|
||||||
ReplaceMarks(*string)
|
ReplaceMarks(*string)
|
||||||
|
27
orm/utils.go
27
orm/utils.go
@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
type StrTo string
|
type StrTo string
|
||||||
|
|
||||||
|
// set string
|
||||||
func (f *StrTo) Set(v string) {
|
func (f *StrTo) Set(v string) {
|
||||||
if v != "" {
|
if v != "" {
|
||||||
*f = StrTo(v)
|
*f = StrTo(v)
|
||||||
@ -18,77 +19,93 @@ func (f *StrTo) Set(v string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clean string
|
||||||
func (f *StrTo) Clear() {
|
func (f *StrTo) Clear() {
|
||||||
*f = StrTo(0x1E)
|
*f = StrTo(0x1E)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check string exist
|
||||||
func (f StrTo) Exist() bool {
|
func (f StrTo) Exist() bool {
|
||||||
return string(f) != string(0x1E)
|
return string(f) != string(0x1E)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// string to bool
|
||||||
func (f StrTo) Bool() (bool, error) {
|
func (f StrTo) Bool() (bool, error) {
|
||||||
return strconv.ParseBool(f.String())
|
return strconv.ParseBool(f.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// string to float32
|
||||||
func (f StrTo) Float32() (float32, error) {
|
func (f StrTo) Float32() (float32, error) {
|
||||||
v, err := strconv.ParseFloat(f.String(), 32)
|
v, err := strconv.ParseFloat(f.String(), 32)
|
||||||
return float32(v), err
|
return float32(v), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// string to float64
|
||||||
func (f StrTo) Float64() (float64, error) {
|
func (f StrTo) Float64() (float64, error) {
|
||||||
return strconv.ParseFloat(f.String(), 64)
|
return strconv.ParseFloat(f.String(), 64)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// string to int
|
||||||
func (f StrTo) Int() (int, error) {
|
func (f StrTo) Int() (int, error) {
|
||||||
v, err := strconv.ParseInt(f.String(), 10, 32)
|
v, err := strconv.ParseInt(f.String(), 10, 32)
|
||||||
return int(v), err
|
return int(v), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// string to int8
|
||||||
func (f StrTo) Int8() (int8, error) {
|
func (f StrTo) Int8() (int8, error) {
|
||||||
v, err := strconv.ParseInt(f.String(), 10, 8)
|
v, err := strconv.ParseInt(f.String(), 10, 8)
|
||||||
return int8(v), err
|
return int8(v), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// string to int16
|
||||||
func (f StrTo) Int16() (int16, error) {
|
func (f StrTo) Int16() (int16, error) {
|
||||||
v, err := strconv.ParseInt(f.String(), 10, 16)
|
v, err := strconv.ParseInt(f.String(), 10, 16)
|
||||||
return int16(v), err
|
return int16(v), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// string to int32
|
||||||
func (f StrTo) Int32() (int32, error) {
|
func (f StrTo) Int32() (int32, error) {
|
||||||
v, err := strconv.ParseInt(f.String(), 10, 32)
|
v, err := strconv.ParseInt(f.String(), 10, 32)
|
||||||
return int32(v), err
|
return int32(v), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// string to int64
|
||||||
func (f StrTo) Int64() (int64, error) {
|
func (f StrTo) Int64() (int64, error) {
|
||||||
v, err := strconv.ParseInt(f.String(), 10, 64)
|
v, err := strconv.ParseInt(f.String(), 10, 64)
|
||||||
return int64(v), err
|
return int64(v), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// string to uint
|
||||||
func (f StrTo) Uint() (uint, error) {
|
func (f StrTo) Uint() (uint, error) {
|
||||||
v, err := strconv.ParseUint(f.String(), 10, 32)
|
v, err := strconv.ParseUint(f.String(), 10, 32)
|
||||||
return uint(v), err
|
return uint(v), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// string to uint8
|
||||||
func (f StrTo) Uint8() (uint8, error) {
|
func (f StrTo) Uint8() (uint8, error) {
|
||||||
v, err := strconv.ParseUint(f.String(), 10, 8)
|
v, err := strconv.ParseUint(f.String(), 10, 8)
|
||||||
return uint8(v), err
|
return uint8(v), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// string to uint16
|
||||||
func (f StrTo) Uint16() (uint16, error) {
|
func (f StrTo) Uint16() (uint16, error) {
|
||||||
v, err := strconv.ParseUint(f.String(), 10, 16)
|
v, err := strconv.ParseUint(f.String(), 10, 16)
|
||||||
return uint16(v), err
|
return uint16(v), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// string to uint31
|
||||||
func (f StrTo) Uint32() (uint32, error) {
|
func (f StrTo) Uint32() (uint32, error) {
|
||||||
v, err := strconv.ParseUint(f.String(), 10, 32)
|
v, err := strconv.ParseUint(f.String(), 10, 32)
|
||||||
return uint32(v), err
|
return uint32(v), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// string to uint64
|
||||||
func (f StrTo) Uint64() (uint64, error) {
|
func (f StrTo) Uint64() (uint64, error) {
|
||||||
v, err := strconv.ParseUint(f.String(), 10, 64)
|
v, err := strconv.ParseUint(f.String(), 10, 64)
|
||||||
return uint64(v), err
|
return uint64(v), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// string to string
|
||||||
func (f StrTo) String() string {
|
func (f StrTo) String() string {
|
||||||
if f.Exist() {
|
if f.Exist() {
|
||||||
return string(f)
|
return string(f)
|
||||||
@ -96,6 +113,7 @@ func (f StrTo) String() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// interface to string
|
||||||
func ToStr(value interface{}, args ...int) (s string) {
|
func ToStr(value interface{}, args ...int) (s string) {
|
||||||
switch v := value.(type) {
|
switch v := value.(type) {
|
||||||
case bool:
|
case bool:
|
||||||
@ -134,6 +152,7 @@ func ToStr(value interface{}, args ...int) (s string) {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// interface to int64
|
||||||
func ToInt64(value interface{}) (d int64) {
|
func ToInt64(value interface{}) (d int64) {
|
||||||
val := reflect.ValueOf(value)
|
val := reflect.ValueOf(value)
|
||||||
switch value.(type) {
|
switch value.(type) {
|
||||||
@ -147,6 +166,7 @@ func ToInt64(value interface{}) (d int64) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// snake string, XxYy to xx_yy
|
||||||
func snakeString(s string) string {
|
func snakeString(s string) string {
|
||||||
data := make([]byte, 0, len(s)*2)
|
data := make([]byte, 0, len(s)*2)
|
||||||
j := false
|
j := false
|
||||||
@ -164,6 +184,7 @@ func snakeString(s string) string {
|
|||||||
return strings.ToLower(string(data[:len(data)]))
|
return strings.ToLower(string(data[:len(data)]))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// camel string, xx_yy to XxYy
|
||||||
func camelString(s string) string {
|
func camelString(s string) string {
|
||||||
data := make([]byte, 0, len(s))
|
data := make([]byte, 0, len(s))
|
||||||
j := false
|
j := false
|
||||||
@ -190,6 +211,7 @@ func camelString(s string) string {
|
|||||||
|
|
||||||
type argString []string
|
type argString []string
|
||||||
|
|
||||||
|
// get string by index from string slice
|
||||||
func (a argString) Get(i int, args ...string) (r string) {
|
func (a argString) Get(i int, args ...string) (r string) {
|
||||||
if i >= 0 && i < len(a) {
|
if i >= 0 && i < len(a) {
|
||||||
r = a[i]
|
r = a[i]
|
||||||
@ -201,6 +223,7 @@ func (a argString) Get(i int, args ...string) (r string) {
|
|||||||
|
|
||||||
type argInt []int
|
type argInt []int
|
||||||
|
|
||||||
|
// get int by index from int slice
|
||||||
func (a argInt) Get(i int, args ...int) (r int) {
|
func (a argInt) Get(i int, args ...int) (r int) {
|
||||||
if i >= 0 && i < len(a) {
|
if i >= 0 && i < len(a) {
|
||||||
r = a[i]
|
r = a[i]
|
||||||
@ -213,6 +236,7 @@ func (a argInt) Get(i int, args ...int) (r int) {
|
|||||||
|
|
||||||
type argAny []interface{}
|
type argAny []interface{}
|
||||||
|
|
||||||
|
// get interface by index from interface slice
|
||||||
func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
|
func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
|
||||||
if i >= 0 && i < len(a) {
|
if i >= 0 && i < len(a) {
|
||||||
r = a[i]
|
r = a[i]
|
||||||
@ -223,15 +247,18 @@ func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parse time to string with location
|
||||||
func timeParse(dateString, format string) (time.Time, error) {
|
func timeParse(dateString, format string) (time.Time, error) {
|
||||||
tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
|
tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
|
||||||
return tp, err
|
return tp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// format time string
|
||||||
func timeFormat(t time.Time, format string) string {
|
func timeFormat(t time.Time, format string) string {
|
||||||
return t.Format(format)
|
return t.Format(format)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get pointer indirect type
|
||||||
func indirectType(v reflect.Type) reflect.Type {
|
func indirectType(v reflect.Type) reflect.Type {
|
||||||
switch v.Kind() {
|
switch v.Kind() {
|
||||||
case reflect.Ptr:
|
case reflect.Ptr:
|
||||||
|
75
plugins/auth/basic.go
Normal file
75
plugins/auth/basic.go
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
// basic auth for plugin
|
||||||
|
package auth
|
||||||
|
|
||||||
|
// Example:
|
||||||
|
// func SecretAuth(username, password string) bool {
|
||||||
|
// if username == "astaxie" && password == "helloBeego" {
|
||||||
|
// return true
|
||||||
|
// }
|
||||||
|
// return false
|
||||||
|
// }
|
||||||
|
// authPlugin := auth.NewBasicAuthenticator(SecretAuth)
|
||||||
|
// beego.AddFilter("*","AfterStatic",authPlugin)
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/astaxie/beego"
|
||||||
|
"github.com/astaxie/beego/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewBasicAuthenticator(secrets SecretProvider, Realm string) beego.FilterFunc {
|
||||||
|
return func(ctx *context.Context) {
|
||||||
|
a := &BasicAuth{Secrets: secrets, Realm: Realm}
|
||||||
|
if username := a.CheckAuth(ctx.Request); username == "" {
|
||||||
|
a.RequireAuth(ctx.ResponseWriter, ctx.Request)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type SecretProvider func(user, pass string) bool
|
||||||
|
|
||||||
|
type BasicAuth struct {
|
||||||
|
Secrets SecretProvider
|
||||||
|
Realm string
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Checks the username/password combination from the request. Returns
|
||||||
|
either an empty string (authentication failed) or the name of the
|
||||||
|
authenticated user.
|
||||||
|
|
||||||
|
Supports MD5 and SHA1 password entries
|
||||||
|
*/
|
||||||
|
func (a *BasicAuth) CheckAuth(r *http.Request) string {
|
||||||
|
s := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
|
||||||
|
if len(s) != 2 || s[0] != "Basic" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := base64.StdEncoding.DecodeString(s[1])
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
pair := strings.SplitN(string(b), ":", 2)
|
||||||
|
if len(pair) != 2 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if a.Secrets(pair[0], pair[1]) {
|
||||||
|
return pair[0]
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
http.Handler for BasicAuth which initiates the authentication process
|
||||||
|
(or requires reauthentication).
|
||||||
|
*/
|
||||||
|
func (a *BasicAuth) RequireAuth(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("WWW-Authenticate", `Basic realm="`+a.Realm+`"`)
|
||||||
|
w.WriteHeader(401)
|
||||||
|
w.Write([]byte("401 Unauthorized\n"))
|
||||||
|
}
|
79
router.go
79
router.go
@ -1,7 +1,10 @@
|
|||||||
package beego
|
package beego
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
@ -30,6 +33,14 @@ const (
|
|||||||
var (
|
var (
|
||||||
// supported http methods.
|
// supported http methods.
|
||||||
HTTPMETHOD = []string{"get", "post", "put", "delete", "patch", "options", "head"}
|
HTTPMETHOD = []string{"get", "post", "put", "delete", "patch", "options", "head"}
|
||||||
|
// these beego.Controller's methods shouldn't reflect to AutoRouter
|
||||||
|
exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString",
|
||||||
|
"RenderBytes", "Redirect", "Abort", "StopRun", "UrlFor", "ServeJson", "ServeJsonp",
|
||||||
|
"ServeXml", "Input", "ParseForm", "GetString", "GetStrings", "GetInt", "GetBool",
|
||||||
|
"GetFloat", "GetFile", "SaveToFile", "StartSession", "SetSession", "GetSession",
|
||||||
|
"DelSession", "SessionRegenerateID", "DestroySession", "IsAjax", "GetSecureCookie",
|
||||||
|
"SetSecureCookie", "XsrfToken", "CheckXsrfCookie", "XsrfFormHtml",
|
||||||
|
"GetControllerAndAction"}
|
||||||
)
|
)
|
||||||
|
|
||||||
type controllerInfo struct {
|
type controllerInfo struct {
|
||||||
@ -77,7 +88,7 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM
|
|||||||
params := make(map[int]string)
|
params := make(map[int]string)
|
||||||
for i, part := range parts {
|
for i, part := range parts {
|
||||||
if strings.HasPrefix(part, ":") {
|
if strings.HasPrefix(part, ":") {
|
||||||
expr := "(.+)"
|
expr := "(.*)"
|
||||||
//a user may choose to override the defult expression
|
//a user may choose to override the defult expression
|
||||||
// similar to expressjs: ‘/user/:id([0-9]+)’
|
// similar to expressjs: ‘/user/:id([0-9]+)’
|
||||||
if index := strings.Index(part, "("); index != -1 {
|
if index := strings.Index(part, "("); index != -1 {
|
||||||
@ -100,7 +111,7 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM
|
|||||||
j++
|
j++
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(part, "*") {
|
if strings.HasPrefix(part, "*") {
|
||||||
expr := "(.+)"
|
expr := "(.*)"
|
||||||
if part == "*.*" {
|
if part == "*.*" {
|
||||||
params[j] = ":path"
|
params[j] = ":path"
|
||||||
parts[i] = "([^.]+).([^.]+)"
|
parts[i] = "([^.]+).([^.]+)"
|
||||||
@ -218,8 +229,8 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM
|
|||||||
// Add auto router to ControllerRegistor.
|
// Add auto router to ControllerRegistor.
|
||||||
// example beego.AddAuto(&MainContorlller{}),
|
// example beego.AddAuto(&MainContorlller{}),
|
||||||
// MainController has method List and Page.
|
// MainController has method List and Page.
|
||||||
// visit the url /main/list to exec List function
|
// visit the url /main/list to execute List function
|
||||||
// /main/page to exec Page function.
|
// /main/page to execute Page function.
|
||||||
func (p *ControllerRegistor) AddAuto(c ControllerInterface) {
|
func (p *ControllerRegistor) AddAuto(c ControllerInterface) {
|
||||||
p.enableAuto = true
|
p.enableAuto = true
|
||||||
reflectVal := reflect.ValueOf(c)
|
reflectVal := reflect.ValueOf(c)
|
||||||
@ -232,14 +243,42 @@ func (p *ControllerRegistor) AddAuto(c ControllerInterface) {
|
|||||||
p.autoRouter[firstParam] = make(map[string]reflect.Type)
|
p.autoRouter[firstParam] = make(map[string]reflect.Type)
|
||||||
}
|
}
|
||||||
for i := 0; i < rt.NumMethod(); i++ {
|
for i := 0; i < rt.NumMethod(); i++ {
|
||||||
|
if !utils.InSlice(rt.Method(i).Name, exceptMethod) {
|
||||||
p.autoRouter[firstParam][rt.Method(i).Name] = ct
|
p.autoRouter[firstParam][rt.Method(i).Name] = ct
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add auto router to ControllerRegistor with prefix.
|
||||||
|
// example beego.AddAutoPrefix("/admin",&MainContorlller{}),
|
||||||
|
// MainController has method List and Page.
|
||||||
|
// visit the url /admin/main/list to execute List function
|
||||||
|
// /admin/main/page to execute Page function.
|
||||||
|
func (p *ControllerRegistor) AddAutoPrefix(prefix string, c ControllerInterface) {
|
||||||
|
p.enableAuto = true
|
||||||
|
reflectVal := reflect.ValueOf(c)
|
||||||
|
rt := reflectVal.Type()
|
||||||
|
ct := reflect.Indirect(reflectVal).Type()
|
||||||
|
firstParam := strings.Trim(prefix, "/") + "/" + strings.ToLower(strings.TrimSuffix(ct.Name(), "Controller"))
|
||||||
|
if _, ok := p.autoRouter[firstParam]; ok {
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
p.autoRouter[firstParam] = make(map[string]reflect.Type)
|
||||||
|
}
|
||||||
|
for i := 0; i < rt.NumMethod(); i++ {
|
||||||
|
if !utils.InSlice(rt.Method(i).Name, exceptMethod) {
|
||||||
|
p.autoRouter[firstParam][rt.Method(i).Name] = ct
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// [Deprecated] use InsertFilter.
|
// [Deprecated] use InsertFilter.
|
||||||
// Add FilterFunc with pattern for action.
|
// Add FilterFunc with pattern for action.
|
||||||
func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc) {
|
func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc) error {
|
||||||
mr := buildFilter(pattern, filter)
|
mr, err := buildFilter(pattern, filter)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
switch action {
|
switch action {
|
||||||
case "BeforeRouter":
|
case "BeforeRouter":
|
||||||
p.filters[BeforeRouter] = append(p.filters[BeforeRouter], mr)
|
p.filters[BeforeRouter] = append(p.filters[BeforeRouter], mr)
|
||||||
@ -253,13 +292,18 @@ func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc
|
|||||||
p.filters[FinishRouter] = append(p.filters[FinishRouter], mr)
|
p.filters[FinishRouter] = append(p.filters[FinishRouter], mr)
|
||||||
}
|
}
|
||||||
p.enableFilter = true
|
p.enableFilter = true
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add a FilterFunc with pattern rule and action constant.
|
// Add a FilterFunc with pattern rule and action constant.
|
||||||
func (p *ControllerRegistor) InsertFilter(pattern string, pos int, filter FilterFunc) {
|
func (p *ControllerRegistor) InsertFilter(pattern string, pos int, filter FilterFunc) error {
|
||||||
mr := buildFilter(pattern, filter)
|
mr, err := buildFilter(pattern, filter)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
p.filters[pos] = append(p.filters[pos], mr)
|
p.filters[pos] = append(p.filters[pos], mr)
|
||||||
p.enableFilter = true
|
p.enableFilter = true
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UrlFor does another controller handler in this request function.
|
// UrlFor does another controller handler in this request function.
|
||||||
@ -485,7 +529,9 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
|||||||
// session init
|
// session init
|
||||||
if SessionOn {
|
if SessionOn {
|
||||||
context.Input.CruSession = GlobalSessions.SessionStart(w, r)
|
context.Input.CruSession = GlobalSessions.SessionStart(w, r)
|
||||||
defer context.Input.CruSession.SessionRelease()
|
defer func() {
|
||||||
|
context.Input.CruSession.SessionRelease(w)
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
if !utils.InSlice(strings.ToLower(r.Method), HTTPMETHOD) {
|
if !utils.InSlice(strings.ToLower(r.Method), HTTPMETHOD) {
|
||||||
@ -575,12 +621,11 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
|||||||
}
|
}
|
||||||
// pattern /admin url /admin 200 /admin/ 200
|
// pattern /admin url /admin 200 /admin/ 200
|
||||||
// pattern /admin/ url /admin 301 /admin/ 200
|
// pattern /admin/ url /admin 301 /admin/ 200
|
||||||
if requestPath[n-1] != '/' && len(route.pattern) == n+1 &&
|
if requestPath[n-1] != '/' && requestPath+"/" == route.pattern {
|
||||||
route.pattern[n] == '/' && route.pattern[:n] == requestPath {
|
|
||||||
http.Redirect(w, r, requestPath+"/", 301)
|
http.Redirect(w, r, requestPath+"/", 301)
|
||||||
goto Admin
|
goto Admin
|
||||||
}
|
}
|
||||||
if requestPath[n-1] == '/' && n >= 2 && requestPath[:n-2] == route.pattern {
|
if requestPath[n-1] == '/' && route.pattern+"/" == requestPath {
|
||||||
runMethod = p.getRunMethod(r.Method, context, route)
|
runMethod = p.getRunMethod(r.Method, context, route)
|
||||||
if runMethod != "" {
|
if runMethod != "" {
|
||||||
runrouter = route.controllerType
|
runrouter = route.controllerType
|
||||||
@ -857,3 +902,13 @@ func (w *responseWriter) WriteHeader(code int) {
|
|||||||
w.started = true
|
w.started = true
|
||||||
w.writer.WriteHeader(code)
|
w.writer.WriteHeader(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// hijacker for http
|
||||||
|
func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
hj, ok := w.writer.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
println("supported?")
|
||||||
|
return nil, nil, errors.New("webserver doesn't support hijacking")
|
||||||
|
}
|
||||||
|
return hj.Hijack()
|
||||||
|
}
|
||||||
|
@ -198,3 +198,15 @@ func TestPrepare(t *testing.T) {
|
|||||||
t.Errorf(w.Body.String() + "user define func can't run")
|
t.Errorf(w.Body.String() + "user define func can't run")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAutoPrefix(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "/admin/test/list", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler := NewControllerRegistor()
|
||||||
|
handler.AddAutoPrefix("/admin", &TestController{})
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if w.Body.String() != "i am list" {
|
||||||
|
t.Errorf("TestAutoPrefix can't run")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -28,21 +28,21 @@ Then in you web app init the global session manager
|
|||||||
* Use **memory** as provider:
|
* Use **memory** as provider:
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
globalSessions, _ = session.NewManager("memory", "gosessionid", 3600,"")
|
globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid","gclifetime":3600}`)
|
||||||
go globalSessions.GC()
|
go globalSessions.GC()
|
||||||
}
|
}
|
||||||
|
|
||||||
* Use **file** as provider, the last param is the path where you want file to be stored:
|
* Use **file** as provider, the last param is the path where you want file to be stored:
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
globalSessions, _ = session.NewManager("file", "gosessionid", 3600, "./tmp")
|
globalSessions, _ = session.NewManager("file",`{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","./tmp"}`)
|
||||||
go globalSessions.GC()
|
go globalSessions.GC()
|
||||||
}
|
}
|
||||||
|
|
||||||
* Use **Redis** as provider, the last param is the Redis conn address,poolsize,password:
|
* Use **Redis** as provider, the last param is the Redis conn address,poolsize,password:
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
globalSessions, _ = session.NewManager("redis", "gosessionid", 3600, "127.0.0.1:6379,100,astaxie")
|
globalSessions, _ = session.NewManager("redis", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","127.0.0.1:6379,100,astaxie"}`)
|
||||||
go globalSessions.GC()
|
go globalSessions.GC()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -50,15 +50,24 @@ Then in you web app init the global session manager
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
globalSessions, _ = session.NewManager(
|
globalSessions, _ = session.NewManager(
|
||||||
"mysql", "gosessionid", 3600, "username:password@protocol(address)/dbname?param=value")
|
"mysql", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","username:password@protocol(address)/dbname?param=value"}`)
|
||||||
go globalSessions.GC()
|
go globalSessions.GC()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
* Use **Cookie** as provider:
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
globalSessions, _ = session.NewManager(
|
||||||
|
"cookie", `{"cookieName":"gosessionid","enableSetCookie":false,gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`)
|
||||||
|
go globalSessions.GC()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
Finally in the handlerfunc you can use it like this
|
Finally in the handlerfunc you can use it like this
|
||||||
|
|
||||||
func login(w http.ResponseWriter, r *http.Request) {
|
func login(w http.ResponseWriter, r *http.Request) {
|
||||||
sess := globalSessions.SessionStart(w, r)
|
sess := globalSessions.SessionStart(w, r)
|
||||||
defer sess.SessionRelease()
|
defer sess.SessionRelease(w)
|
||||||
username := sess.Get("username")
|
username := sess.Get("username")
|
||||||
fmt.Println(username)
|
fmt.Println(username)
|
||||||
if r.Method == "GET" {
|
if r.Method == "GET" {
|
||||||
@ -78,19 +87,19 @@ When you develop a web app, maybe you want to write own provider because you mus
|
|||||||
|
|
||||||
Writing a provider is easy. You only need to define two struct types
|
Writing a provider is easy. You only need to define two struct types
|
||||||
(Session and Provider), which satisfy the interface definition.
|
(Session and Provider), which satisfy the interface definition.
|
||||||
Maybe you will find the **memory** provider as good example.
|
Maybe you will find the **memory** provider is a good example.
|
||||||
|
|
||||||
type SessionStore interface {
|
type SessionStore interface {
|
||||||
Set(key, value interface{}) error //set session value
|
Set(key, value interface{}) error //set session value
|
||||||
Get(key interface{}) interface{} //get session value
|
Get(key interface{}) interface{} //get session value
|
||||||
Delete(key interface{}) error //delete session value
|
Delete(key interface{}) error //delete session value
|
||||||
SessionID() string //back current sessionID
|
SessionID() string //back current sessionID
|
||||||
SessionRelease() // release the resource & save data to provider
|
SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data
|
||||||
Flush() error //delete all data
|
Flush() error //delete all data
|
||||||
}
|
}
|
||||||
|
|
||||||
type Provider interface {
|
type Provider interface {
|
||||||
SessionInit(maxlifetime int64, savePath string) error
|
SessionInit(gclifetime int64, config string) error
|
||||||
SessionRead(sid string) (SessionStore, error)
|
SessionRead(sid string) (SessionStore, error)
|
||||||
SessionExist(sid string) bool
|
SessionExist(sid string) bool
|
||||||
SessionRegenerate(oldsid, sid string) (SessionStore, error)
|
SessionRegenerate(oldsid, sid string) (SessionStore, error)
|
||||||
|
170
session/sess_cookie.go
Normal file
170
session/sess_cookie.go
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
var cookiepder = &CookieProvider{}
|
||||||
|
|
||||||
|
// Cookie SessionStore
|
||||||
|
type CookieSessionStore struct {
|
||||||
|
sid string
|
||||||
|
values map[interface{}]interface{} // session data
|
||||||
|
lock sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set value to cookie session.
|
||||||
|
// the value are encoded as gob with hash block string.
|
||||||
|
func (st *CookieSessionStore) Set(key, value interface{}) error {
|
||||||
|
st.lock.Lock()
|
||||||
|
defer st.lock.Unlock()
|
||||||
|
st.values[key] = value
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get value from cookie session
|
||||||
|
func (st *CookieSessionStore) Get(key interface{}) interface{} {
|
||||||
|
st.lock.RLock()
|
||||||
|
defer st.lock.RUnlock()
|
||||||
|
if v, ok := st.values[key]; ok {
|
||||||
|
return v
|
||||||
|
} else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete value in cookie session
|
||||||
|
func (st *CookieSessionStore) Delete(key interface{}) error {
|
||||||
|
st.lock.Lock()
|
||||||
|
defer st.lock.Unlock()
|
||||||
|
delete(st.values, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean all values in cookie session
|
||||||
|
func (st *CookieSessionStore) Flush() error {
|
||||||
|
st.lock.Lock()
|
||||||
|
defer st.lock.Unlock()
|
||||||
|
st.values = make(map[interface{}]interface{})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return id of this cookie session
|
||||||
|
func (st *CookieSessionStore) SessionID() string {
|
||||||
|
return st.sid
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write cookie session to http response cookie
|
||||||
|
func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) {
|
||||||
|
str, err := encodeCookie(cookiepder.block,
|
||||||
|
cookiepder.config.SecurityKey,
|
||||||
|
cookiepder.config.SecurityName,
|
||||||
|
st.values)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cookie := &http.Cookie{Name: cookiepder.config.CookieName,
|
||||||
|
Value: url.QueryEscape(str),
|
||||||
|
Path: "/",
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: cookiepder.config.Secure}
|
||||||
|
http.SetCookie(w, cookie)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type cookieConfig struct {
|
||||||
|
SecurityKey string `json:"securityKey"`
|
||||||
|
BlockKey string `json:"blockKey"`
|
||||||
|
SecurityName string `json:"securityName"`
|
||||||
|
CookieName string `json:"cookieName"`
|
||||||
|
Secure bool `json:"secure"`
|
||||||
|
Maxage int `json:"maxage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cookie session provider
|
||||||
|
type CookieProvider struct {
|
||||||
|
maxlifetime int64
|
||||||
|
config *cookieConfig
|
||||||
|
block cipher.Block
|
||||||
|
}
|
||||||
|
|
||||||
|
// Init cookie session provider with max lifetime and config json.
|
||||||
|
// maxlifetime is ignored.
|
||||||
|
// json config:
|
||||||
|
// securityKey - hash string
|
||||||
|
// blockKey - gob encode hash string. it's saved as aes crypto.
|
||||||
|
// securityName - recognized name in encoded cookie string
|
||||||
|
// cookieName - cookie name
|
||||||
|
// maxage - cookie max life time.
|
||||||
|
func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error {
|
||||||
|
pder.config = &cookieConfig{}
|
||||||
|
err := json.Unmarshal([]byte(config), pder.config)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if pder.config.BlockKey == "" {
|
||||||
|
pder.config.BlockKey = string(generateRandomKey(16))
|
||||||
|
}
|
||||||
|
if pder.config.SecurityName == "" {
|
||||||
|
pder.config.SecurityName = string(generateRandomKey(20))
|
||||||
|
}
|
||||||
|
pder.block, err = aes.NewCipher([]byte(pder.config.BlockKey))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get SessionStore in cooke.
|
||||||
|
// decode cooke string to map and put into SessionStore with sid.
|
||||||
|
func (pder *CookieProvider) SessionRead(sid string) (SessionStore, error) {
|
||||||
|
maps, _ := decodeCookie(pder.block,
|
||||||
|
pder.config.SecurityKey,
|
||||||
|
pder.config.SecurityName,
|
||||||
|
sid, pder.maxlifetime)
|
||||||
|
if maps == nil {
|
||||||
|
maps = make(map[interface{}]interface{})
|
||||||
|
}
|
||||||
|
rs := &CookieSessionStore{sid: sid, values: maps}
|
||||||
|
return rs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cookie session is always existed
|
||||||
|
func (pder *CookieProvider) SessionExist(sid string) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement method, no used.
|
||||||
|
func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement method, no used.
|
||||||
|
func (pder *CookieProvider) SessionDestroy(sid string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement method, no used.
|
||||||
|
func (pder *CookieProvider) SessionGC() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement method, return 0.
|
||||||
|
func (pder *CookieProvider) SessionAll() int {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement method, no used.
|
||||||
|
func (pder *CookieProvider) SessionUpdate(sid string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
Register("cookie", cookiepder)
|
||||||
|
}
|
38
session/sess_cookie_test.go
Normal file
38
session/sess_cookie_test.go
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCookie(t *testing.T) {
|
||||||
|
config := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
|
||||||
|
globalSessions, err := NewManager("cookie", config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("init cookie session err", err)
|
||||||
|
}
|
||||||
|
r, _ := http.NewRequest("GET", "/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
sess := globalSessions.SessionStart(w, r)
|
||||||
|
err = sess.Set("username", "astaxie")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("set error,", err)
|
||||||
|
}
|
||||||
|
if username := sess.Get("username"); username != "astaxie" {
|
||||||
|
t.Fatal("get username error")
|
||||||
|
}
|
||||||
|
sess.SessionRelease(w)
|
||||||
|
if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" {
|
||||||
|
t.Fatal("setcookie error")
|
||||||
|
} else {
|
||||||
|
parts := strings.Split(strings.TrimSpace(cookiestr), ";")
|
||||||
|
for k, v := range parts {
|
||||||
|
nameval := strings.Split(v, "=")
|
||||||
|
if k == 0 && nameval[0] != "gosessionid" {
|
||||||
|
t.Fatal("error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@ -17,6 +18,7 @@ var (
|
|||||||
gcmaxlifetime int64
|
gcmaxlifetime int64
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// File session store
|
||||||
type FileSessionStore struct {
|
type FileSessionStore struct {
|
||||||
f *os.File
|
f *os.File
|
||||||
sid string
|
sid string
|
||||||
@ -24,6 +26,7 @@ type FileSessionStore struct {
|
|||||||
values map[interface{}]interface{}
|
values map[interface{}]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set value to file session
|
||||||
func (fs *FileSessionStore) Set(key, value interface{}) error {
|
func (fs *FileSessionStore) Set(key, value interface{}) error {
|
||||||
fs.lock.Lock()
|
fs.lock.Lock()
|
||||||
defer fs.lock.Unlock()
|
defer fs.lock.Unlock()
|
||||||
@ -31,6 +34,7 @@ func (fs *FileSessionStore) Set(key, value interface{}) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get value from file session
|
||||||
func (fs *FileSessionStore) Get(key interface{}) interface{} {
|
func (fs *FileSessionStore) Get(key interface{}) interface{} {
|
||||||
fs.lock.RLock()
|
fs.lock.RLock()
|
||||||
defer fs.lock.RUnlock()
|
defer fs.lock.RUnlock()
|
||||||
@ -42,6 +46,7 @@ func (fs *FileSessionStore) Get(key interface{}) interface{} {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Delete value in file session by given key
|
||||||
func (fs *FileSessionStore) Delete(key interface{}) error {
|
func (fs *FileSessionStore) Delete(key interface{}) error {
|
||||||
fs.lock.Lock()
|
fs.lock.Lock()
|
||||||
defer fs.lock.Unlock()
|
defer fs.lock.Unlock()
|
||||||
@ -49,6 +54,7 @@ func (fs *FileSessionStore) Delete(key interface{}) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clean all values in file session
|
||||||
func (fs *FileSessionStore) Flush() error {
|
func (fs *FileSessionStore) Flush() error {
|
||||||
fs.lock.Lock()
|
fs.lock.Lock()
|
||||||
defer fs.lock.Unlock()
|
defer fs.lock.Unlock()
|
||||||
@ -56,11 +62,13 @@ func (fs *FileSessionStore) Flush() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get file session store id
|
||||||
func (fs *FileSessionStore) SessionID() string {
|
func (fs *FileSessionStore) SessionID() string {
|
||||||
return fs.sid
|
return fs.sid
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fs *FileSessionStore) SessionRelease() {
|
// Write file session to local file with Gob string
|
||||||
|
func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) {
|
||||||
defer fs.f.Close()
|
defer fs.f.Close()
|
||||||
b, err := encodeGob(fs.values)
|
b, err := encodeGob(fs.values)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -71,17 +79,23 @@ func (fs *FileSessionStore) SessionRelease() {
|
|||||||
fs.f.Write(b)
|
fs.f.Write(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// File session provider
|
||||||
type FileProvider struct {
|
type FileProvider struct {
|
||||||
maxlifetime int64
|
maxlifetime int64
|
||||||
savePath string
|
savePath string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Init file session provider.
|
||||||
|
// savePath sets the session files path.
|
||||||
func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error {
|
func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error {
|
||||||
fp.maxlifetime = maxlifetime
|
fp.maxlifetime = maxlifetime
|
||||||
fp.savePath = savePath
|
fp.savePath = savePath
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Read file session by sid.
|
||||||
|
// if file is not exist, create it.
|
||||||
|
// the file path is generated from sid string.
|
||||||
func (fp *FileProvider) SessionRead(sid string) (SessionStore, error) {
|
func (fp *FileProvider) SessionRead(sid string) (SessionStore, error) {
|
||||||
err := os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777)
|
err := os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -116,6 +130,8 @@ func (fp *FileProvider) SessionRead(sid string) (SessionStore, error) {
|
|||||||
return ss, nil
|
return ss, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check file session exist.
|
||||||
|
// it checkes the file named from sid exist or not.
|
||||||
func (fp *FileProvider) SessionExist(sid string) bool {
|
func (fp *FileProvider) SessionExist(sid string) bool {
|
||||||
_, err := os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
|
_, err := os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@ -125,16 +141,20 @@ func (fp *FileProvider) SessionExist(sid string) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Remove all files in this save path
|
||||||
func (fp *FileProvider) SessionDestroy(sid string) error {
|
func (fp *FileProvider) SessionDestroy(sid string) error {
|
||||||
os.Remove(path.Join(fp.savePath))
|
os.Remove(path.Join(fp.savePath))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Recycle files in save path
|
||||||
func (fp *FileProvider) SessionGC() {
|
func (fp *FileProvider) SessionGC() {
|
||||||
gcmaxlifetime = fp.maxlifetime
|
gcmaxlifetime = fp.maxlifetime
|
||||||
filepath.Walk(fp.savePath, gcpath)
|
filepath.Walk(fp.savePath, gcpath)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get active file session number.
|
||||||
|
// it walks save path to count files.
|
||||||
func (fp *FileProvider) SessionAll() int {
|
func (fp *FileProvider) SessionAll() int {
|
||||||
a := &activeSession{}
|
a := &activeSession{}
|
||||||
err := filepath.Walk(fp.savePath, func(path string, f os.FileInfo, err error) error {
|
err := filepath.Walk(fp.savePath, func(path string, f os.FileInfo, err error) error {
|
||||||
@ -147,6 +167,8 @@ func (fp *FileProvider) SessionAll() int {
|
|||||||
return a.total
|
return a.total
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Generate new sid for file session.
|
||||||
|
// it delete old file and create new file named from new sid.
|
||||||
func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
|
func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
|
||||||
err := os.MkdirAll(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])), 0777)
|
err := os.MkdirAll(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])), 0777)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -196,6 +218,7 @@ func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (SessionStore, err
|
|||||||
return ss, nil
|
return ss, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// remove file in save path if expired
|
||||||
func gcpath(path string, info os.FileInfo, err error) error {
|
func gcpath(path string, info os.FileInfo, err error) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -1,38 +0,0 @@
|
|||||||
package session
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/gob"
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
gob.Register([]interface{}{})
|
|
||||||
gob.Register(map[int]interface{}{})
|
|
||||||
gob.Register(map[string]interface{}{})
|
|
||||||
gob.Register(map[interface{}]interface{}{})
|
|
||||||
gob.Register(map[string]string{})
|
|
||||||
gob.Register(map[int]string{})
|
|
||||||
gob.Register(map[int]int{})
|
|
||||||
gob.Register(map[int]int64{})
|
|
||||||
}
|
|
||||||
|
|
||||||
func encodeGob(obj map[interface{}]interface{}) ([]byte, error) {
|
|
||||||
buf := bytes.NewBuffer(nil)
|
|
||||||
enc := gob.NewEncoder(buf)
|
|
||||||
err := enc.Encode(obj)
|
|
||||||
if err != nil {
|
|
||||||
return []byte(""), err
|
|
||||||
}
|
|
||||||
return buf.Bytes(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func decodeGob(encoded []byte) (map[interface{}]interface{}, error) {
|
|
||||||
buf := bytes.NewBuffer(encoded)
|
|
||||||
dec := gob.NewDecoder(buf)
|
|
||||||
var out map[interface{}]interface{}
|
|
||||||
err := dec.Decode(&out)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
@ -2,19 +2,23 @@ package session
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"container/list"
|
"container/list"
|
||||||
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)}
|
var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)}
|
||||||
|
|
||||||
|
// memory session store.
|
||||||
|
// it saved sessions in a map in memory.
|
||||||
type MemSessionStore struct {
|
type MemSessionStore struct {
|
||||||
sid string //session id唯一标示
|
sid string //session id
|
||||||
timeAccessed time.Time //最后访问时间
|
timeAccessed time.Time //last access time
|
||||||
value map[interface{}]interface{} //session里面存储的值
|
value map[interface{}]interface{} //session store
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set value to memory session
|
||||||
func (st *MemSessionStore) Set(key, value interface{}) error {
|
func (st *MemSessionStore) Set(key, value interface{}) error {
|
||||||
st.lock.Lock()
|
st.lock.Lock()
|
||||||
defer st.lock.Unlock()
|
defer st.lock.Unlock()
|
||||||
@ -22,6 +26,7 @@ func (st *MemSessionStore) Set(key, value interface{}) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get value from memory session by key
|
||||||
func (st *MemSessionStore) Get(key interface{}) interface{} {
|
func (st *MemSessionStore) Get(key interface{}) interface{} {
|
||||||
st.lock.RLock()
|
st.lock.RLock()
|
||||||
defer st.lock.RUnlock()
|
defer st.lock.RUnlock()
|
||||||
@ -33,6 +38,7 @@ func (st *MemSessionStore) Get(key interface{}) interface{} {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// delete in memory session by key
|
||||||
func (st *MemSessionStore) Delete(key interface{}) error {
|
func (st *MemSessionStore) Delete(key interface{}) error {
|
||||||
st.lock.Lock()
|
st.lock.Lock()
|
||||||
defer st.lock.Unlock()
|
defer st.lock.Unlock()
|
||||||
@ -40,6 +46,7 @@ func (st *MemSessionStore) Delete(key interface{}) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clear all values in memory session
|
||||||
func (st *MemSessionStore) Flush() error {
|
func (st *MemSessionStore) Flush() error {
|
||||||
st.lock.Lock()
|
st.lock.Lock()
|
||||||
defer st.lock.Unlock()
|
defer st.lock.Unlock()
|
||||||
@ -47,28 +54,31 @@ func (st *MemSessionStore) Flush() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get this id of memory session store
|
||||||
func (st *MemSessionStore) SessionID() string {
|
func (st *MemSessionStore) SessionID() string {
|
||||||
return st.sid
|
return st.sid
|
||||||
}
|
}
|
||||||
|
|
||||||
func (st *MemSessionStore) SessionRelease() {
|
// Implement method, no used.
|
||||||
|
func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) {
|
||||||
}
|
}
|
||||||
|
|
||||||
type MemProvider struct {
|
type MemProvider struct {
|
||||||
lock sync.RWMutex //用来锁
|
lock sync.RWMutex // locker
|
||||||
sessions map[string]*list.Element //用来存储在内存
|
sessions map[string]*list.Element // map in memory
|
||||||
list *list.List //用来做gc
|
list *list.List // for gc
|
||||||
maxlifetime int64
|
maxlifetime int64
|
||||||
savePath string
|
savePath string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// init memory session
|
||||||
func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error {
|
func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error {
|
||||||
pder.maxlifetime = maxlifetime
|
pder.maxlifetime = maxlifetime
|
||||||
pder.savePath = savePath
|
pder.savePath = savePath
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get memory session store by sid
|
||||||
func (pder *MemProvider) SessionRead(sid string) (SessionStore, error) {
|
func (pder *MemProvider) SessionRead(sid string) (SessionStore, error) {
|
||||||
pder.lock.RLock()
|
pder.lock.RLock()
|
||||||
if element, ok := pder.sessions[sid]; ok {
|
if element, ok := pder.sessions[sid]; ok {
|
||||||
@ -87,6 +97,7 @@ func (pder *MemProvider) SessionRead(sid string) (SessionStore, error) {
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check session store exist in memory session by sid
|
||||||
func (pder *MemProvider) SessionExist(sid string) bool {
|
func (pder *MemProvider) SessionExist(sid string) bool {
|
||||||
pder.lock.RLock()
|
pder.lock.RLock()
|
||||||
defer pder.lock.RUnlock()
|
defer pder.lock.RUnlock()
|
||||||
@ -97,6 +108,7 @@ func (pder *MemProvider) SessionExist(sid string) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generate new sid for session store in memory session
|
||||||
func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
|
func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
|
||||||
pder.lock.RLock()
|
pder.lock.RLock()
|
||||||
if element, ok := pder.sessions[oldsid]; ok {
|
if element, ok := pder.sessions[oldsid]; ok {
|
||||||
@ -120,6 +132,7 @@ func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (SessionStore, er
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// delete session store in memory session by id
|
||||||
func (pder *MemProvider) SessionDestroy(sid string) error {
|
func (pder *MemProvider) SessionDestroy(sid string) error {
|
||||||
pder.lock.Lock()
|
pder.lock.Lock()
|
||||||
defer pder.lock.Unlock()
|
defer pder.lock.Unlock()
|
||||||
@ -131,6 +144,7 @@ func (pder *MemProvider) SessionDestroy(sid string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clean expired session stores in memory session
|
||||||
func (pder *MemProvider) SessionGC() {
|
func (pder *MemProvider) SessionGC() {
|
||||||
pder.lock.RLock()
|
pder.lock.RLock()
|
||||||
for {
|
for {
|
||||||
@ -152,10 +166,12 @@ func (pder *MemProvider) SessionGC() {
|
|||||||
pder.lock.RUnlock()
|
pder.lock.RUnlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get count number of memory session
|
||||||
func (pder *MemProvider) SessionAll() int {
|
func (pder *MemProvider) SessionAll() int {
|
||||||
return pder.list.Len()
|
return pder.list.Len()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// expand time of session store by id in memory session
|
||||||
func (pder *MemProvider) SessionUpdate(sid string) error {
|
func (pder *MemProvider) SessionUpdate(sid string) error {
|
||||||
pder.lock.Lock()
|
pder.lock.Lock()
|
||||||
defer pder.lock.Unlock()
|
defer pder.lock.Unlock()
|
||||||
|
35
session/sess_mem_test.go
Normal file
35
session/sess_mem_test.go
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMem(t *testing.T) {
|
||||||
|
globalSessions, _ := NewManager("memory", `{"cookieName":"gosessionid","gclifetime":10}`)
|
||||||
|
go globalSessions.GC()
|
||||||
|
r, _ := http.NewRequest("GET", "/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
sess := globalSessions.SessionStart(w, r)
|
||||||
|
defer sess.SessionRelease(w)
|
||||||
|
err := sess.Set("username", "astaxie")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("set error,", err)
|
||||||
|
}
|
||||||
|
if username := sess.Get("username"); username != "astaxie" {
|
||||||
|
t.Fatal("get username error")
|
||||||
|
}
|
||||||
|
if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" {
|
||||||
|
t.Fatal("setcookie error")
|
||||||
|
} else {
|
||||||
|
parts := strings.Split(strings.TrimSpace(cookiestr), ";")
|
||||||
|
for k, v := range parts {
|
||||||
|
nameval := strings.Split(v, "=")
|
||||||
|
if k == 0 && nameval[0] != "gosessionid" {
|
||||||
|
t.Fatal("error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,14 +1,16 @@
|
|||||||
package session
|
package session
|
||||||
|
|
||||||
|
// mysql session support need create table as sql:
|
||||||
// CREATE TABLE `session` (
|
// CREATE TABLE `session` (
|
||||||
// `session_key` char(64) NOT NULL,
|
// `session_key` char(64) NOT NULL,
|
||||||
// `session_data` blob,
|
// session_data` blob,
|
||||||
// `session_expiry` int(11) unsigned NOT NULL,
|
// `session_expiry` int(11) unsigned NOT NULL,
|
||||||
// PRIMARY KEY (`session_key`)
|
// PRIMARY KEY (`session_key`)
|
||||||
// ) ENGINE=MyISAM DEFAULT CHARSET=utf8;
|
// ) ENGINE=MyISAM DEFAULT CHARSET=utf8;
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -17,6 +19,7 @@ import (
|
|||||||
|
|
||||||
var mysqlpder = &MysqlProvider{}
|
var mysqlpder = &MysqlProvider{}
|
||||||
|
|
||||||
|
// mysql session store
|
||||||
type MysqlSessionStore struct {
|
type MysqlSessionStore struct {
|
||||||
c *sql.DB
|
c *sql.DB
|
||||||
sid string
|
sid string
|
||||||
@ -24,6 +27,8 @@ type MysqlSessionStore struct {
|
|||||||
values map[interface{}]interface{}
|
values map[interface{}]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set value in mysql session.
|
||||||
|
// it is temp value in map.
|
||||||
func (st *MysqlSessionStore) Set(key, value interface{}) error {
|
func (st *MysqlSessionStore) Set(key, value interface{}) error {
|
||||||
st.lock.Lock()
|
st.lock.Lock()
|
||||||
defer st.lock.Unlock()
|
defer st.lock.Unlock()
|
||||||
@ -31,6 +36,7 @@ func (st *MysqlSessionStore) Set(key, value interface{}) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get value from mysql session
|
||||||
func (st *MysqlSessionStore) Get(key interface{}) interface{} {
|
func (st *MysqlSessionStore) Get(key interface{}) interface{} {
|
||||||
st.lock.RLock()
|
st.lock.RLock()
|
||||||
defer st.lock.RUnlock()
|
defer st.lock.RUnlock()
|
||||||
@ -42,6 +48,7 @@ func (st *MysqlSessionStore) Get(key interface{}) interface{} {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// delete value in mysql session
|
||||||
func (st *MysqlSessionStore) Delete(key interface{}) error {
|
func (st *MysqlSessionStore) Delete(key interface{}) error {
|
||||||
st.lock.Lock()
|
st.lock.Lock()
|
||||||
defer st.lock.Unlock()
|
defer st.lock.Unlock()
|
||||||
@ -49,6 +56,7 @@ func (st *MysqlSessionStore) Delete(key interface{}) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clear all values in mysql session
|
||||||
func (st *MysqlSessionStore) Flush() error {
|
func (st *MysqlSessionStore) Flush() error {
|
||||||
st.lock.Lock()
|
st.lock.Lock()
|
||||||
defer st.lock.Unlock()
|
defer st.lock.Unlock()
|
||||||
@ -56,26 +64,31 @@ func (st *MysqlSessionStore) Flush() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get session id of this mysql session store
|
||||||
func (st *MysqlSessionStore) SessionID() string {
|
func (st *MysqlSessionStore) SessionID() string {
|
||||||
return st.sid
|
return st.sid
|
||||||
}
|
}
|
||||||
|
|
||||||
func (st *MysqlSessionStore) SessionRelease() {
|
// save mysql session values to database.
|
||||||
|
// must call this method to save values to database.
|
||||||
|
func (st *MysqlSessionStore) SessionRelease(w http.ResponseWriter) {
|
||||||
defer st.c.Close()
|
defer st.c.Close()
|
||||||
if len(st.values) > 0 {
|
|
||||||
b, err := encodeGob(st.values)
|
b, err := encodeGob(st.values)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
st.c.Exec("UPDATE session set `session_data`= ? where session_key=?", b, st.sid)
|
st.c.Exec("UPDATE session set `session_data`=?, `session_expiry`=? where session_key=?",
|
||||||
}
|
b, time.Now().Unix(), st.sid)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mysql session provider
|
||||||
type MysqlProvider struct {
|
type MysqlProvider struct {
|
||||||
maxlifetime int64
|
maxlifetime int64
|
||||||
savePath string
|
savePath string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// connect to mysql
|
||||||
func (mp *MysqlProvider) connectInit() *sql.DB {
|
func (mp *MysqlProvider) connectInit() *sql.DB {
|
||||||
db, e := sql.Open("mysql", mp.savePath)
|
db, e := sql.Open("mysql", mp.savePath)
|
||||||
if e != nil {
|
if e != nil {
|
||||||
@ -84,19 +97,23 @@ func (mp *MysqlProvider) connectInit() *sql.DB {
|
|||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// init mysql session.
|
||||||
|
// savepath is the connection string of mysql.
|
||||||
func (mp *MysqlProvider) SessionInit(maxlifetime int64, savePath string) error {
|
func (mp *MysqlProvider) SessionInit(maxlifetime int64, savePath string) error {
|
||||||
mp.maxlifetime = maxlifetime
|
mp.maxlifetime = maxlifetime
|
||||||
mp.savePath = savePath
|
mp.savePath = savePath
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get mysql session by sid
|
||||||
func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) {
|
func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) {
|
||||||
c := mp.connectInit()
|
c := mp.connectInit()
|
||||||
row := c.QueryRow("select session_data from session where session_key=?", sid)
|
row := c.QueryRow("select session_data from session where session_key=?", sid)
|
||||||
var sessiondata []byte
|
var sessiondata []byte
|
||||||
err := row.Scan(&sessiondata)
|
err := row.Scan(&sessiondata)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", sid, "", time.Now().Unix())
|
c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)",
|
||||||
|
sid, "", time.Now().Unix())
|
||||||
}
|
}
|
||||||
var kv map[interface{}]interface{}
|
var kv map[interface{}]interface{}
|
||||||
if len(sessiondata) == 0 {
|
if len(sessiondata) == 0 {
|
||||||
@ -111,8 +128,10 @@ func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) {
|
|||||||
return rs, nil
|
return rs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check mysql session exist
|
||||||
func (mp *MysqlProvider) SessionExist(sid string) bool {
|
func (mp *MysqlProvider) SessionExist(sid string) bool {
|
||||||
c := mp.connectInit()
|
c := mp.connectInit()
|
||||||
|
defer c.Close()
|
||||||
row := c.QueryRow("select session_data from session where session_key=?", sid)
|
row := c.QueryRow("select session_data from session where session_key=?", sid)
|
||||||
var sessiondata []byte
|
var sessiondata []byte
|
||||||
err := row.Scan(&sessiondata)
|
err := row.Scan(&sessiondata)
|
||||||
@ -123,6 +142,7 @@ func (mp *MysqlProvider) SessionExist(sid string) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generate new sid for mysql session
|
||||||
func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
|
func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
|
||||||
c := mp.connectInit()
|
c := mp.connectInit()
|
||||||
row := c.QueryRow("select session_data from session where session_key=?", oldsid)
|
row := c.QueryRow("select session_data from session where session_key=?", oldsid)
|
||||||
@ -145,6 +165,7 @@ func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (SessionStore, er
|
|||||||
return rs, nil
|
return rs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// delete mysql session by sid
|
||||||
func (mp *MysqlProvider) SessionDestroy(sid string) error {
|
func (mp *MysqlProvider) SessionDestroy(sid string) error {
|
||||||
c := mp.connectInit()
|
c := mp.connectInit()
|
||||||
c.Exec("DELETE FROM session where session_key=?", sid)
|
c.Exec("DELETE FROM session where session_key=?", sid)
|
||||||
@ -152,6 +173,7 @@ func (mp *MysqlProvider) SessionDestroy(sid string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// delete expired values in mysql session
|
||||||
func (mp *MysqlProvider) SessionGC() {
|
func (mp *MysqlProvider) SessionGC() {
|
||||||
c := mp.connectInit()
|
c := mp.connectInit()
|
||||||
c.Exec("DELETE from session where session_expiry < ?", time.Now().Unix()-mp.maxlifetime)
|
c.Exec("DELETE from session where session_expiry < ?", time.Now().Unix()-mp.maxlifetime)
|
||||||
@ -159,6 +181,7 @@ func (mp *MysqlProvider) SessionGC() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// count values in mysql session
|
||||||
func (mp *MysqlProvider) SessionAll() int {
|
func (mp *MysqlProvider) SessionAll() int {
|
||||||
c := mp.connectInit()
|
c := mp.connectInit()
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package session
|
package session
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -10,18 +11,21 @@ import (
|
|||||||
|
|
||||||
var redispder = &RedisProvider{}
|
var redispder = &RedisProvider{}
|
||||||
|
|
||||||
|
// redis max pool size
|
||||||
var MAX_POOL_SIZE = 100
|
var MAX_POOL_SIZE = 100
|
||||||
|
|
||||||
var redisPool chan redis.Conn
|
var redisPool chan redis.Conn
|
||||||
|
|
||||||
|
// redis session store
|
||||||
type RedisSessionStore struct {
|
type RedisSessionStore struct {
|
||||||
c redis.Conn
|
p *redis.Pool
|
||||||
sid string
|
sid string
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
values map[interface{}]interface{}
|
values map[interface{}]interface{}
|
||||||
maxlifetime int64
|
maxlifetime int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set value in redis session
|
||||||
func (rs *RedisSessionStore) Set(key, value interface{}) error {
|
func (rs *RedisSessionStore) Set(key, value interface{}) error {
|
||||||
rs.lock.Lock()
|
rs.lock.Lock()
|
||||||
defer rs.lock.Unlock()
|
defer rs.lock.Unlock()
|
||||||
@ -29,6 +33,7 @@ func (rs *RedisSessionStore) Set(key, value interface{}) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get value in redis session
|
||||||
func (rs *RedisSessionStore) Get(key interface{}) interface{} {
|
func (rs *RedisSessionStore) Get(key interface{}) interface{} {
|
||||||
rs.lock.RLock()
|
rs.lock.RLock()
|
||||||
defer rs.lock.RUnlock()
|
defer rs.lock.RUnlock()
|
||||||
@ -40,6 +45,7 @@ func (rs *RedisSessionStore) Get(key interface{}) interface{} {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// delete value in redis session
|
||||||
func (rs *RedisSessionStore) Delete(key interface{}) error {
|
func (rs *RedisSessionStore) Delete(key interface{}) error {
|
||||||
rs.lock.Lock()
|
rs.lock.Lock()
|
||||||
defer rs.lock.Unlock()
|
defer rs.lock.Unlock()
|
||||||
@ -47,6 +53,7 @@ func (rs *RedisSessionStore) Delete(key interface{}) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clear all values in redis session
|
||||||
func (rs *RedisSessionStore) Flush() error {
|
func (rs *RedisSessionStore) Flush() error {
|
||||||
rs.lock.Lock()
|
rs.lock.Lock()
|
||||||
defer rs.lock.Unlock()
|
defer rs.lock.Unlock()
|
||||||
@ -54,22 +61,31 @@ func (rs *RedisSessionStore) Flush() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get redis session id
|
||||||
func (rs *RedisSessionStore) SessionID() string {
|
func (rs *RedisSessionStore) SessionID() string {
|
||||||
return rs.sid
|
return rs.sid
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs *RedisSessionStore) SessionRelease() {
|
// save session values to redis
|
||||||
defer rs.c.Close()
|
func (rs *RedisSessionStore) SessionRelease(w http.ResponseWriter) {
|
||||||
if len(rs.values) > 0 {
|
c := rs.p.Get()
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
// if rs.values is empty, return directly
|
||||||
|
if len(rs.values) < 1 {
|
||||||
|
c.Do("DEL", rs.sid)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
b, err := encodeGob(rs.values)
|
b, err := encodeGob(rs.values)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
rs.c.Do("SET", rs.sid, string(b))
|
|
||||||
rs.c.Do("EXPIRE", rs.sid, rs.maxlifetime)
|
c.Do("SET", rs.sid, string(b), "EX", rs.maxlifetime)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// redis session provider
|
||||||
type RedisProvider struct {
|
type RedisProvider struct {
|
||||||
maxlifetime int64
|
maxlifetime int64
|
||||||
savePath string
|
savePath string
|
||||||
@ -78,8 +94,9 @@ type RedisProvider struct {
|
|||||||
poollist *redis.Pool
|
poollist *redis.Pool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// init redis session
|
||||||
// savepath like redis server addr,pool size,password
|
// savepath like redis server addr,pool size,password
|
||||||
//127.0.0.1:6379,100,astaxie
|
// e.g. 127.0.0.1:6379,100,astaxie
|
||||||
func (rp *RedisProvider) SessionInit(maxlifetime int64, savePath string) error {
|
func (rp *RedisProvider) SessionInit(maxlifetime int64, savePath string) error {
|
||||||
rp.maxlifetime = maxlifetime
|
rp.maxlifetime = maxlifetime
|
||||||
configs := strings.Split(savePath, ",")
|
configs := strings.Split(savePath, ",")
|
||||||
@ -115,12 +132,11 @@ func (rp *RedisProvider) SessionInit(maxlifetime int64, savePath string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// read redis session by sid
|
||||||
func (rp *RedisProvider) SessionRead(sid string) (SessionStore, error) {
|
func (rp *RedisProvider) SessionRead(sid string) (SessionStore, error) {
|
||||||
c := rp.poollist.Get()
|
c := rp.poollist.Get()
|
||||||
if existed, err := redis.Int(c.Do("EXISTS", sid)); err != nil || existed == 0 {
|
defer c.Close()
|
||||||
c.Do("SET", sid)
|
|
||||||
}
|
|
||||||
c.Do("EXPIRE", sid, rp.maxlifetime)
|
|
||||||
kvs, err := redis.String(c.Do("GET", sid))
|
kvs, err := redis.String(c.Do("GET", sid))
|
||||||
var kv map[interface{}]interface{}
|
var kv map[interface{}]interface{}
|
||||||
if len(kvs) == 0 {
|
if len(kvs) == 0 {
|
||||||
@ -131,13 +147,16 @@ func (rp *RedisProvider) SessionRead(sid string) (SessionStore, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rs := &RedisSessionStore{c: c, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
|
|
||||||
|
rs := &RedisSessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
|
||||||
return rs, nil
|
return rs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check redis session exist by sid
|
||||||
func (rp *RedisProvider) SessionExist(sid string) bool {
|
func (rp *RedisProvider) SessionExist(sid string) bool {
|
||||||
c := rp.poollist.Get()
|
c := rp.poollist.Get()
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
|
|
||||||
if existed, err := redis.Int(c.Do("EXISTS", sid)); err != nil || existed == 0 {
|
if existed, err := redis.Int(c.Do("EXISTS", sid)); err != nil || existed == 0 {
|
||||||
return false
|
return false
|
||||||
} else {
|
} else {
|
||||||
@ -145,13 +164,21 @@ func (rp *RedisProvider) SessionExist(sid string) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generate new sid for redis session
|
||||||
func (rp *RedisProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
|
func (rp *RedisProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
|
||||||
c := rp.poollist.Get()
|
c := rp.poollist.Get()
|
||||||
if existed, err := redis.Int(c.Do("EXISTS", oldsid)); err != nil || existed == 0 {
|
defer c.Close()
|
||||||
c.Do("SET", oldsid)
|
|
||||||
}
|
if existed, _ := redis.Int(c.Do("EXISTS", oldsid)); existed == 0 {
|
||||||
|
// oldsid doesn't exists, set the new sid directly
|
||||||
|
// ignore error here, since if it return error
|
||||||
|
// the existed value will be 0
|
||||||
|
c.Do("SET", sid, "", "EX", rp.maxlifetime)
|
||||||
|
} else {
|
||||||
c.Do("RENAME", oldsid, sid)
|
c.Do("RENAME", oldsid, sid)
|
||||||
c.Do("EXPIRE", sid, rp.maxlifetime)
|
c.Do("EXPIRE", sid, rp.maxlifetime)
|
||||||
|
}
|
||||||
|
|
||||||
kvs, err := redis.String(c.Do("GET", sid))
|
kvs, err := redis.String(c.Do("GET", sid))
|
||||||
var kv map[interface{}]interface{}
|
var kv map[interface{}]interface{}
|
||||||
if len(kvs) == 0 {
|
if len(kvs) == 0 {
|
||||||
@ -162,24 +189,27 @@ func (rp *RedisProvider) SessionRegenerate(oldsid, sid string) (SessionStore, er
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rs := &RedisSessionStore{c: c, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
|
|
||||||
|
rs := &RedisSessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
|
||||||
return rs, nil
|
return rs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// delete redis session by id
|
||||||
func (rp *RedisProvider) SessionDestroy(sid string) error {
|
func (rp *RedisProvider) SessionDestroy(sid string) error {
|
||||||
c := rp.poollist.Get()
|
c := rp.poollist.Get()
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
|
|
||||||
c.Do("DEL", sid)
|
c.Do("DEL", sid)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Impelment method, no used.
|
||||||
func (rp *RedisProvider) SessionGC() {
|
func (rp *RedisProvider) SessionGC() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// @todo
|
// @todo
|
||||||
func (rp *RedisProvider) SessionAll() int {
|
func (rp *RedisProvider) SessionAll() int {
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
package session
|
package session
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/aes"
|
||||||
|
"encoding/json"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -26,3 +28,82 @@ func Test_gob(t *testing.T) {
|
|||||||
t.Error("decode int error")
|
t.Error("decode int error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGenerate(t *testing.T) {
|
||||||
|
str := generateRandomKey(20)
|
||||||
|
if len(str) != 20 {
|
||||||
|
t.Fatal("generate length is not equal to 20")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCookieEncodeDecode(t *testing.T) {
|
||||||
|
hashKey := "testhashKey"
|
||||||
|
blockkey := generateRandomKey(16)
|
||||||
|
block, err := aes.NewCipher(blockkey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("NewCipher:", err)
|
||||||
|
}
|
||||||
|
securityName := string(generateRandomKey(20))
|
||||||
|
val := make(map[interface{}]interface{})
|
||||||
|
val["name"] = "astaxie"
|
||||||
|
val["gender"] = "male"
|
||||||
|
str, err := encodeCookie(block, hashKey, securityName, val)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("encodeCookie:", err)
|
||||||
|
}
|
||||||
|
dst := make(map[interface{}]interface{})
|
||||||
|
dst, err = decodeCookie(block, hashKey, securityName, str, 3600)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("decodeCookie", err)
|
||||||
|
}
|
||||||
|
if dst["name"] != "astaxie" {
|
||||||
|
t.Fatal("dst get map error")
|
||||||
|
}
|
||||||
|
if dst["gender"] != "male" {
|
||||||
|
t.Fatal("dst get map error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseConfig(t *testing.T) {
|
||||||
|
s := `{"cookieName":"gosessionid","gclifetime":3600}`
|
||||||
|
cf := new(managerConfig)
|
||||||
|
cf.EnableSetCookie = true
|
||||||
|
err := json.Unmarshal([]byte(s), cf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("parse json error,", err)
|
||||||
|
}
|
||||||
|
if cf.CookieName != "gosessionid" {
|
||||||
|
t.Fatal("parseconfig get cookiename error")
|
||||||
|
}
|
||||||
|
if cf.Gclifetime != 3600 {
|
||||||
|
t.Fatal("parseconfig get gclifetime error")
|
||||||
|
}
|
||||||
|
|
||||||
|
cc := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
|
||||||
|
cf2 := new(managerConfig)
|
||||||
|
cf2.EnableSetCookie = true
|
||||||
|
err = json.Unmarshal([]byte(cc), cf2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("parse json error,", err)
|
||||||
|
}
|
||||||
|
if cf2.CookieName != "gosessionid" {
|
||||||
|
t.Fatal("parseconfig get cookiename error")
|
||||||
|
}
|
||||||
|
if cf2.Gclifetime != 3600 {
|
||||||
|
t.Fatal("parseconfig get gclifetime error")
|
||||||
|
}
|
||||||
|
if cf2.EnableSetCookie != false {
|
||||||
|
t.Fatal("parseconfig get enableSetCookie error")
|
||||||
|
}
|
||||||
|
cconfig := new(cookieConfig)
|
||||||
|
err = json.Unmarshal([]byte(cf2.ProviderConfig), cconfig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("parse ProviderConfig err,", err)
|
||||||
|
}
|
||||||
|
if cconfig.CookieName != "gosessionid" {
|
||||||
|
t.Fatal("ProviderConfig get cookieName error")
|
||||||
|
}
|
||||||
|
if cconfig.SecurityKey != "beegocookiehashkey" {
|
||||||
|
t.Fatal("ProviderConfig get securityKey error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
188
session/sess_utils.go
Normal file
188
session/sess_utils.go
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha1"
|
||||||
|
"crypto/subtle"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/gob"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
gob.Register([]interface{}{})
|
||||||
|
gob.Register(map[int]interface{}{})
|
||||||
|
gob.Register(map[string]interface{}{})
|
||||||
|
gob.Register(map[interface{}]interface{}{})
|
||||||
|
gob.Register(map[string]string{})
|
||||||
|
gob.Register(map[int]string{})
|
||||||
|
gob.Register(map[int]int{})
|
||||||
|
gob.Register(map[int]int64{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeGob(obj map[interface{}]interface{}) ([]byte, error) {
|
||||||
|
buf := bytes.NewBuffer(nil)
|
||||||
|
enc := gob.NewEncoder(buf)
|
||||||
|
err := enc.Encode(obj)
|
||||||
|
if err != nil {
|
||||||
|
return []byte(""), err
|
||||||
|
}
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeGob(encoded []byte) (map[interface{}]interface{}, error) {
|
||||||
|
buf := bytes.NewBuffer(encoded)
|
||||||
|
dec := gob.NewDecoder(buf)
|
||||||
|
var out map[interface{}]interface{}
|
||||||
|
err := dec.Decode(&out)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateRandomKey creates a random key with the given strength.
|
||||||
|
func generateRandomKey(strength int) []byte {
|
||||||
|
k := make([]byte, strength)
|
||||||
|
if _, err := io.ReadFull(rand.Reader, k); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return k
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encryption -----------------------------------------------------------------
|
||||||
|
|
||||||
|
// encrypt encrypts a value using the given block in counter mode.
|
||||||
|
//
|
||||||
|
// A random initialization vector (http://goo.gl/zF67k) with the length of the
|
||||||
|
// block size is prepended to the resulting ciphertext.
|
||||||
|
func encrypt(block cipher.Block, value []byte) ([]byte, error) {
|
||||||
|
iv := generateRandomKey(block.BlockSize())
|
||||||
|
if iv == nil {
|
||||||
|
return nil, errors.New("encrypt: failed to generate random iv")
|
||||||
|
}
|
||||||
|
// Encrypt it.
|
||||||
|
stream := cipher.NewCTR(block, iv)
|
||||||
|
stream.XORKeyStream(value, value)
|
||||||
|
// Return iv + ciphertext.
|
||||||
|
return append(iv, value...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// decrypt decrypts a value using the given block in counter mode.
|
||||||
|
//
|
||||||
|
// The value to be decrypted must be prepended by a initialization vector
|
||||||
|
// (http://goo.gl/zF67k) with the length of the block size.
|
||||||
|
func decrypt(block cipher.Block, value []byte) ([]byte, error) {
|
||||||
|
size := block.BlockSize()
|
||||||
|
if len(value) > size {
|
||||||
|
// Extract iv.
|
||||||
|
iv := value[:size]
|
||||||
|
// Extract ciphertext.
|
||||||
|
value = value[size:]
|
||||||
|
// Decrypt it.
|
||||||
|
stream := cipher.NewCTR(block, iv)
|
||||||
|
stream.XORKeyStream(value, value)
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
return nil, errors.New("decrypt: the value could not be decrypted")
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeCookie(block cipher.Block, hashKey, name string, value map[interface{}]interface{}) (string, error) {
|
||||||
|
var err error
|
||||||
|
var b []byte
|
||||||
|
// 1. encodeGob.
|
||||||
|
if b, err = encodeGob(value); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
// 2. Encrypt (optional).
|
||||||
|
if b, err = encrypt(block, b); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
b = encode(b)
|
||||||
|
// 3. Create MAC for "name|date|value". Extra pipe to be used later.
|
||||||
|
b = []byte(fmt.Sprintf("%s|%d|%s|", name, time.Now().UTC().Unix(), b))
|
||||||
|
h := hmac.New(sha1.New, []byte(hashKey))
|
||||||
|
h.Write(b)
|
||||||
|
sig := h.Sum(nil)
|
||||||
|
// Append mac, remove name.
|
||||||
|
b = append(b, sig...)[len(name)+1:]
|
||||||
|
// 4. Encode to base64.
|
||||||
|
b = encode(b)
|
||||||
|
// Done.
|
||||||
|
return string(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime int64) (map[interface{}]interface{}, error) {
|
||||||
|
// 1. Decode from base64.
|
||||||
|
b, err := decode([]byte(value))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// 2. Verify MAC. Value is "date|value|mac".
|
||||||
|
parts := bytes.SplitN(b, []byte("|"), 3)
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return nil, errors.New("Decode: invalid value %v")
|
||||||
|
}
|
||||||
|
|
||||||
|
b = append([]byte(name+"|"), b[:len(b)-len(parts[2])]...)
|
||||||
|
h := hmac.New(sha1.New, []byte(hashKey))
|
||||||
|
h.Write(b)
|
||||||
|
sig := h.Sum(nil)
|
||||||
|
if len(sig) != len(parts[2]) || subtle.ConstantTimeCompare(sig, parts[2]) != 1 {
|
||||||
|
return nil, errors.New("Decode: the value is not valid")
|
||||||
|
}
|
||||||
|
// 3. Verify date ranges.
|
||||||
|
var t1 int64
|
||||||
|
if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil {
|
||||||
|
return nil, errors.New("Decode: invalid timestamp")
|
||||||
|
}
|
||||||
|
t2 := time.Now().UTC().Unix()
|
||||||
|
if t1 > t2 {
|
||||||
|
return nil, errors.New("Decode: timestamp is too new")
|
||||||
|
}
|
||||||
|
if t1 < t2-gcmaxlifetime {
|
||||||
|
return nil, errors.New("Decode: expired timestamp")
|
||||||
|
}
|
||||||
|
// 4. Decrypt (optional).
|
||||||
|
b, err = decode(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if b, err = decrypt(block, b); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// 5. decodeGob.
|
||||||
|
if dst, err := decodeGob(b); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
return dst, nil
|
||||||
|
}
|
||||||
|
// Done.
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encoding -------------------------------------------------------------------
|
||||||
|
|
||||||
|
// encode encodes a value using base64.
|
||||||
|
func encode(value []byte) []byte {
|
||||||
|
encoded := make([]byte, base64.URLEncoding.EncodedLen(len(value)))
|
||||||
|
base64.URLEncoding.Encode(encoded, value)
|
||||||
|
return encoded
|
||||||
|
}
|
||||||
|
|
||||||
|
// decode decodes a cookie using base64.
|
||||||
|
func decode(value []byte) ([]byte, error) {
|
||||||
|
decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value)))
|
||||||
|
b, err := base64.URLEncoding.Decode(decoded, value)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return decoded[:b], nil
|
||||||
|
}
|
@ -6,6 +6,7 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -13,17 +14,20 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// SessionStore contains all data for one session process with specific id.
|
||||||
type SessionStore interface {
|
type SessionStore interface {
|
||||||
Set(key, value interface{}) error //set session value
|
Set(key, value interface{}) error //set session value
|
||||||
Get(key interface{}) interface{} //get session value
|
Get(key interface{}) interface{} //get session value
|
||||||
Delete(key interface{}) error //delete session value
|
Delete(key interface{}) error //delete session value
|
||||||
SessionID() string //back current sessionID
|
SessionID() string //back current sessionID
|
||||||
SessionRelease() // release the resource & save data to provider
|
SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data
|
||||||
Flush() error //delete all data
|
Flush() error //delete all data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Provider contains global session methods and saved SessionStores.
|
||||||
|
// it can operate a SessionStore by its id.
|
||||||
type Provider interface {
|
type Provider interface {
|
||||||
SessionInit(maxlifetime int64, savePath string) error
|
SessionInit(gclifetime int64, config string) error
|
||||||
SessionRead(sid string) (SessionStore, error)
|
SessionRead(sid string) (SessionStore, error)
|
||||||
SessionExist(sid string) bool
|
SessionExist(sid string) bool
|
||||||
SessionRegenerate(oldsid, sid string) (SessionStore, error)
|
SessionRegenerate(oldsid, sid string) (SessionStore, error)
|
||||||
@ -47,90 +51,86 @@ func Register(name string, provide Provider) {
|
|||||||
provides[name] = provide
|
provides[name] = provide
|
||||||
}
|
}
|
||||||
|
|
||||||
type Manager struct {
|
type managerConfig struct {
|
||||||
cookieName string //private cookiename
|
CookieName string `json:"cookieName"`
|
||||||
provider Provider
|
EnableSetCookie bool `json:"enableSetCookie,omitempty"`
|
||||||
maxlifetime int64
|
Gclifetime int64 `json:"gclifetime"`
|
||||||
hashfunc string //support md5 & sha1
|
Maxlifetime int64 `json:"maxLifetime"`
|
||||||
hashkey string
|
Maxage int `json:"maxage"`
|
||||||
maxage int //cookielifetime
|
Secure bool `json:"secure"`
|
||||||
secure bool
|
SessionIDHashFunc string `json:"sessionIDHashFunc"`
|
||||||
options []interface{}
|
SessionIDHashKey string `json:"sessionIDHashKey"`
|
||||||
|
CookieLifeTime int64 `json:"cookieLifeTime"`
|
||||||
|
ProviderConfig string `json:"providerConfig"`
|
||||||
}
|
}
|
||||||
|
|
||||||
//options
|
// Manager contains Provider and its configuration.
|
||||||
|
type Manager struct {
|
||||||
|
provider Provider
|
||||||
|
config *managerConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new Manager with provider name and json config string.
|
||||||
|
// provider name:
|
||||||
|
// 1. cookie
|
||||||
|
// 2. file
|
||||||
|
// 3. memory
|
||||||
|
// 4. redis
|
||||||
|
// 5. mysql
|
||||||
|
// json config:
|
||||||
// 1. is https default false
|
// 1. is https default false
|
||||||
// 2. hashfunc default sha1
|
// 2. hashfunc default sha1
|
||||||
// 3. hashkey default beegosessionkey
|
// 3. hashkey default beegosessionkey
|
||||||
// 4. maxage default is none
|
// 4. maxage default is none
|
||||||
func NewManager(provideName, cookieName string, maxlifetime int64, savePath string, options ...interface{}) (*Manager, error) {
|
func NewManager(provideName, config string) (*Manager, error) {
|
||||||
provider, ok := provides[provideName]
|
provider, ok := provides[provideName]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName)
|
return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName)
|
||||||
}
|
}
|
||||||
provider.SessionInit(maxlifetime, savePath)
|
cf := new(managerConfig)
|
||||||
secure := false
|
cf.EnableSetCookie = true
|
||||||
if len(options) > 0 {
|
err := json.Unmarshal([]byte(config), cf)
|
||||||
secure = options[0].(bool)
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
hashfunc := "sha1"
|
if cf.Maxlifetime == 0 {
|
||||||
if len(options) > 1 {
|
cf.Maxlifetime = cf.Gclifetime
|
||||||
hashfunc = options[1].(string)
|
|
||||||
}
|
}
|
||||||
hashkey := "beegosessionkey"
|
err = provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig)
|
||||||
if len(options) > 2 {
|
if err != nil {
|
||||||
hashkey = options[2].(string)
|
return nil, err
|
||||||
}
|
|
||||||
maxage := -1
|
|
||||||
if len(options) > 3 {
|
|
||||||
switch options[3].(type) {
|
|
||||||
case int:
|
|
||||||
if options[3].(int) > 0 {
|
|
||||||
maxage = options[3].(int)
|
|
||||||
} else if options[3].(int) < 0 {
|
|
||||||
maxage = 0
|
|
||||||
}
|
|
||||||
case int64:
|
|
||||||
if options[3].(int64) > 0 {
|
|
||||||
maxage = int(options[3].(int64))
|
|
||||||
} else if options[3].(int64) < 0 {
|
|
||||||
maxage = 0
|
|
||||||
}
|
|
||||||
case int32:
|
|
||||||
if options[3].(int32) > 0 {
|
|
||||||
maxage = int(options[3].(int32))
|
|
||||||
} else if options[3].(int32) < 0 {
|
|
||||||
maxage = 0
|
|
||||||
}
|
}
|
||||||
|
if cf.SessionIDHashFunc == "" {
|
||||||
|
cf.SessionIDHashFunc = "sha1"
|
||||||
}
|
}
|
||||||
|
if cf.SessionIDHashKey == "" {
|
||||||
|
cf.SessionIDHashKey = string(generateRandomKey(16))
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Manager{
|
return &Manager{
|
||||||
provider: provider,
|
provider,
|
||||||
cookieName: cookieName,
|
cf,
|
||||||
maxlifetime: maxlifetime,
|
|
||||||
hashfunc: hashfunc,
|
|
||||||
hashkey: hashkey,
|
|
||||||
maxage: maxage,
|
|
||||||
secure: secure,
|
|
||||||
options: options,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//get Session
|
// Start session. generate or read the session id from http request.
|
||||||
|
// if session id exists, return SessionStore with this id.
|
||||||
func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore) {
|
func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore) {
|
||||||
cookie, err := r.Cookie(manager.cookieName)
|
cookie, err := r.Cookie(manager.config.CookieName)
|
||||||
if err != nil || cookie.Value == "" {
|
if err != nil || cookie.Value == "" {
|
||||||
sid := manager.sessionId(r)
|
sid := manager.sessionId(r)
|
||||||
session, _ = manager.provider.SessionRead(sid)
|
session, _ = manager.provider.SessionRead(sid)
|
||||||
cookie = &http.Cookie{Name: manager.cookieName,
|
cookie = &http.Cookie{Name: manager.config.CookieName,
|
||||||
Value: url.QueryEscape(sid),
|
Value: url.QueryEscape(sid),
|
||||||
Path: "/",
|
Path: "/",
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: manager.secure}
|
Secure: manager.config.Secure}
|
||||||
if manager.maxage >= 0 {
|
if manager.config.Maxage >= 0 {
|
||||||
cookie.MaxAge = manager.maxage
|
cookie.MaxAge = manager.config.Maxage
|
||||||
}
|
}
|
||||||
|
if manager.config.EnableSetCookie {
|
||||||
http.SetCookie(w, cookie)
|
http.SetCookie(w, cookie)
|
||||||
|
}
|
||||||
r.AddCookie(cookie)
|
r.AddCookie(cookie)
|
||||||
} else {
|
} else {
|
||||||
sid, _ := url.QueryUnescape(cookie.Value)
|
sid, _ := url.QueryUnescape(cookie.Value)
|
||||||
@ -139,55 +139,65 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
|
|||||||
} else {
|
} else {
|
||||||
sid = manager.sessionId(r)
|
sid = manager.sessionId(r)
|
||||||
session, _ = manager.provider.SessionRead(sid)
|
session, _ = manager.provider.SessionRead(sid)
|
||||||
cookie = &http.Cookie{Name: manager.cookieName,
|
cookie = &http.Cookie{Name: manager.config.CookieName,
|
||||||
Value: url.QueryEscape(sid),
|
Value: url.QueryEscape(sid),
|
||||||
Path: "/",
|
Path: "/",
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: manager.secure}
|
Secure: manager.config.Secure}
|
||||||
if manager.maxage >= 0 {
|
if manager.config.Maxage >= 0 {
|
||||||
cookie.MaxAge = manager.maxage
|
cookie.MaxAge = manager.config.Maxage
|
||||||
}
|
}
|
||||||
|
if manager.config.EnableSetCookie {
|
||||||
http.SetCookie(w, cookie)
|
http.SetCookie(w, cookie)
|
||||||
|
}
|
||||||
r.AddCookie(cookie)
|
r.AddCookie(cookie)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//Destroy sessionid
|
// Destroy session by its id in http request cookie.
|
||||||
func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) {
|
func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) {
|
||||||
cookie, err := r.Cookie(manager.cookieName)
|
cookie, err := r.Cookie(manager.config.CookieName)
|
||||||
if err != nil || cookie.Value == "" {
|
if err != nil || cookie.Value == "" {
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
manager.provider.SessionDestroy(cookie.Value)
|
manager.provider.SessionDestroy(cookie.Value)
|
||||||
expiration := time.Now()
|
expiration := time.Now()
|
||||||
cookie := http.Cookie{Name: manager.cookieName, Path: "/", HttpOnly: true, Expires: expiration, MaxAge: -1}
|
cookie := http.Cookie{Name: manager.config.CookieName,
|
||||||
|
Path: "/",
|
||||||
|
HttpOnly: true,
|
||||||
|
Expires: expiration,
|
||||||
|
MaxAge: -1}
|
||||||
http.SetCookie(w, &cookie)
|
http.SetCookie(w, &cookie)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get SessionStore by its id.
|
||||||
func (manager *Manager) GetProvider(sid string) (sessions SessionStore, err error) {
|
func (manager *Manager) GetProvider(sid string) (sessions SessionStore, err error) {
|
||||||
sessions, err = manager.provider.SessionRead(sid)
|
sessions, err = manager.provider.SessionRead(sid)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Start session gc process.
|
||||||
|
// it can do gc in times after gc lifetime.
|
||||||
func (manager *Manager) GC() {
|
func (manager *Manager) GC() {
|
||||||
manager.provider.SessionGC()
|
manager.provider.SessionGC()
|
||||||
time.AfterFunc(time.Duration(manager.maxlifetime)*time.Second, func() { manager.GC() })
|
time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Regenerate a session id for this SessionStore who's id is saving in http request.
|
||||||
func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Request) (session SessionStore) {
|
func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Request) (session SessionStore) {
|
||||||
sid := manager.sessionId(r)
|
sid := manager.sessionId(r)
|
||||||
cookie, err := r.Cookie(manager.cookieName)
|
cookie, err := r.Cookie(manager.config.CookieName)
|
||||||
if err != nil && cookie.Value == "" {
|
if err != nil && cookie.Value == "" {
|
||||||
//delete old cookie
|
//delete old cookie
|
||||||
session, _ = manager.provider.SessionRead(sid)
|
session, _ = manager.provider.SessionRead(sid)
|
||||||
cookie = &http.Cookie{Name: manager.cookieName,
|
cookie = &http.Cookie{Name: manager.config.CookieName,
|
||||||
Value: url.QueryEscape(sid),
|
Value: url.QueryEscape(sid),
|
||||||
Path: "/",
|
Path: "/",
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: manager.secure,
|
Secure: manager.config.Secure,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
oldsid, _ := url.QueryUnescape(cookie.Value)
|
oldsid, _ := url.QueryUnescape(cookie.Value)
|
||||||
@ -196,44 +206,47 @@ func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Reque
|
|||||||
cookie.HttpOnly = true
|
cookie.HttpOnly = true
|
||||||
cookie.Path = "/"
|
cookie.Path = "/"
|
||||||
}
|
}
|
||||||
if manager.maxage >= 0 {
|
if manager.config.Maxage >= 0 {
|
||||||
cookie.MaxAge = manager.maxage
|
cookie.MaxAge = manager.config.Maxage
|
||||||
}
|
}
|
||||||
http.SetCookie(w, cookie)
|
http.SetCookie(w, cookie)
|
||||||
r.AddCookie(cookie)
|
r.AddCookie(cookie)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get all active sessions count number.
|
||||||
func (manager *Manager) GetActiveSession() int {
|
func (manager *Manager) GetActiveSession() int {
|
||||||
return manager.provider.SessionAll()
|
return manager.provider.SessionAll()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set hash function for generating session id.
|
||||||
func (manager *Manager) SetHashFunc(hasfunc, hashkey string) {
|
func (manager *Manager) SetHashFunc(hasfunc, hashkey string) {
|
||||||
manager.hashfunc = hasfunc
|
manager.config.SessionIDHashFunc = hasfunc
|
||||||
manager.hashkey = hashkey
|
manager.config.SessionIDHashKey = hashkey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set cookie with https.
|
||||||
func (manager *Manager) SetSecure(secure bool) {
|
func (manager *Manager) SetSecure(secure bool) {
|
||||||
manager.secure = secure
|
manager.config.Secure = secure
|
||||||
}
|
}
|
||||||
|
|
||||||
//remote_addr cruunixnano randdata
|
// generate session id with rand string, unix nano time, remote addr by hash function.
|
||||||
func (manager *Manager) sessionId(r *http.Request) (sid string) {
|
func (manager *Manager) sessionId(r *http.Request) (sid string) {
|
||||||
bs := make([]byte, 24)
|
bs := make([]byte, 24)
|
||||||
if _, err := io.ReadFull(rand.Reader, bs); err != nil {
|
if _, err := io.ReadFull(rand.Reader, bs); err != nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
sig := fmt.Sprintf("%s%d%s", r.RemoteAddr, time.Now().UnixNano(), bs)
|
sig := fmt.Sprintf("%s%d%s", r.RemoteAddr, time.Now().UnixNano(), bs)
|
||||||
if manager.hashfunc == "md5" {
|
if manager.config.SessionIDHashFunc == "md5" {
|
||||||
h := md5.New()
|
h := md5.New()
|
||||||
h.Write([]byte(sig))
|
h.Write([]byte(sig))
|
||||||
sid = hex.EncodeToString(h.Sum(nil))
|
sid = hex.EncodeToString(h.Sum(nil))
|
||||||
} else if manager.hashfunc == "sha1" {
|
} else if manager.config.SessionIDHashFunc == "sha1" {
|
||||||
h := hmac.New(sha1.New, []byte(manager.hashkey))
|
h := hmac.New(sha1.New, []byte(manager.config.SessionIDHashKey))
|
||||||
fmt.Fprintf(h, "%s", sig)
|
fmt.Fprintf(h, "%s", sig)
|
||||||
sid = hex.EncodeToString(h.Sum(nil))
|
sid = hex.EncodeToString(h.Sum(nil))
|
||||||
} else {
|
} else {
|
||||||
h := hmac.New(sha1.New, []byte(manager.hashkey))
|
h := hmac.New(sha1.New, []byte(manager.config.SessionIDHashKey))
|
||||||
fmt.Fprintf(h, "%s", sig)
|
fmt.Fprintf(h, "%s", sig)
|
||||||
sid = hex.EncodeToString(h.Sum(nil))
|
sid = hex.EncodeToString(h.Sum(nil))
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
var port = ""
|
var port = ""
|
||||||
var baseUrl = "http://localhost:"
|
var baseUrl = "http://localhost:"
|
||||||
|
|
||||||
|
// beego test request client
|
||||||
type TestHttpRequest struct {
|
type TestHttpRequest struct {
|
||||||
httplib.BeegoHttpRequest
|
httplib.BeegoHttpRequest
|
||||||
}
|
}
|
||||||
@ -24,22 +25,27 @@ func getPort() string {
|
|||||||
return port
|
return port
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// returns test client in GET method
|
||||||
func Get(path string) *TestHttpRequest {
|
func Get(path string) *TestHttpRequest {
|
||||||
return &TestHttpRequest{*httplib.Get(baseUrl + getPort() + path)}
|
return &TestHttpRequest{*httplib.Get(baseUrl + getPort() + path)}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// returns test client in POST method
|
||||||
func Post(path string) *TestHttpRequest {
|
func Post(path string) *TestHttpRequest {
|
||||||
return &TestHttpRequest{*httplib.Post(baseUrl + getPort() + path)}
|
return &TestHttpRequest{*httplib.Post(baseUrl + getPort() + path)}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// returns test client in PUT method
|
||||||
func Put(path string) *TestHttpRequest {
|
func Put(path string) *TestHttpRequest {
|
||||||
return &TestHttpRequest{*httplib.Put(baseUrl + getPort() + path)}
|
return &TestHttpRequest{*httplib.Put(baseUrl + getPort() + path)}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// returns test client in DELETE method
|
||||||
func Delete(path string) *TestHttpRequest {
|
func Delete(path string) *TestHttpRequest {
|
||||||
return &TestHttpRequest{*httplib.Delete(baseUrl + getPort() + path)}
|
return &TestHttpRequest{*httplib.Delete(baseUrl + getPort() + path)}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// returns test client in HEAD method
|
||||||
func Head(path string) *TestHttpRequest {
|
func Head(path string) *TestHttpRequest {
|
||||||
return &TestHttpRequest{*httplib.Head(baseUrl + getPort() + path)}
|
return &TestHttpRequest{*httplib.Head(baseUrl + getPort() + path)}
|
||||||
}
|
}
|
||||||
|
@ -29,16 +29,13 @@ type pointerInfo struct {
|
|||||||
used []int
|
used []int
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
|
||||||
// print the data in console
|
// print the data in console
|
||||||
//
|
|
||||||
func Display(data ...interface{}) {
|
func Display(data ...interface{}) {
|
||||||
display(true, data...)
|
display(true, data...)
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
|
||||||
// return string
|
// return data print string
|
||||||
//
|
|
||||||
func GetDisplayString(data ...interface{}) string {
|
func GetDisplayString(data ...interface{}) string {
|
||||||
return display(false, data...)
|
return display(false, data...)
|
||||||
}
|
}
|
||||||
@ -67,9 +64,7 @@ func display(displayed bool, data ...interface{}) string {
|
|||||||
return buf.String()
|
return buf.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
// return data dump and format bytes
|
||||||
// return fomateinfo
|
|
||||||
//
|
|
||||||
func fomateinfo(headlen int, data ...interface{}) []byte {
|
func fomateinfo(headlen int, data ...interface{}) []byte {
|
||||||
var buf = new(bytes.Buffer)
|
var buf = new(bytes.Buffer)
|
||||||
|
|
||||||
@ -108,6 +103,7 @@ func fomateinfo(headlen int, data ...interface{}) []byte {
|
|||||||
return buf.Bytes()
|
return buf.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check data is golang basic type
|
||||||
func isSimpleType(val reflect.Value, kind reflect.Kind, pointers **pointerInfo, interfaces *[]reflect.Value) bool {
|
func isSimpleType(val reflect.Value, kind reflect.Kind, pointers **pointerInfo, interfaces *[]reflect.Value) bool {
|
||||||
switch kind {
|
switch kind {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
@ -158,6 +154,7 @@ func isSimpleType(val reflect.Value, kind reflect.Kind, pointers **pointerInfo,
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// dump value
|
||||||
func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo, interfaces *[]reflect.Value, structFilter func(string, string) bool, formatOutput bool, indent string, level int) {
|
func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo, interfaces *[]reflect.Value, structFilter func(string, string) bool, formatOutput bool, indent string, level int) {
|
||||||
var t = val.Kind()
|
var t = val.Kind()
|
||||||
|
|
||||||
@ -367,6 +364,7 @@ func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// dump pointer value
|
||||||
func printPointerInfo(buf *bytes.Buffer, headlen int, pointers *pointerInfo) {
|
func printPointerInfo(buf *bytes.Buffer, headlen int, pointers *pointerInfo) {
|
||||||
var anyused = false
|
var anyused = false
|
||||||
var pointerNum = 0
|
var pointerNum = 0
|
||||||
@ -434,9 +432,7 @@ func printPointerInfo(buf *bytes.Buffer, headlen int, pointers *pointerInfo) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
// get stack bytes
|
||||||
// get stack info
|
|
||||||
//
|
|
||||||
func stack(skip int, indent string) []byte {
|
func stack(skip int, indent string) []byte {
|
||||||
var buf = new(bytes.Buffer)
|
var buf = new(bytes.Buffer)
|
||||||
|
|
||||||
@ -455,7 +451,7 @@ func stack(skip int, indent string) []byte {
|
|||||||
return buf.Bytes()
|
return buf.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
// function returns, if possible, the name of the function containing the PC.
|
// return the name of the function containing the PC if possible,
|
||||||
func function(pc uintptr) []byte {
|
func function(pc uintptr) []byte {
|
||||||
fn := runtime.FuncForPC(pc)
|
fn := runtime.FuncForPC(pc)
|
||||||
if fn == nil {
|
if fn == nil {
|
||||||
|
@ -13,12 +13,15 @@ package toolbox
|
|||||||
|
|
||||||
//AddHealthCheck("database",&DatabaseCheck{})
|
//AddHealthCheck("database",&DatabaseCheck{})
|
||||||
|
|
||||||
|
// health checker map
|
||||||
var AdminCheckList map[string]HealthChecker
|
var AdminCheckList map[string]HealthChecker
|
||||||
|
|
||||||
|
// health checker interface
|
||||||
type HealthChecker interface {
|
type HealthChecker interface {
|
||||||
Check() error
|
Check() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add health checker with name string
|
||||||
func AddHealthCheck(name string, hc HealthChecker) {
|
func AddHealthCheck(name string, hc HealthChecker) {
|
||||||
AdminCheckList[name] = hc
|
AdminCheckList[name] = hc
|
||||||
}
|
}
|
||||||
|
@ -19,6 +19,7 @@ func init() {
|
|||||||
pid = os.Getpid()
|
pid = os.Getpid()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parse input command string
|
||||||
func ProcessInput(input string, w io.Writer) {
|
func ProcessInput(input string, w io.Writer) {
|
||||||
switch input {
|
switch input {
|
||||||
case "lookup goroutine":
|
case "lookup goroutine":
|
||||||
@ -44,6 +45,7 @@ func ProcessInput(input string, w io.Writer) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// record memory profile in pprof
|
||||||
func MemProf() {
|
func MemProf() {
|
||||||
if f, err := os.Create("mem-" + strconv.Itoa(pid) + ".memprof"); err != nil {
|
if f, err := os.Create("mem-" + strconv.Itoa(pid) + ".memprof"); err != nil {
|
||||||
log.Fatal("record memory profile failed: %v", err)
|
log.Fatal("record memory profile failed: %v", err)
|
||||||
@ -54,6 +56,7 @@ func MemProf() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// start cpu profile monitor
|
||||||
func StartCPUProfile() {
|
func StartCPUProfile() {
|
||||||
f, err := os.Create("cpu-" + strconv.Itoa(pid) + ".pprof")
|
f, err := os.Create("cpu-" + strconv.Itoa(pid) + ".pprof")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -62,10 +65,12 @@ func StartCPUProfile() {
|
|||||||
pprof.StartCPUProfile(f)
|
pprof.StartCPUProfile(f)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// stop cpu profile monitor
|
||||||
func StopCPUProfile() {
|
func StopCPUProfile() {
|
||||||
pprof.StopCPUProfile()
|
pprof.StopCPUProfile()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// print gc information to io.Writer
|
||||||
func PrintGCSummary(w io.Writer) {
|
func PrintGCSummary(w io.Writer) {
|
||||||
memStats := &runtime.MemStats{}
|
memStats := &runtime.MemStats{}
|
||||||
runtime.ReadMemStats(memStats)
|
runtime.ReadMemStats(memStats)
|
||||||
@ -114,7 +119,7 @@ func avg(items []time.Duration) time.Duration {
|
|||||||
return time.Duration(int64(sum) / int64(len(items)))
|
return time.Duration(int64(sum) / int64(len(items)))
|
||||||
}
|
}
|
||||||
|
|
||||||
// human readable format
|
// format bytes number friendly
|
||||||
func toH(bytes uint64) string {
|
func toH(bytes uint64) string {
|
||||||
switch {
|
switch {
|
||||||
case bytes < 1024:
|
case bytes < 1024:
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Statistics struct
|
||||||
type Statistics struct {
|
type Statistics struct {
|
||||||
RequestUrl string
|
RequestUrl string
|
||||||
RequestController string
|
RequestController string
|
||||||
@ -16,12 +17,15 @@ type Statistics struct {
|
|||||||
TotalTime time.Duration
|
TotalTime time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UrlMap contains several statistics struct to log different data
|
||||||
type UrlMap struct {
|
type UrlMap struct {
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
LengthLimit int //limit the urlmap's length if it's equal to 0 there's no limit
|
LengthLimit int //limit the urlmap's length if it's equal to 0 there's no limit
|
||||||
urlmap map[string]map[string]*Statistics
|
urlmap map[string]map[string]*Statistics
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add statistics task.
|
||||||
|
// it needs request method, request url, request controller and statistics time duration
|
||||||
func (m *UrlMap) AddStatistics(requestMethod, requestUrl, requestController string, requesttime time.Duration) {
|
func (m *UrlMap) AddStatistics(requestMethod, requestUrl, requestController string, requesttime time.Duration) {
|
||||||
m.lock.Lock()
|
m.lock.Lock()
|
||||||
defer m.lock.Unlock()
|
defer m.lock.Unlock()
|
||||||
@ -65,6 +69,7 @@ func (m *UrlMap) AddStatistics(requestMethod, requestUrl, requestController stri
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// put url statistics result in io.Writer
|
||||||
func (m *UrlMap) GetMap(rw io.Writer) {
|
func (m *UrlMap) GetMap(rw io.Writer) {
|
||||||
m.lock.RLock()
|
m.lock.RLock()
|
||||||
defer m.lock.RUnlock()
|
defer m.lock.RUnlock()
|
||||||
@ -78,6 +83,7 @@ func (m *UrlMap) GetMap(rw io.Writer) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// global statistics data map
|
||||||
var StatisticsMap *UrlMap
|
var StatisticsMap *UrlMap
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -53,6 +53,7 @@ const (
|
|||||||
starBit = 1 << 63
|
starBit = 1 << 63
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// time taks schedule
|
||||||
type Schedule struct {
|
type Schedule struct {
|
||||||
Second uint64
|
Second uint64
|
||||||
Minute uint64
|
Minute uint64
|
||||||
@ -62,8 +63,10 @@ type Schedule struct {
|
|||||||
Week uint64
|
Week uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// task func type
|
||||||
type TaskFunc func() error
|
type TaskFunc func() error
|
||||||
|
|
||||||
|
// task interface
|
||||||
type Tasker interface {
|
type Tasker interface {
|
||||||
GetStatus() string
|
GetStatus() string
|
||||||
Run() error
|
Run() error
|
||||||
@ -73,21 +76,24 @@ type Tasker interface {
|
|||||||
GetPrev() time.Time
|
GetPrev() time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// task error
|
||||||
type taskerr struct {
|
type taskerr struct {
|
||||||
t time.Time
|
t time.Time
|
||||||
errinfo string
|
errinfo string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// task struct
|
||||||
type Task struct {
|
type Task struct {
|
||||||
Taskname string
|
Taskname string
|
||||||
Spec *Schedule
|
Spec *Schedule
|
||||||
DoFunc TaskFunc
|
DoFunc TaskFunc
|
||||||
Prev time.Time
|
Prev time.Time
|
||||||
Next time.Time
|
Next time.Time
|
||||||
Errlist []*taskerr //errtime:errinfo
|
Errlist []*taskerr // like errtime:errinfo
|
||||||
ErrLimit int //max length for the errlist 0 stand for there' no limit
|
ErrLimit int // max length for the errlist, 0 stand for no limit
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add new task with name, time and func
|
||||||
func NewTask(tname string, spec string, f TaskFunc) *Task {
|
func NewTask(tname string, spec string, f TaskFunc) *Task {
|
||||||
|
|
||||||
task := &Task{
|
task := &Task{
|
||||||
@ -99,6 +105,7 @@ func NewTask(tname string, spec string, f TaskFunc) *Task {
|
|||||||
return task
|
return task
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get current task status
|
||||||
func (tk *Task) GetStatus() string {
|
func (tk *Task) GetStatus() string {
|
||||||
var str string
|
var str string
|
||||||
for _, v := range tk.Errlist {
|
for _, v := range tk.Errlist {
|
||||||
@ -107,6 +114,7 @@ func (tk *Task) GetStatus() string {
|
|||||||
return str
|
return str
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// run task
|
||||||
func (tk *Task) Run() error {
|
func (tk *Task) Run() error {
|
||||||
err := tk.DoFunc()
|
err := tk.DoFunc()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -117,53 +125,58 @@ func (tk *Task) Run() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set next time for this task
|
||||||
func (tk *Task) SetNext(now time.Time) {
|
func (tk *Task) SetNext(now time.Time) {
|
||||||
tk.Next = tk.Spec.Next(now)
|
tk.Next = tk.Spec.Next(now)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get the next call time of this task
|
||||||
func (tk *Task) GetNext() time.Time {
|
func (tk *Task) GetNext() time.Time {
|
||||||
return tk.Next
|
return tk.Next
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set prev time of this task
|
||||||
func (tk *Task) SetPrev(now time.Time) {
|
func (tk *Task) SetPrev(now time.Time) {
|
||||||
tk.Prev = now
|
tk.Prev = now
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get prev time of this task
|
||||||
func (tk *Task) GetPrev() time.Time {
|
func (tk *Task) GetPrev() time.Time {
|
||||||
return tk.Prev
|
return tk.Prev
|
||||||
}
|
}
|
||||||
|
|
||||||
//前6个字段分别表示:
|
// six columns mean:
|
||||||
// 秒钟:0-59
|
// second:0-59
|
||||||
// 分钟:0-59
|
// minute:0-59
|
||||||
// 小时:1-23
|
// hour:1-23
|
||||||
// 日期:1-31
|
// day:1-31
|
||||||
// 月份:1-12
|
// month:1-12
|
||||||
// 星期:0-6(0表示周日)
|
// week:0-6(0 means Sunday)
|
||||||
|
|
||||||
//还可以用一些特殊符号:
|
// some signals:
|
||||||
// *: 表示任何时刻
|
// *: any time
|
||||||
// ,: 表示分割,如第三段里:2,4,表示2点和4点执行
|
// ,: separate signal
|
||||||
// -:表示一个段,如第三端里: 1-5,就表示1到5点
|
// -:duration
|
||||||
// /n : 表示每个n的单位执行一次,如第三段里,*/1, 就表示每隔1个小时执行一次命令。也可以写成1-23/1.
|
// /n : do as n times of time duration
|
||||||
/////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////
|
||||||
// 0/30 * * * * * 每30秒 执行
|
// 0/30 * * * * * every 30s
|
||||||
// 0 43 21 * * * 21:43 执行
|
// 0 43 21 * * * 21:43
|
||||||
// 0 15 05 * * * 05:15 执行
|
// 0 15 05 * * * 05:15
|
||||||
// 0 0 17 * * * 17:00 执行
|
// 0 0 17 * * * 17:00
|
||||||
// 0 0 17 * * 1 每周一的 17:00 执行
|
// 0 0 17 * * 1 17:00 in every Monday
|
||||||
// 0 0,10 17 * * 0,2,3 每周日,周二,周三的 17:00和 17:10 执行
|
// 0 0,10 17 * * 0,2,3 17:00 and 17:10 in every Sunday, Tuesday and Wednesday
|
||||||
// 0 0-10 17 1 * * 毎月1日从 17:00到7:10 毎隔1分钟 执行
|
// 0 0-10 17 1 * * 17:00 to 17:10 in 1 min duration each time on the first day of month
|
||||||
// 0 0 0 1,15 * 1 毎月1日和 15日和 一日的 0:00 执行
|
// 0 0 0 1,15 * 1 0:00 on the 1st day and 15th day of month
|
||||||
// 0 42 4 1 * * 毎月1日的 4:42分 执行
|
// 0 42 4 1 * * 4:42 on the 1st day of month
|
||||||
// 0 0 21 * * 1-6 周一到周六 21:00 执行
|
// 0 0 21 * * 1-6 21:00 from Monday to Saturday
|
||||||
// 0 0,10,20,30,40,50 * * * * 每隔10分 执行
|
// 0 0,10,20,30,40,50 * * * * every 10 min duration
|
||||||
// 0 */10 * * * * 每隔10分 执行
|
// 0 */10 * * * * every 10 min duration
|
||||||
// 0 * 1 * * * 从1:0到1:59 每隔1分钟 执行
|
// 0 * 1 * * * 1:00 to 1:59 in 1 min duration each time
|
||||||
// 0 0 1 * * * 1:00 执行
|
// 0 0 1 * * * 1:00
|
||||||
// 0 0 */1 * * * 毎时0分 每隔1小时 执行
|
// 0 0 */1 * * * 0 min of hour in 1 hour duration
|
||||||
// 0 0 * * * * 毎时0分 每隔1小时 执行
|
// 0 0 * * * * 0 min of hour in 1 hour duration
|
||||||
// 0 2 8-20/3 * * * 8:02,11:02,14:02,17:02,20:02 执行
|
// 0 2 8-20/3 * * * 8:02, 11:02, 14:02, 17:02, 20:02
|
||||||
// 0 30 5 1,15 * * 1日 和 15日的 5:30 执行
|
// 0 30 5 1,15 * * 5:30 on the 1st day and 15th day of month
|
||||||
func (t *Task) SetCron(spec string) {
|
func (t *Task) SetCron(spec string) {
|
||||||
t.Spec = t.parse(spec)
|
t.Spec = t.parse(spec)
|
||||||
}
|
}
|
||||||
@ -252,6 +265,7 @@ func (t *Task) parseSpec(spec string) *Schedule {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set schedule to next time
|
||||||
func (s *Schedule) Next(t time.Time) time.Time {
|
func (s *Schedule) Next(t time.Time) time.Time {
|
||||||
|
|
||||||
// Start at the earliest possible time (the upcoming second).
|
// Start at the earliest possible time (the upcoming second).
|
||||||
@ -349,6 +363,7 @@ func dayMatches(s *Schedule, t time.Time) bool {
|
|||||||
return domMatch || dowMatch
|
return domMatch || dowMatch
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// start all tasks
|
||||||
func StartTask() {
|
func StartTask() {
|
||||||
go run()
|
go run()
|
||||||
}
|
}
|
||||||
@ -388,10 +403,12 @@ func run() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// start all tasks
|
||||||
func StopTask() {
|
func StopTask() {
|
||||||
stop <- true
|
stop <- true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add task with name
|
||||||
func AddTask(taskname string, t Tasker) {
|
func AddTask(taskname string, t Tasker) {
|
||||||
AdminTaskList[taskname] = t
|
AdminTaskList[taskname] = t
|
||||||
}
|
}
|
||||||
@ -402,6 +419,7 @@ type MapSorter struct {
|
|||||||
Vals []Tasker
|
Vals []Tasker
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create new tasker map
|
||||||
func NewMapSorter(m map[string]Tasker) *MapSorter {
|
func NewMapSorter(m map[string]Tasker) *MapSorter {
|
||||||
ms := &MapSorter{
|
ms := &MapSorter{
|
||||||
Keys: make([]string, 0, len(m)),
|
Keys: make([]string, 0, len(m)),
|
||||||
@ -414,6 +432,7 @@ func NewMapSorter(m map[string]Tasker) *MapSorter {
|
|||||||
return ms
|
return ms
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sort tasker map
|
||||||
func (ms *MapSorter) Sort() {
|
func (ms *MapSorter) Sort() {
|
||||||
sort.Sort(ms)
|
sort.Sort(ms)
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// get function name
|
||||||
func GetFuncName(i interface{}) string {
|
func GetFuncName(i interface{}) string {
|
||||||
return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name()
|
return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name()
|
||||||
}
|
}
|
||||||
|
45
utils/captcha/README.md
Normal file
45
utils/captcha/README.md
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# Captcha
|
||||||
|
|
||||||
|
an example for use captcha
|
||||||
|
|
||||||
|
```
|
||||||
|
package controllers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/astaxie/beego"
|
||||||
|
"github.com/astaxie/beego/cache"
|
||||||
|
"github.com/astaxie/beego/utils/captcha"
|
||||||
|
)
|
||||||
|
|
||||||
|
var cpt *captcha.Captcha
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// use beego cache system store the captcha data
|
||||||
|
store := cache.NewMemoryCache()
|
||||||
|
cpt = captcha.NewWithFilter("/captcha/", store)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MainController struct {
|
||||||
|
beego.Controller
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *MainController) Get() {
|
||||||
|
this.TplNames = "index.tpl"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *MainController) Post() {
|
||||||
|
this.TplNames = "index.tpl"
|
||||||
|
|
||||||
|
this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
template usage
|
||||||
|
|
||||||
|
```
|
||||||
|
{{.Success}}
|
||||||
|
<form action="/" method="post">
|
||||||
|
{{create_captcha}}
|
||||||
|
<input name="captcha" type="text">
|
||||||
|
</form>
|
||||||
|
```
|
251
utils/captcha/captcha.go
Normal file
251
utils/captcha/captcha.go
Normal file
@ -0,0 +1,251 @@
|
|||||||
|
// an example for use captcha
|
||||||
|
//
|
||||||
|
// ```
|
||||||
|
// package controllers
|
||||||
|
//
|
||||||
|
// import (
|
||||||
|
// "github.com/astaxie/beego"
|
||||||
|
// "github.com/astaxie/beego/cache"
|
||||||
|
// "github.com/astaxie/beego/utils/captcha"
|
||||||
|
// )
|
||||||
|
//
|
||||||
|
// var cpt *captcha.Captcha
|
||||||
|
//
|
||||||
|
// func init() {
|
||||||
|
// // use beego cache system store the captcha data
|
||||||
|
// store := cache.NewMemoryCache()
|
||||||
|
// cpt = captcha.NewWithFilter("/captcha/", store)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// type MainController struct {
|
||||||
|
// beego.Controller
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// func (this *MainController) Get() {
|
||||||
|
// this.TplNames = "index.tpl"
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// func (this *MainController) Post() {
|
||||||
|
// this.TplNames = "index.tpl"
|
||||||
|
//
|
||||||
|
// this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request)
|
||||||
|
// }
|
||||||
|
// ```
|
||||||
|
//
|
||||||
|
// template usage
|
||||||
|
//
|
||||||
|
// ```
|
||||||
|
// {{.Success}}
|
||||||
|
// <form action="/" method="post">
|
||||||
|
// {{create_captcha}}
|
||||||
|
// <input name="captcha" type="text">
|
||||||
|
// </form>
|
||||||
|
// ```
|
||||||
|
package captcha
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"html/template"
|
||||||
|
"net/http"
|
||||||
|
"path"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/astaxie/beego"
|
||||||
|
"github.com/astaxie/beego/cache"
|
||||||
|
"github.com/astaxie/beego/context"
|
||||||
|
"github.com/astaxie/beego/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
defaultChars = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// default captcha attributes
|
||||||
|
challengeNums = 6
|
||||||
|
expiration = 600
|
||||||
|
fieldIdName = "captcha_id"
|
||||||
|
fieldCaptchaName = "captcha"
|
||||||
|
cachePrefix = "captcha_"
|
||||||
|
urlPrefix = "/captcha/"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Captcha struct
|
||||||
|
type Captcha struct {
|
||||||
|
// beego cache store
|
||||||
|
store cache.Cache
|
||||||
|
|
||||||
|
// url prefix for captcha image
|
||||||
|
urlPrefix string
|
||||||
|
|
||||||
|
// specify captcha id input field name
|
||||||
|
FieldIdName string
|
||||||
|
// specify captcha result input field name
|
||||||
|
FieldCaptchaName string
|
||||||
|
|
||||||
|
// captcha image width and height
|
||||||
|
StdWidth int
|
||||||
|
StdHeight int
|
||||||
|
|
||||||
|
// captcha chars nums
|
||||||
|
ChallengeNums int
|
||||||
|
|
||||||
|
// captcha expiration seconds
|
||||||
|
Expiration int64
|
||||||
|
|
||||||
|
// cache key prefix
|
||||||
|
CachePrefix string
|
||||||
|
}
|
||||||
|
|
||||||
|
// generate key string
|
||||||
|
func (c *Captcha) key(id string) string {
|
||||||
|
return c.CachePrefix + id
|
||||||
|
}
|
||||||
|
|
||||||
|
// generate rand chars with default chars
|
||||||
|
func (c *Captcha) genRandChars() []byte {
|
||||||
|
return utils.RandomCreateBytes(c.ChallengeNums, defaultChars...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// beego filter handler for serve captcha image
|
||||||
|
func (c *Captcha) Handler(ctx *context.Context) {
|
||||||
|
var chars []byte
|
||||||
|
|
||||||
|
id := path.Base(ctx.Request.RequestURI)
|
||||||
|
if i := strings.Index(id, "."); i != -1 {
|
||||||
|
id = id[:i]
|
||||||
|
}
|
||||||
|
|
||||||
|
key := c.key(id)
|
||||||
|
|
||||||
|
if v, ok := c.store.Get(key).([]byte); ok {
|
||||||
|
chars = v
|
||||||
|
} else {
|
||||||
|
ctx.Output.SetStatus(404)
|
||||||
|
ctx.WriteString("captcha not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// reload captcha
|
||||||
|
if len(ctx.Input.Query("reload")) > 0 {
|
||||||
|
chars = c.genRandChars()
|
||||||
|
if err := c.store.Put(key, chars, c.Expiration); err != nil {
|
||||||
|
ctx.Output.SetStatus(500)
|
||||||
|
ctx.WriteString("captcha reload error")
|
||||||
|
beego.Error("Reload Create Captcha Error:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
img := NewImage(chars, c.StdWidth, c.StdHeight)
|
||||||
|
if _, err := img.WriteTo(ctx.ResponseWriter); err != nil {
|
||||||
|
beego.Error("Write Captcha Image Error:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tempalte func for output html
|
||||||
|
func (c *Captcha) CreateCaptchaHtml() template.HTML {
|
||||||
|
value, err := c.CreateCaptcha()
|
||||||
|
if err != nil {
|
||||||
|
beego.Error("Create Captcha Error:", err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// create html
|
||||||
|
return template.HTML(fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`+
|
||||||
|
`<a class="captcha" href="javascript:">`+
|
||||||
|
`<img onclick="this.src=('%s%s.png?reload='+(new Date()).getTime())" class="captcha-img" src="%s%s.png">`+
|
||||||
|
`</a>`, c.FieldIdName, value, c.urlPrefix, value, c.urlPrefix, value))
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a new captcha id
|
||||||
|
func (c *Captcha) CreateCaptcha() (string, error) {
|
||||||
|
// generate captcha id
|
||||||
|
id := string(utils.RandomCreateBytes(15))
|
||||||
|
|
||||||
|
// get the captcha chars
|
||||||
|
chars := c.genRandChars()
|
||||||
|
|
||||||
|
// save to store
|
||||||
|
if err := c.store.Put(c.key(id), chars, c.Expiration); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// verify from a request
|
||||||
|
func (c *Captcha) VerifyReq(req *http.Request) bool {
|
||||||
|
req.ParseForm()
|
||||||
|
return c.Verify(req.Form.Get(c.FieldIdName), req.Form.Get(c.FieldCaptchaName))
|
||||||
|
}
|
||||||
|
|
||||||
|
// direct verify id and challenge string
|
||||||
|
func (c *Captcha) Verify(id string, challenge string) (success bool) {
|
||||||
|
if len(challenge) == 0 || len(id) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var chars []byte
|
||||||
|
|
||||||
|
key := c.key(id)
|
||||||
|
|
||||||
|
if v, ok := c.store.Get(key).([]byte); ok && len(v) == len(challenge) {
|
||||||
|
chars = v
|
||||||
|
} else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
// finally remove it
|
||||||
|
c.store.Delete(key)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// verify challenge
|
||||||
|
for i, c := range chars {
|
||||||
|
if c != challenge[i]-48 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a new captcha.Captcha
|
||||||
|
func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha {
|
||||||
|
cpt := &Captcha{}
|
||||||
|
cpt.store = store
|
||||||
|
cpt.FieldIdName = fieldIdName
|
||||||
|
cpt.FieldCaptchaName = fieldCaptchaName
|
||||||
|
cpt.ChallengeNums = challengeNums
|
||||||
|
cpt.Expiration = expiration
|
||||||
|
cpt.CachePrefix = cachePrefix
|
||||||
|
cpt.StdWidth = stdWidth
|
||||||
|
cpt.StdHeight = stdHeight
|
||||||
|
|
||||||
|
if len(urlPrefix) == 0 {
|
||||||
|
urlPrefix = urlPrefix
|
||||||
|
}
|
||||||
|
|
||||||
|
if urlPrefix[len(urlPrefix)-1] != '/' {
|
||||||
|
urlPrefix += "/"
|
||||||
|
}
|
||||||
|
|
||||||
|
cpt.urlPrefix = urlPrefix
|
||||||
|
|
||||||
|
return cpt
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a new captcha.Captcha and auto AddFilter for serve captacha image
|
||||||
|
// and add a tempalte func for output html
|
||||||
|
func NewWithFilter(urlPrefix string, store cache.Cache) *Captcha {
|
||||||
|
cpt := NewCaptcha(urlPrefix, store)
|
||||||
|
|
||||||
|
// create filter for serve captcha image
|
||||||
|
beego.AddFilter(urlPrefix+":", "BeforeRouter", cpt.Handler)
|
||||||
|
|
||||||
|
// add to template func map
|
||||||
|
beego.AddFuncMap("create_captcha", cpt.CreateCaptchaHtml)
|
||||||
|
|
||||||
|
return cpt
|
||||||
|
}
|
484
utils/captcha/image.go
Normal file
484
utils/captcha/image.go
Normal file
@ -0,0 +1,484 @@
|
|||||||
|
// modifiy and integrated to Beego from https://github.com/dchest/captcha
|
||||||
|
package captcha
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"image"
|
||||||
|
"image/color"
|
||||||
|
"image/png"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
fontWidth = 11
|
||||||
|
fontHeight = 18
|
||||||
|
blackChar = 1
|
||||||
|
|
||||||
|
// Standard width and height of a captcha image.
|
||||||
|
stdWidth = 240
|
||||||
|
stdHeight = 80
|
||||||
|
// Maximum absolute skew factor of a single digit.
|
||||||
|
maxSkew = 0.7
|
||||||
|
// Number of background circles.
|
||||||
|
circleCount = 20
|
||||||
|
)
|
||||||
|
|
||||||
|
var font = [][]byte{
|
||||||
|
{ // 0
|
||||||
|
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
|
||||||
|
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
|
||||||
|
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0,
|
||||||
|
0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||||
|
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||||
|
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1,
|
||||||
|
0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||||
|
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0,
|
||||||
|
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
|
||||||
|
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
|
||||||
|
},
|
||||||
|
{ // 1
|
||||||
|
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0,
|
||||||
|
0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0,
|
||||||
|
0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0,
|
||||||
|
0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||||
|
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
},
|
||||||
|
{ // 2
|
||||||
|
0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0,
|
||||||
|
0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
|
||||||
|
1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0,
|
||||||
|
0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
},
|
||||||
|
{ // 3
|
||||||
|
0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
|
||||||
|
1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0,
|
||||||
|
0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0,
|
||||||
|
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
|
||||||
|
1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
|
||||||
|
0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
|
||||||
|
},
|
||||||
|
{ // 4
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0,
|
||||||
|
0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0,
|
||||||
|
0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0,
|
||||||
|
0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0,
|
||||||
|
0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0,
|
||||||
|
0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0,
|
||||||
|
0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0,
|
||||||
|
1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0,
|
||||||
|
1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
|
||||||
|
},
|
||||||
|
{ // 5
|
||||||
|
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
|
||||||
|
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
|
||||||
|
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
|
||||||
|
0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
|
||||||
|
1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
|
||||||
|
0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0,
|
||||||
|
},
|
||||||
|
{ // 6
|
||||||
|
0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0,
|
||||||
|
0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0,
|
||||||
|
0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0,
|
||||||
|
1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0,
|
||||||
|
1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0,
|
||||||
|
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1,
|
||||||
|
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1,
|
||||||
|
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0,
|
||||||
|
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
|
||||||
|
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
|
||||||
|
},
|
||||||
|
{ // 7
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0,
|
||||||
|
},
|
||||||
|
{ // 8
|
||||||
|
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
|
||||||
|
0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0,
|
||||||
|
0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1,
|
||||||
|
0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0,
|
||||||
|
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
|
||||||
|
0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0,
|
||||||
|
0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0,
|
||||||
|
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0,
|
||||||
|
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1,
|
||||||
|
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||||
|
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
|
||||||
|
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
|
||||||
|
},
|
||||||
|
{ // 9
|
||||||
|
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
|
||||||
|
0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
|
||||||
|
0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0,
|
||||||
|
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||||
|
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1,
|
||||||
|
0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1,
|
||||||
|
0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,
|
||||||
|
0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0,
|
||||||
|
0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
|
||||||
|
0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
type Image struct {
|
||||||
|
*image.Paletted
|
||||||
|
numWidth int
|
||||||
|
numHeight int
|
||||||
|
dotSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
var prng = &siprng{}
|
||||||
|
|
||||||
|
// randIntn returns a pseudorandom non-negative int in range [0, n).
|
||||||
|
func randIntn(n int) int {
|
||||||
|
return prng.Intn(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// randInt returns a pseudorandom int in range [from, to].
|
||||||
|
func randInt(from, to int) int {
|
||||||
|
return prng.Intn(to+1-from) + from
|
||||||
|
}
|
||||||
|
|
||||||
|
// randFloat returns a pseudorandom float64 in range [from, to].
|
||||||
|
func randFloat(from, to float64) float64 {
|
||||||
|
return (to-from)*prng.Float64() + from
|
||||||
|
}
|
||||||
|
|
||||||
|
func randomPalette() color.Palette {
|
||||||
|
p := make([]color.Color, circleCount+1)
|
||||||
|
// Transparent color.
|
||||||
|
p[0] = color.RGBA{0xFF, 0xFF, 0xFF, 0x00}
|
||||||
|
// Primary color.
|
||||||
|
prim := color.RGBA{
|
||||||
|
uint8(randIntn(129)),
|
||||||
|
uint8(randIntn(129)),
|
||||||
|
uint8(randIntn(129)),
|
||||||
|
0xFF,
|
||||||
|
}
|
||||||
|
p[1] = prim
|
||||||
|
// Circle colors.
|
||||||
|
for i := 2; i <= circleCount; i++ {
|
||||||
|
p[i] = randomBrightness(prim, 255)
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewImage returns a new captcha image of the given width and height with the
|
||||||
|
// given digits, where each digit must be in range 0-9.
|
||||||
|
func NewImage(digits []byte, width, height int) *Image {
|
||||||
|
m := new(Image)
|
||||||
|
m.Paletted = image.NewPaletted(image.Rect(0, 0, width, height), randomPalette())
|
||||||
|
m.calculateSizes(width, height, len(digits))
|
||||||
|
// Randomly position captcha inside the image.
|
||||||
|
maxx := width - (m.numWidth+m.dotSize)*len(digits) - m.dotSize
|
||||||
|
maxy := height - m.numHeight - m.dotSize*2
|
||||||
|
var border int
|
||||||
|
if width > height {
|
||||||
|
border = height / 5
|
||||||
|
} else {
|
||||||
|
border = width / 5
|
||||||
|
}
|
||||||
|
x := randInt(border, maxx-border)
|
||||||
|
y := randInt(border, maxy-border)
|
||||||
|
// Draw digits.
|
||||||
|
for _, n := range digits {
|
||||||
|
m.drawDigit(font[n], x, y)
|
||||||
|
x += m.numWidth + m.dotSize
|
||||||
|
}
|
||||||
|
// Draw strike-through line.
|
||||||
|
m.strikeThrough()
|
||||||
|
// Apply wave distortion.
|
||||||
|
m.distort(randFloat(5, 10), randFloat(100, 200))
|
||||||
|
// Fill image with random circles.
|
||||||
|
m.fillWithCircles(circleCount, m.dotSize)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodedPNG encodes an image to PNG and returns
|
||||||
|
// the result as a byte slice.
|
||||||
|
func (m *Image) encodedPNG() []byte {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := png.Encode(&buf, m.Paletted); err != nil {
|
||||||
|
panic(err.Error())
|
||||||
|
}
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteTo writes captcha image in PNG format into the given writer.
|
||||||
|
func (m *Image) WriteTo(w io.Writer) (int64, error) {
|
||||||
|
n, err := w.Write(m.encodedPNG())
|
||||||
|
return int64(n), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Image) calculateSizes(width, height, ncount int) {
|
||||||
|
// Goal: fit all digits inside the image.
|
||||||
|
var border int
|
||||||
|
if width > height {
|
||||||
|
border = height / 4
|
||||||
|
} else {
|
||||||
|
border = width / 4
|
||||||
|
}
|
||||||
|
// Convert everything to floats for calculations.
|
||||||
|
w := float64(width - border*2)
|
||||||
|
h := float64(height - border*2)
|
||||||
|
// fw takes into account 1-dot spacing between digits.
|
||||||
|
fw := float64(fontWidth + 1)
|
||||||
|
fh := float64(fontHeight)
|
||||||
|
nc := float64(ncount)
|
||||||
|
// Calculate the width of a single digit taking into account only the
|
||||||
|
// width of the image.
|
||||||
|
nw := w / nc
|
||||||
|
// Calculate the height of a digit from this width.
|
||||||
|
nh := nw * fh / fw
|
||||||
|
// Digit too high?
|
||||||
|
if nh > h {
|
||||||
|
// Fit digits based on height.
|
||||||
|
nh = h
|
||||||
|
nw = fw / fh * nh
|
||||||
|
}
|
||||||
|
// Calculate dot size.
|
||||||
|
m.dotSize = int(nh / fh)
|
||||||
|
// Save everything, making the actual width smaller by 1 dot to account
|
||||||
|
// for spacing between digits.
|
||||||
|
m.numWidth = int(nw) - m.dotSize
|
||||||
|
m.numHeight = int(nh)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Image) drawHorizLine(fromX, toX, y int, colorIdx uint8) {
|
||||||
|
for x := fromX; x <= toX; x++ {
|
||||||
|
m.SetColorIndex(x, y, colorIdx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Image) drawCircle(x, y, radius int, colorIdx uint8) {
|
||||||
|
f := 1 - radius
|
||||||
|
dfx := 1
|
||||||
|
dfy := -2 * radius
|
||||||
|
xo := 0
|
||||||
|
yo := radius
|
||||||
|
|
||||||
|
m.SetColorIndex(x, y+radius, colorIdx)
|
||||||
|
m.SetColorIndex(x, y-radius, colorIdx)
|
||||||
|
m.drawHorizLine(x-radius, x+radius, y, colorIdx)
|
||||||
|
|
||||||
|
for xo < yo {
|
||||||
|
if f >= 0 {
|
||||||
|
yo--
|
||||||
|
dfy += 2
|
||||||
|
f += dfy
|
||||||
|
}
|
||||||
|
xo++
|
||||||
|
dfx += 2
|
||||||
|
f += dfx
|
||||||
|
m.drawHorizLine(x-xo, x+xo, y+yo, colorIdx)
|
||||||
|
m.drawHorizLine(x-xo, x+xo, y-yo, colorIdx)
|
||||||
|
m.drawHorizLine(x-yo, x+yo, y+xo, colorIdx)
|
||||||
|
m.drawHorizLine(x-yo, x+yo, y-xo, colorIdx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Image) fillWithCircles(n, maxradius int) {
|
||||||
|
maxx := m.Bounds().Max.X
|
||||||
|
maxy := m.Bounds().Max.Y
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
colorIdx := uint8(randInt(1, circleCount-1))
|
||||||
|
r := randInt(1, maxradius)
|
||||||
|
m.drawCircle(randInt(r, maxx-r), randInt(r, maxy-r), r, colorIdx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Image) strikeThrough() {
|
||||||
|
maxx := m.Bounds().Max.X
|
||||||
|
maxy := m.Bounds().Max.Y
|
||||||
|
y := randInt(maxy/3, maxy-maxy/3)
|
||||||
|
amplitude := randFloat(5, 20)
|
||||||
|
period := randFloat(80, 180)
|
||||||
|
dx := 2.0 * math.Pi / period
|
||||||
|
for x := 0; x < maxx; x++ {
|
||||||
|
xo := amplitude * math.Cos(float64(y)*dx)
|
||||||
|
yo := amplitude * math.Sin(float64(x)*dx)
|
||||||
|
for yn := 0; yn < m.dotSize; yn++ {
|
||||||
|
r := randInt(0, m.dotSize)
|
||||||
|
m.drawCircle(x+int(xo), y+int(yo)+(yn*m.dotSize), r/2, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Image) drawDigit(digit []byte, x, y int) {
|
||||||
|
skf := randFloat(-maxSkew, maxSkew)
|
||||||
|
xs := float64(x)
|
||||||
|
r := m.dotSize / 2
|
||||||
|
y += randInt(-r, r)
|
||||||
|
for yo := 0; yo < fontHeight; yo++ {
|
||||||
|
for xo := 0; xo < fontWidth; xo++ {
|
||||||
|
if digit[yo*fontWidth+xo] != blackChar {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
m.drawCircle(x+xo*m.dotSize, y+yo*m.dotSize, r, 1)
|
||||||
|
}
|
||||||
|
xs += skf
|
||||||
|
x = int(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Image) distort(amplude float64, period float64) {
|
||||||
|
w := m.Bounds().Max.X
|
||||||
|
h := m.Bounds().Max.Y
|
||||||
|
|
||||||
|
oldm := m.Paletted
|
||||||
|
newm := image.NewPaletted(image.Rect(0, 0, w, h), oldm.Palette)
|
||||||
|
|
||||||
|
dx := 2.0 * math.Pi / period
|
||||||
|
for x := 0; x < w; x++ {
|
||||||
|
for y := 0; y < h; y++ {
|
||||||
|
xo := amplude * math.Sin(float64(y)*dx)
|
||||||
|
yo := amplude * math.Cos(float64(x)*dx)
|
||||||
|
newm.SetColorIndex(x, y, oldm.ColorIndexAt(x+int(xo), y+int(yo)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.Paletted = newm
|
||||||
|
}
|
||||||
|
|
||||||
|
func randomBrightness(c color.RGBA, max uint8) color.RGBA {
|
||||||
|
minc := min3(c.R, c.G, c.B)
|
||||||
|
maxc := max3(c.R, c.G, c.B)
|
||||||
|
if maxc > max {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
n := randIntn(int(max-maxc)) - int(minc)
|
||||||
|
return color.RGBA{
|
||||||
|
uint8(int(c.R) + n),
|
||||||
|
uint8(int(c.G) + n),
|
||||||
|
uint8(int(c.B) + n),
|
||||||
|
uint8(c.A),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func min3(x, y, z uint8) (m uint8) {
|
||||||
|
m = x
|
||||||
|
if y < m {
|
||||||
|
m = y
|
||||||
|
}
|
||||||
|
if z < m {
|
||||||
|
m = z
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func max3(x, y, z uint8) (m uint8) {
|
||||||
|
m = x
|
||||||
|
if y > m {
|
||||||
|
m = y
|
||||||
|
}
|
||||||
|
if z > m {
|
||||||
|
m = z
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
38
utils/captcha/image_test.go
Normal file
38
utils/captcha/image_test.go
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
package captcha
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/astaxie/beego/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type byteCounter struct {
|
||||||
|
n int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bc *byteCounter) Write(b []byte) (int, error) {
|
||||||
|
bc.n += int64(len(b))
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkNewImage(b *testing.B) {
|
||||||
|
b.StopTimer()
|
||||||
|
d := utils.RandomCreateBytes(challengeNums, defaultChars...)
|
||||||
|
b.StartTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
NewImage(d, stdWidth, stdHeight)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkImageWriteTo(b *testing.B) {
|
||||||
|
b.StopTimer()
|
||||||
|
d := utils.RandomCreateBytes(challengeNums, defaultChars...)
|
||||||
|
b.StartTimer()
|
||||||
|
counter := &byteCounter{}
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
img := NewImage(d, stdWidth, stdHeight)
|
||||||
|
img.WriteTo(counter)
|
||||||
|
b.SetBytes(counter.n)
|
||||||
|
counter.n = 0
|
||||||
|
}
|
||||||
|
}
|
264
utils/captcha/siprng.go
Normal file
264
utils/captcha/siprng.go
Normal file
@ -0,0 +1,264 @@
|
|||||||
|
// modifiy and integrated to Beego from https://github.com/dchest/captcha
|
||||||
|
package captcha
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// siprng is PRNG based on SipHash-2-4.
|
||||||
|
type siprng struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
k0, k1, ctr uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// siphash implements SipHash-2-4, accepting a uint64 as a message.
|
||||||
|
func siphash(k0, k1, m uint64) uint64 {
|
||||||
|
// Initialization.
|
||||||
|
v0 := k0 ^ 0x736f6d6570736575
|
||||||
|
v1 := k1 ^ 0x646f72616e646f6d
|
||||||
|
v2 := k0 ^ 0x6c7967656e657261
|
||||||
|
v3 := k1 ^ 0x7465646279746573
|
||||||
|
t := uint64(8) << 56
|
||||||
|
|
||||||
|
// Compression.
|
||||||
|
v3 ^= m
|
||||||
|
|
||||||
|
// Round 1.
|
||||||
|
v0 += v1
|
||||||
|
v1 = v1<<13 | v1>>(64-13)
|
||||||
|
v1 ^= v0
|
||||||
|
v0 = v0<<32 | v0>>(64-32)
|
||||||
|
|
||||||
|
v2 += v3
|
||||||
|
v3 = v3<<16 | v3>>(64-16)
|
||||||
|
v3 ^= v2
|
||||||
|
|
||||||
|
v0 += v3
|
||||||
|
v3 = v3<<21 | v3>>(64-21)
|
||||||
|
v3 ^= v0
|
||||||
|
|
||||||
|
v2 += v1
|
||||||
|
v1 = v1<<17 | v1>>(64-17)
|
||||||
|
v1 ^= v2
|
||||||
|
v2 = v2<<32 | v2>>(64-32)
|
||||||
|
|
||||||
|
// Round 2.
|
||||||
|
v0 += v1
|
||||||
|
v1 = v1<<13 | v1>>(64-13)
|
||||||
|
v1 ^= v0
|
||||||
|
v0 = v0<<32 | v0>>(64-32)
|
||||||
|
|
||||||
|
v2 += v3
|
||||||
|
v3 = v3<<16 | v3>>(64-16)
|
||||||
|
v3 ^= v2
|
||||||
|
|
||||||
|
v0 += v3
|
||||||
|
v3 = v3<<21 | v3>>(64-21)
|
||||||
|
v3 ^= v0
|
||||||
|
|
||||||
|
v2 += v1
|
||||||
|
v1 = v1<<17 | v1>>(64-17)
|
||||||
|
v1 ^= v2
|
||||||
|
v2 = v2<<32 | v2>>(64-32)
|
||||||
|
|
||||||
|
v0 ^= m
|
||||||
|
|
||||||
|
// Compress last block.
|
||||||
|
v3 ^= t
|
||||||
|
|
||||||
|
// Round 1.
|
||||||
|
v0 += v1
|
||||||
|
v1 = v1<<13 | v1>>(64-13)
|
||||||
|
v1 ^= v0
|
||||||
|
v0 = v0<<32 | v0>>(64-32)
|
||||||
|
|
||||||
|
v2 += v3
|
||||||
|
v3 = v3<<16 | v3>>(64-16)
|
||||||
|
v3 ^= v2
|
||||||
|
|
||||||
|
v0 += v3
|
||||||
|
v3 = v3<<21 | v3>>(64-21)
|
||||||
|
v3 ^= v0
|
||||||
|
|
||||||
|
v2 += v1
|
||||||
|
v1 = v1<<17 | v1>>(64-17)
|
||||||
|
v1 ^= v2
|
||||||
|
v2 = v2<<32 | v2>>(64-32)
|
||||||
|
|
||||||
|
// Round 2.
|
||||||
|
v0 += v1
|
||||||
|
v1 = v1<<13 | v1>>(64-13)
|
||||||
|
v1 ^= v0
|
||||||
|
v0 = v0<<32 | v0>>(64-32)
|
||||||
|
|
||||||
|
v2 += v3
|
||||||
|
v3 = v3<<16 | v3>>(64-16)
|
||||||
|
v3 ^= v2
|
||||||
|
|
||||||
|
v0 += v3
|
||||||
|
v3 = v3<<21 | v3>>(64-21)
|
||||||
|
v3 ^= v0
|
||||||
|
|
||||||
|
v2 += v1
|
||||||
|
v1 = v1<<17 | v1>>(64-17)
|
||||||
|
v1 ^= v2
|
||||||
|
v2 = v2<<32 | v2>>(64-32)
|
||||||
|
|
||||||
|
v0 ^= t
|
||||||
|
|
||||||
|
// Finalization.
|
||||||
|
v2 ^= 0xff
|
||||||
|
|
||||||
|
// Round 1.
|
||||||
|
v0 += v1
|
||||||
|
v1 = v1<<13 | v1>>(64-13)
|
||||||
|
v1 ^= v0
|
||||||
|
v0 = v0<<32 | v0>>(64-32)
|
||||||
|
|
||||||
|
v2 += v3
|
||||||
|
v3 = v3<<16 | v3>>(64-16)
|
||||||
|
v3 ^= v2
|
||||||
|
|
||||||
|
v0 += v3
|
||||||
|
v3 = v3<<21 | v3>>(64-21)
|
||||||
|
v3 ^= v0
|
||||||
|
|
||||||
|
v2 += v1
|
||||||
|
v1 = v1<<17 | v1>>(64-17)
|
||||||
|
v1 ^= v2
|
||||||
|
v2 = v2<<32 | v2>>(64-32)
|
||||||
|
|
||||||
|
// Round 2.
|
||||||
|
v0 += v1
|
||||||
|
v1 = v1<<13 | v1>>(64-13)
|
||||||
|
v1 ^= v0
|
||||||
|
v0 = v0<<32 | v0>>(64-32)
|
||||||
|
|
||||||
|
v2 += v3
|
||||||
|
v3 = v3<<16 | v3>>(64-16)
|
||||||
|
v3 ^= v2
|
||||||
|
|
||||||
|
v0 += v3
|
||||||
|
v3 = v3<<21 | v3>>(64-21)
|
||||||
|
v3 ^= v0
|
||||||
|
|
||||||
|
v2 += v1
|
||||||
|
v1 = v1<<17 | v1>>(64-17)
|
||||||
|
v1 ^= v2
|
||||||
|
v2 = v2<<32 | v2>>(64-32)
|
||||||
|
|
||||||
|
// Round 3.
|
||||||
|
v0 += v1
|
||||||
|
v1 = v1<<13 | v1>>(64-13)
|
||||||
|
v1 ^= v0
|
||||||
|
v0 = v0<<32 | v0>>(64-32)
|
||||||
|
|
||||||
|
v2 += v3
|
||||||
|
v3 = v3<<16 | v3>>(64-16)
|
||||||
|
v3 ^= v2
|
||||||
|
|
||||||
|
v0 += v3
|
||||||
|
v3 = v3<<21 | v3>>(64-21)
|
||||||
|
v3 ^= v0
|
||||||
|
|
||||||
|
v2 += v1
|
||||||
|
v1 = v1<<17 | v1>>(64-17)
|
||||||
|
v1 ^= v2
|
||||||
|
v2 = v2<<32 | v2>>(64-32)
|
||||||
|
|
||||||
|
// Round 4.
|
||||||
|
v0 += v1
|
||||||
|
v1 = v1<<13 | v1>>(64-13)
|
||||||
|
v1 ^= v0
|
||||||
|
v0 = v0<<32 | v0>>(64-32)
|
||||||
|
|
||||||
|
v2 += v3
|
||||||
|
v3 = v3<<16 | v3>>(64-16)
|
||||||
|
v3 ^= v2
|
||||||
|
|
||||||
|
v0 += v3
|
||||||
|
v3 = v3<<21 | v3>>(64-21)
|
||||||
|
v3 ^= v0
|
||||||
|
|
||||||
|
v2 += v1
|
||||||
|
v1 = v1<<17 | v1>>(64-17)
|
||||||
|
v1 ^= v2
|
||||||
|
v2 = v2<<32 | v2>>(64-32)
|
||||||
|
|
||||||
|
return v0 ^ v1 ^ v2 ^ v3
|
||||||
|
}
|
||||||
|
|
||||||
|
// rekey sets a new PRNG key, which is read from crypto/rand.
|
||||||
|
func (p *siprng) rekey() {
|
||||||
|
var k [16]byte
|
||||||
|
if _, err := io.ReadFull(rand.Reader, k[:]); err != nil {
|
||||||
|
panic(err.Error())
|
||||||
|
}
|
||||||
|
p.k0 = binary.LittleEndian.Uint64(k[0:8])
|
||||||
|
p.k1 = binary.LittleEndian.Uint64(k[8:16])
|
||||||
|
p.ctr = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uint64 returns a new pseudorandom uint64.
|
||||||
|
// It rekeys PRNG on the first call and every 64 MB of generated data.
|
||||||
|
func (p *siprng) Uint64() uint64 {
|
||||||
|
p.mu.Lock()
|
||||||
|
if p.ctr == 0 || p.ctr > 8*1024*1024 {
|
||||||
|
p.rekey()
|
||||||
|
}
|
||||||
|
v := siphash(p.k0, p.k1, p.ctr)
|
||||||
|
p.ctr++
|
||||||
|
p.mu.Unlock()
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *siprng) Int63() int64 {
|
||||||
|
return int64(p.Uint64() & 0x7fffffffffffffff)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *siprng) Uint32() uint32 {
|
||||||
|
return uint32(p.Uint64())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *siprng) Int31() int32 {
|
||||||
|
return int32(p.Uint32() & 0x7fffffff)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *siprng) Intn(n int) int {
|
||||||
|
if n <= 0 {
|
||||||
|
panic("invalid argument to Intn")
|
||||||
|
}
|
||||||
|
if n <= 1<<31-1 {
|
||||||
|
return int(p.Int31n(int32(n)))
|
||||||
|
}
|
||||||
|
return int(p.Int63n(int64(n)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *siprng) Int63n(n int64) int64 {
|
||||||
|
if n <= 0 {
|
||||||
|
panic("invalid argument to Int63n")
|
||||||
|
}
|
||||||
|
max := int64((1 << 63) - 1 - (1<<63)%uint64(n))
|
||||||
|
v := p.Int63()
|
||||||
|
for v > max {
|
||||||
|
v = p.Int63()
|
||||||
|
}
|
||||||
|
return v % n
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *siprng) Int31n(n int32) int32 {
|
||||||
|
if n <= 0 {
|
||||||
|
panic("invalid argument to Int31n")
|
||||||
|
}
|
||||||
|
max := int32((1 << 31) - 1 - (1<<31)%uint32(n))
|
||||||
|
v := p.Int31()
|
||||||
|
for v > max {
|
||||||
|
v = p.Int31()
|
||||||
|
}
|
||||||
|
return v % n
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *siprng) Float64() float64 { return float64(p.Int63()) / (1 << 63) }
|
19
utils/captcha/siprng_test.go
Normal file
19
utils/captcha/siprng_test.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package captcha
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestSiphash(t *testing.T) {
|
||||||
|
good := uint64(0xe849e8bb6ffe2567)
|
||||||
|
cur := siphash(0, 0, 0)
|
||||||
|
if cur != good {
|
||||||
|
t.Fatalf("siphash: expected %x, got %x", good, cur)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkSiprng(b *testing.B) {
|
||||||
|
b.SetBytes(8)
|
||||||
|
p := &siprng{}
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
p.Uint64()
|
||||||
|
}
|
||||||
|
}
|
@ -9,11 +9,13 @@ import (
|
|||||||
"regexp"
|
"regexp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// SelfPath gets compiled executable file absolute path
|
||||||
func SelfPath() string {
|
func SelfPath() string {
|
||||||
path, _ := filepath.Abs(os.Args[0])
|
path, _ := filepath.Abs(os.Args[0])
|
||||||
return path
|
return path
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SelfDir gets compiled executable file directory
|
||||||
func SelfDir() string {
|
func SelfDir() string {
|
||||||
return filepath.Dir(SelfPath())
|
return filepath.Dir(SelfPath())
|
||||||
}
|
}
|
||||||
@ -28,8 +30,8 @@ func FileExists(name string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// search a file in paths.
|
// Search a file in paths.
|
||||||
// this is offen used in search config file in /etc ~/
|
// this is often used in search config file in /etc ~/
|
||||||
func SearchFile(filename string, paths ...string) (fullpath string, err error) {
|
func SearchFile(filename string, paths ...string) (fullpath string, err error) {
|
||||||
for _, path := range paths {
|
for _, path := range paths {
|
||||||
if fullpath = filepath.Join(path, filename); FileExists(fullpath) {
|
if fullpath = filepath.Join(path, filename); FileExists(fullpath) {
|
||||||
|
301
utils/mail.go
Normal file
301
utils/mail.go
Normal file
@ -0,0 +1,301 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"mime"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/mail"
|
||||||
|
"net/smtp"
|
||||||
|
"net/textproto"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxLineLength = 76
|
||||||
|
)
|
||||||
|
|
||||||
|
// Email is the type used for email messages
|
||||||
|
type Email struct {
|
||||||
|
Auth smtp.Auth
|
||||||
|
Identity string `json:"identity"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
Host string `json:"host"`
|
||||||
|
Port int `json:"port"`
|
||||||
|
From string `json:"from"`
|
||||||
|
To []string
|
||||||
|
Bcc []string
|
||||||
|
Cc []string
|
||||||
|
Subject string
|
||||||
|
Text string // Plaintext message (optional)
|
||||||
|
HTML string // Html message (optional)
|
||||||
|
Headers textproto.MIMEHeader
|
||||||
|
Attachments []*Attachment
|
||||||
|
ReadReceipt []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attachment is a struct representing an email attachment.
|
||||||
|
// Based on the mime/multipart.FileHeader struct, Attachment contains the name, MIMEHeader, and content of the attachment in question
|
||||||
|
type Attachment struct {
|
||||||
|
Filename string
|
||||||
|
Header textproto.MIMEHeader
|
||||||
|
Content []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEMail create new Email struct with config json.
|
||||||
|
// config json is followed from Email struct fields.
|
||||||
|
func NewEMail(config string) *Email {
|
||||||
|
e := new(Email)
|
||||||
|
e.Headers = textproto.MIMEHeader{}
|
||||||
|
err := json.Unmarshal([]byte(config), e)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if e.From == "" {
|
||||||
|
e.From = e.Username
|
||||||
|
}
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make all send information to byte
|
||||||
|
func (e *Email) Bytes() ([]byte, error) {
|
||||||
|
buff := &bytes.Buffer{}
|
||||||
|
w := multipart.NewWriter(buff)
|
||||||
|
// Set the appropriate headers (overwriting any conflicts)
|
||||||
|
// Leave out Bcc (only included in envelope headers)
|
||||||
|
e.Headers.Set("To", strings.Join(e.To, ","))
|
||||||
|
if e.Cc != nil {
|
||||||
|
e.Headers.Set("Cc", strings.Join(e.Cc, ","))
|
||||||
|
}
|
||||||
|
e.Headers.Set("From", e.From)
|
||||||
|
e.Headers.Set("Subject", e.Subject)
|
||||||
|
if len(e.ReadReceipt) != 0 {
|
||||||
|
e.Headers.Set("Disposition-Notification-To", strings.Join(e.ReadReceipt, ","))
|
||||||
|
}
|
||||||
|
e.Headers.Set("MIME-Version", "1.0")
|
||||||
|
e.Headers.Set("Content-Type", fmt.Sprintf("multipart/mixed;\r\n boundary=%s\r\n", w.Boundary()))
|
||||||
|
|
||||||
|
// Write the envelope headers (including any custom headers)
|
||||||
|
if err := headerToBytes(buff, e.Headers); err != nil {
|
||||||
|
return nil, fmt.Errorf("Failed to render message headers: %s", err)
|
||||||
|
}
|
||||||
|
// Start the multipart/mixed part
|
||||||
|
fmt.Fprintf(buff, "--%s\r\n", w.Boundary())
|
||||||
|
header := textproto.MIMEHeader{}
|
||||||
|
// Check to see if there is a Text or HTML field
|
||||||
|
if e.Text != "" || e.HTML != "" {
|
||||||
|
subWriter := multipart.NewWriter(buff)
|
||||||
|
// Create the multipart alternative part
|
||||||
|
header.Set("Content-Type", fmt.Sprintf("multipart/alternative;\r\n boundary=%s\r\n", subWriter.Boundary()))
|
||||||
|
// Write the header
|
||||||
|
if err := headerToBytes(buff, header); err != nil {
|
||||||
|
return nil, fmt.Errorf("Failed to render multipart message headers: %s", err)
|
||||||
|
}
|
||||||
|
// Create the body sections
|
||||||
|
if e.Text != "" {
|
||||||
|
header.Set("Content-Type", fmt.Sprintf("text/plain; charset=UTF-8"))
|
||||||
|
header.Set("Content-Transfer-Encoding", "quoted-printable")
|
||||||
|
if _, err := subWriter.CreatePart(header); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Write the text
|
||||||
|
if err := quotePrintEncode(buff, e.Text); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if e.HTML != "" {
|
||||||
|
header.Set("Content-Type", fmt.Sprintf("text/html; charset=UTF-8"))
|
||||||
|
header.Set("Content-Transfer-Encoding", "quoted-printable")
|
||||||
|
if _, err := subWriter.CreatePart(header); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Write the text
|
||||||
|
if err := quotePrintEncode(buff, e.HTML); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := subWriter.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Create attachment part, if necessary
|
||||||
|
for _, a := range e.Attachments {
|
||||||
|
ap, err := w.CreatePart(a.Header)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Write the base64Wrapped content to the part
|
||||||
|
base64Wrap(ap, a.Content)
|
||||||
|
}
|
||||||
|
if err := w.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return buff.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add attach file to the send mail
|
||||||
|
func (e *Email) AttachFile(filename string) (a *Attachment, err error) {
|
||||||
|
f, err := os.Open(filename)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ct := mime.TypeByExtension(filepath.Ext(filename))
|
||||||
|
basename := path.Base(filename)
|
||||||
|
return e.Attach(f, basename, ct)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attach is used to attach content from an io.Reader to the email.
|
||||||
|
// Parameters include an io.Reader, the desired filename for the attachment, and the Content-Type.
|
||||||
|
func (e *Email) Attach(r io.Reader, filename string, c string) (a *Attachment, err error) {
|
||||||
|
var buffer bytes.Buffer
|
||||||
|
if _, err = io.Copy(&buffer, r); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
at := &Attachment{
|
||||||
|
Filename: filename,
|
||||||
|
Header: textproto.MIMEHeader{},
|
||||||
|
Content: buffer.Bytes(),
|
||||||
|
}
|
||||||
|
// Get the Content-Type to be used in the MIMEHeader
|
||||||
|
if c != "" {
|
||||||
|
at.Header.Set("Content-Type", c)
|
||||||
|
} else {
|
||||||
|
// If the Content-Type is blank, set the Content-Type to "application/octet-stream"
|
||||||
|
at.Header.Set("Content-Type", "application/octet-stream")
|
||||||
|
}
|
||||||
|
at.Header.Set("Content-Disposition", fmt.Sprintf("attachment;\r\n filename=\"%s\"", filename))
|
||||||
|
at.Header.Set("Content-Transfer-Encoding", "base64")
|
||||||
|
e.Attachments = append(e.Attachments, at)
|
||||||
|
return at, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Email) Send() error {
|
||||||
|
if e.Auth == nil {
|
||||||
|
e.Auth = smtp.PlainAuth(e.Identity, e.Username, e.Password, e.Host)
|
||||||
|
}
|
||||||
|
// Merge the To, Cc, and Bcc fields
|
||||||
|
to := make([]string, 0, len(e.To)+len(e.Cc)+len(e.Bcc))
|
||||||
|
to = append(append(append(to, e.To...), e.Cc...), e.Bcc...)
|
||||||
|
// Check to make sure there is at least one recipient and one "From" address
|
||||||
|
if e.From == "" || len(to) == 0 {
|
||||||
|
return errors.New("Must specify at least one From address and one To address")
|
||||||
|
}
|
||||||
|
from, err := mail.ParseAddress(e.From)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
raw, err := e.Bytes()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return smtp.SendMail(e.Host+":"+strconv.Itoa(e.Port), e.Auth, from.Address, to, raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
// quotePrintEncode writes the quoted-printable text to the IO Writer (according to RFC 2045)
|
||||||
|
func quotePrintEncode(w io.Writer, s string) error {
|
||||||
|
var buf [3]byte
|
||||||
|
mc := 0
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
c := s[i]
|
||||||
|
// We're assuming Unix style text formats as input (LF line break), and
|
||||||
|
// quoted-printble uses CRLF line breaks. (Literal CRs will become
|
||||||
|
// "=0D", but probably shouldn't be there to begin with!)
|
||||||
|
if c == '\n' {
|
||||||
|
io.WriteString(w, "\r\n")
|
||||||
|
mc = 0
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var nextOut []byte
|
||||||
|
if isPrintable(c) {
|
||||||
|
nextOut = append(buf[:0], c)
|
||||||
|
} else {
|
||||||
|
nextOut = buf[:]
|
||||||
|
qpEscape(nextOut, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a soft line break if the next (encoded) byte would push this line
|
||||||
|
// to or past the limit.
|
||||||
|
if mc+len(nextOut) >= maxLineLength {
|
||||||
|
if _, err := io.WriteString(w, "=\r\n"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
mc = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := w.Write(nextOut); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
mc += len(nextOut)
|
||||||
|
}
|
||||||
|
// No trailing end-of-line?? Soft line break, then. TODO: is this sane?
|
||||||
|
if mc > 0 {
|
||||||
|
io.WriteString(w, "=\r\n")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPrintable returns true if the rune given is "printable" according to RFC 2045, false otherwise
|
||||||
|
func isPrintable(c byte) bool {
|
||||||
|
return (c >= '!' && c <= '<') || (c >= '>' && c <= '~') || (c == ' ' || c == '\n' || c == '\t')
|
||||||
|
}
|
||||||
|
|
||||||
|
// qpEscape is a helper function for quotePrintEncode which escapes a
|
||||||
|
// non-printable byte. Expects len(dest) == 3.
|
||||||
|
func qpEscape(dest []byte, c byte) {
|
||||||
|
const nums = "0123456789ABCDEF"
|
||||||
|
dest[0] = '='
|
||||||
|
dest[1] = nums[(c&0xf0)>>4]
|
||||||
|
dest[2] = nums[(c & 0xf)]
|
||||||
|
}
|
||||||
|
|
||||||
|
// headerToBytes enumerates the key and values in the header, and writes the results to the IO Writer
|
||||||
|
func headerToBytes(w io.Writer, t textproto.MIMEHeader) error {
|
||||||
|
for k, v := range t {
|
||||||
|
// Write the header key
|
||||||
|
_, err := fmt.Fprintf(w, "%s:", k)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Write each value in the header
|
||||||
|
for _, c := range v {
|
||||||
|
_, err := fmt.Fprintf(w, " %s\r\n", c)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// base64Wrap encodes the attachment content, and wraps it according to RFC 2045 standards (every 76 chars)
|
||||||
|
// The output is then written to the specified io.Writer
|
||||||
|
func base64Wrap(w io.Writer, b []byte) {
|
||||||
|
// 57 raw bytes per 76-byte base64 line.
|
||||||
|
const maxRaw = 57
|
||||||
|
// Buffer for each line, including trailing CRLF.
|
||||||
|
var buffer [maxLineLength + len("\r\n")]byte
|
||||||
|
copy(buffer[maxLineLength:], "\r\n")
|
||||||
|
// Process raw chunks until there's no longer enough to fill a line.
|
||||||
|
for len(b) >= maxRaw {
|
||||||
|
base64.StdEncoding.Encode(buffer[:], b[:maxRaw])
|
||||||
|
w.Write(buffer[:])
|
||||||
|
b = b[maxRaw:]
|
||||||
|
}
|
||||||
|
// Handle the last chunk of bytes.
|
||||||
|
if len(b) > 0 {
|
||||||
|
out := buffer[:base64.StdEncoding.EncodedLen(len(b))]
|
||||||
|
base64.StdEncoding.Encode(out, b)
|
||||||
|
out = append(out, "\r\n"...)
|
||||||
|
w.Write(out)
|
||||||
|
}
|
||||||
|
}
|
27
utils/mail_test.go
Normal file
27
utils/mail_test.go
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestMail(t *testing.T) {
|
||||||
|
config := `{"username":"astaxie@gmail.com","password":"astaxie","host":"smtp.gmail.com","port":587}`
|
||||||
|
mail := NewEMail(config)
|
||||||
|
if mail.Username != "astaxie@gmail.com" {
|
||||||
|
t.Fatal("email parse get username error")
|
||||||
|
}
|
||||||
|
if mail.Password != "astaxie" {
|
||||||
|
t.Fatal("email parse get password error")
|
||||||
|
}
|
||||||
|
if mail.Host != "smtp.gmail.com" {
|
||||||
|
t.Fatal("email parse get host error")
|
||||||
|
}
|
||||||
|
if mail.Port != 587 {
|
||||||
|
t.Fatal("email parse get port error")
|
||||||
|
}
|
||||||
|
mail.To = []string{"xiemengjun@gmail.com"}
|
||||||
|
mail.From = "astaxie@gmail.com"
|
||||||
|
mail.Subject = "hi, just from beego!"
|
||||||
|
mail.Text = "Text Body is, of course, supported!"
|
||||||
|
mail.HTML = "<h1>Fancy Html is supported, too!</h1>"
|
||||||
|
mail.AttachFile("/Users/astaxie/github/beego/beego.go")
|
||||||
|
mail.Send()
|
||||||
|
}
|
20
utils/rand.go
Normal file
20
utils/rand.go
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RandomCreateBytes generate random []byte by specify chars.
|
||||||
|
func RandomCreateBytes(n int, alphabets ...byte) []byte {
|
||||||
|
const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||||
|
var bytes = make([]byte, n)
|
||||||
|
rand.Read(bytes)
|
||||||
|
for i, b := range bytes {
|
||||||
|
if len(alphabets) == 0 {
|
||||||
|
bytes[i] = alphanum[b%byte(len(alphanum))]
|
||||||
|
} else {
|
||||||
|
bytes[i] = alphabets[b%byte(len(alphabets))]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return bytes
|
||||||
|
}
|
@ -9,6 +9,7 @@ type BeeMap struct {
|
|||||||
bm map[interface{}]interface{}
|
bm map[interface{}]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewBeeMap return new safemap
|
||||||
func NewBeeMap() *BeeMap {
|
func NewBeeMap() *BeeMap {
|
||||||
return &BeeMap{
|
return &BeeMap{
|
||||||
lock: new(sync.RWMutex),
|
lock: new(sync.RWMutex),
|
||||||
@ -51,12 +52,14 @@ func (m *BeeMap) Check(k interface{}) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Delete the given key and value.
|
||||||
func (m *BeeMap) Delete(k interface{}) {
|
func (m *BeeMap) Delete(k interface{}) {
|
||||||
m.lock.Lock()
|
m.lock.Lock()
|
||||||
defer m.lock.Unlock()
|
defer m.lock.Unlock()
|
||||||
delete(m.bm, k)
|
delete(m.bm, k)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Items returns all items in safemap.
|
||||||
func (m *BeeMap) Items() map[interface{}]interface{} {
|
func (m *BeeMap) Items() map[interface{}]interface{} {
|
||||||
m.lock.RLock()
|
m.lock.RLock()
|
||||||
defer m.lock.RUnlock()
|
defer m.lock.RUnlock()
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
type reducetype func(interface{}) interface{}
|
type reducetype func(interface{}) interface{}
|
||||||
type filtertype func(interface{}) bool
|
type filtertype func(interface{}) bool
|
||||||
|
|
||||||
|
// InSlice checks given string in string slice or not.
|
||||||
func InSlice(v string, sl []string) bool {
|
func InSlice(v string, sl []string) bool {
|
||||||
for _, vv := range sl {
|
for _, vv := range sl {
|
||||||
if vv == v {
|
if vv == v {
|
||||||
@ -17,6 +18,7 @@ func InSlice(v string, sl []string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InSliceIface checks given interface in interface slice.
|
||||||
func InSliceIface(v interface{}, sl []interface{}) bool {
|
func InSliceIface(v interface{}, sl []interface{}) bool {
|
||||||
for _, vv := range sl {
|
for _, vv := range sl {
|
||||||
if vv == v {
|
if vv == v {
|
||||||
@ -26,6 +28,7 @@ func InSliceIface(v interface{}, sl []interface{}) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SliceRandList generate an int slice from min to max.
|
||||||
func SliceRandList(min, max int) []int {
|
func SliceRandList(min, max int) []int {
|
||||||
if max < min {
|
if max < min {
|
||||||
min, max = max, min
|
min, max = max, min
|
||||||
@ -40,11 +43,13 @@ func SliceRandList(min, max int) []int {
|
|||||||
return list
|
return list
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SliceMerge merges interface slices to one slice.
|
||||||
func SliceMerge(slice1, slice2 []interface{}) (c []interface{}) {
|
func SliceMerge(slice1, slice2 []interface{}) (c []interface{}) {
|
||||||
c = append(slice1, slice2...)
|
c = append(slice1, slice2...)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SliceReduce generates a new slice after parsing every value by reduce function
|
||||||
func SliceReduce(slice []interface{}, a reducetype) (dslice []interface{}) {
|
func SliceReduce(slice []interface{}, a reducetype) (dslice []interface{}) {
|
||||||
for _, v := range slice {
|
for _, v := range slice {
|
||||||
dslice = append(dslice, a(v))
|
dslice = append(dslice, a(v))
|
||||||
@ -52,12 +57,14 @@ func SliceReduce(slice []interface{}, a reducetype) (dslice []interface{}) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SliceRand returns random one from slice.
|
||||||
func SliceRand(a []interface{}) (b interface{}) {
|
func SliceRand(a []interface{}) (b interface{}) {
|
||||||
randnum := rand.Intn(len(a))
|
randnum := rand.Intn(len(a))
|
||||||
b = a[randnum]
|
b = a[randnum]
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SliceSum sums all values in int64 slice.
|
||||||
func SliceSum(intslice []int64) (sum int64) {
|
func SliceSum(intslice []int64) (sum int64) {
|
||||||
for _, v := range intslice {
|
for _, v := range intslice {
|
||||||
sum += v
|
sum += v
|
||||||
@ -65,6 +72,7 @@ func SliceSum(intslice []int64) (sum int64) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SliceFilter generates a new slice after filter function.
|
||||||
func SliceFilter(slice []interface{}, a filtertype) (ftslice []interface{}) {
|
func SliceFilter(slice []interface{}, a filtertype) (ftslice []interface{}) {
|
||||||
for _, v := range slice {
|
for _, v := range slice {
|
||||||
if a(v) {
|
if a(v) {
|
||||||
@ -74,6 +82,7 @@ func SliceFilter(slice []interface{}, a filtertype) (ftslice []interface{}) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SliceDiff returns diff slice of slice1 - slice2.
|
||||||
func SliceDiff(slice1, slice2 []interface{}) (diffslice []interface{}) {
|
func SliceDiff(slice1, slice2 []interface{}) (diffslice []interface{}) {
|
||||||
for _, v := range slice1 {
|
for _, v := range slice1 {
|
||||||
if !InSliceIface(v, slice2) {
|
if !InSliceIface(v, slice2) {
|
||||||
@ -83,6 +92,7 @@ func SliceDiff(slice1, slice2 []interface{}) (diffslice []interface{}) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SliceIntersect returns diff slice of slice2 - slice1.
|
||||||
func SliceIntersect(slice1, slice2 []interface{}) (diffslice []interface{}) {
|
func SliceIntersect(slice1, slice2 []interface{}) (diffslice []interface{}) {
|
||||||
for _, v := range slice1 {
|
for _, v := range slice1 {
|
||||||
if !InSliceIface(v, slice2) {
|
if !InSliceIface(v, slice2) {
|
||||||
@ -92,6 +102,7 @@ func SliceIntersect(slice1, slice2 []interface{}) (diffslice []interface{}) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SliceChuck separates one slice to some sized slice.
|
||||||
func SliceChunk(slice []interface{}, size int) (chunkslice [][]interface{}) {
|
func SliceChunk(slice []interface{}, size int) (chunkslice [][]interface{}) {
|
||||||
if size >= len(slice) {
|
if size >= len(slice) {
|
||||||
chunkslice = append(chunkslice, slice)
|
chunkslice = append(chunkslice, slice)
|
||||||
@ -105,6 +116,7 @@ func SliceChunk(slice []interface{}, size int) (chunkslice [][]interface{}) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SliceRange generates a new slice from begin to end with step duration of int64 number.
|
||||||
func SliceRange(start, end, step int64) (intslice []int64) {
|
func SliceRange(start, end, step int64) (intslice []int64) {
|
||||||
for i := start; i <= end; i += step {
|
for i := start; i <= end; i += step {
|
||||||
intslice = append(intslice, i)
|
intslice = append(intslice, i)
|
||||||
@ -112,6 +124,7 @@ func SliceRange(start, end, step int64) (intslice []int64) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SlicePad prepends size number of val into slice.
|
||||||
func SlicePad(slice []interface{}, size int, val interface{}) []interface{} {
|
func SlicePad(slice []interface{}, size int, val interface{}) []interface{} {
|
||||||
if size <= len(slice) {
|
if size <= len(slice) {
|
||||||
return slice
|
return slice
|
||||||
@ -122,6 +135,7 @@ func SlicePad(slice []interface{}, size int, val interface{}) []interface{} {
|
|||||||
return slice
|
return slice
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SliceUnique cleans repeated values in slice.
|
||||||
func SliceUnique(slice []interface{}) (uniqueslice []interface{}) {
|
func SliceUnique(slice []interface{}) (uniqueslice []interface{}) {
|
||||||
for _, v := range slice {
|
for _, v := range slice {
|
||||||
if !InSliceIface(v, uniqueslice) {
|
if !InSliceIface(v, uniqueslice) {
|
||||||
@ -131,6 +145,7 @@ func SliceUnique(slice []interface{}) (uniqueslice []interface{}) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SliceShuffle shuffles a slice.
|
||||||
func SliceShuffle(slice []interface{}) []interface{} {
|
func SliceShuffle(slice []interface{}) []interface{} {
|
||||||
for i := 0; i < len(slice); i++ {
|
for i := 0; i < len(slice); i++ {
|
||||||
a := rand.Intn(len(slice))
|
a := rand.Intn(len(slice))
|
||||||
|
@ -41,13 +41,16 @@ func init() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Valid function type
|
||||||
type ValidFunc struct {
|
type ValidFunc struct {
|
||||||
Name string
|
Name string
|
||||||
Params []interface{}
|
Params []interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate function map
|
||||||
type Funcs map[string]reflect.Value
|
type Funcs map[string]reflect.Value
|
||||||
|
|
||||||
|
// validate values with named type string
|
||||||
func (f Funcs) Call(name string, params ...interface{}) (result []reflect.Value, err error) {
|
func (f Funcs) Call(name string, params ...interface{}) (result []reflect.Value, err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
|
@ -32,6 +32,7 @@ type ValidationResult struct {
|
|||||||
Ok bool
|
Ok bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get ValidationResult by given key string.
|
||||||
func (r *ValidationResult) Key(key string) *ValidationResult {
|
func (r *ValidationResult) Key(key string) *ValidationResult {
|
||||||
if r.Error != nil {
|
if r.Error != nil {
|
||||||
r.Error.Key = key
|
r.Error.Key = key
|
||||||
@ -39,6 +40,7 @@ func (r *ValidationResult) Key(key string) *ValidationResult {
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set ValidationResult message by string or format string with args
|
||||||
func (r *ValidationResult) Message(message string, args ...interface{}) *ValidationResult {
|
func (r *ValidationResult) Message(message string, args ...interface{}) *ValidationResult {
|
||||||
if r.Error != nil {
|
if r.Error != nil {
|
||||||
if len(args) == 0 {
|
if len(args) == 0 {
|
||||||
@ -56,10 +58,12 @@ type Validation struct {
|
|||||||
ErrorsMap map[string]*ValidationError
|
ErrorsMap map[string]*ValidationError
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clean all ValidationError.
|
||||||
func (v *Validation) Clear() {
|
func (v *Validation) Clear() {
|
||||||
v.Errors = []*ValidationError{}
|
v.Errors = []*ValidationError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Has ValidationError nor not.
|
||||||
func (v *Validation) HasErrors() bool {
|
func (v *Validation) HasErrors() bool {
|
||||||
return len(v.Errors) > 0
|
return len(v.Errors) > 0
|
||||||
}
|
}
|
||||||
@ -101,67 +105,83 @@ func (v *Validation) Range(obj interface{}, min, max int, key string) *Validatio
|
|||||||
return v.apply(Range{Min{Min: min}, Max{Max: max}, key}, obj)
|
return v.apply(Range{Min{Min: min}, Max{Max: max}, key}, obj)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that the obj is longer than min size if type is string or slice
|
||||||
func (v *Validation) MinSize(obj interface{}, min int, key string) *ValidationResult {
|
func (v *Validation) MinSize(obj interface{}, min int, key string) *ValidationResult {
|
||||||
return v.apply(MinSize{min, key}, obj)
|
return v.apply(MinSize{min, key}, obj)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that the obj is shorter than max size if type is string or slice
|
||||||
func (v *Validation) MaxSize(obj interface{}, max int, key string) *ValidationResult {
|
func (v *Validation) MaxSize(obj interface{}, max int, key string) *ValidationResult {
|
||||||
return v.apply(MaxSize{max, key}, obj)
|
return v.apply(MaxSize{max, key}, obj)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that the obj is same length to n if type is string or slice
|
||||||
func (v *Validation) Length(obj interface{}, n int, key string) *ValidationResult {
|
func (v *Validation) Length(obj interface{}, n int, key string) *ValidationResult {
|
||||||
return v.apply(Length{n, key}, obj)
|
return v.apply(Length{n, key}, obj)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that the obj is [a-zA-Z] if type is string
|
||||||
func (v *Validation) Alpha(obj interface{}, key string) *ValidationResult {
|
func (v *Validation) Alpha(obj interface{}, key string) *ValidationResult {
|
||||||
return v.apply(Alpha{key}, obj)
|
return v.apply(Alpha{key}, obj)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that the obj is [0-9] if type is string
|
||||||
func (v *Validation) Numeric(obj interface{}, key string) *ValidationResult {
|
func (v *Validation) Numeric(obj interface{}, key string) *ValidationResult {
|
||||||
return v.apply(Numeric{key}, obj)
|
return v.apply(Numeric{key}, obj)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that the obj is [0-9a-zA-Z] if type is string
|
||||||
func (v *Validation) AlphaNumeric(obj interface{}, key string) *ValidationResult {
|
func (v *Validation) AlphaNumeric(obj interface{}, key string) *ValidationResult {
|
||||||
return v.apply(AlphaNumeric{key}, obj)
|
return v.apply(AlphaNumeric{key}, obj)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that the obj matches regexp if type is string
|
||||||
func (v *Validation) Match(obj interface{}, regex *regexp.Regexp, key string) *ValidationResult {
|
func (v *Validation) Match(obj interface{}, regex *regexp.Regexp, key string) *ValidationResult {
|
||||||
return v.apply(Match{regex, key}, obj)
|
return v.apply(Match{regex, key}, obj)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that the obj doesn't match regexp if type is string
|
||||||
func (v *Validation) NoMatch(obj interface{}, regex *regexp.Regexp, key string) *ValidationResult {
|
func (v *Validation) NoMatch(obj interface{}, regex *regexp.Regexp, key string) *ValidationResult {
|
||||||
return v.apply(NoMatch{Match{Regexp: regex}, key}, obj)
|
return v.apply(NoMatch{Match{Regexp: regex}, key}, obj)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that the obj is [0-9a-zA-Z_-] if type is string
|
||||||
func (v *Validation) AlphaDash(obj interface{}, key string) *ValidationResult {
|
func (v *Validation) AlphaDash(obj interface{}, key string) *ValidationResult {
|
||||||
return v.apply(AlphaDash{NoMatch{Match: Match{Regexp: alphaDashPattern}}, key}, obj)
|
return v.apply(AlphaDash{NoMatch{Match: Match{Regexp: alphaDashPattern}}, key}, obj)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that the obj is email address if type is string
|
||||||
func (v *Validation) Email(obj interface{}, key string) *ValidationResult {
|
func (v *Validation) Email(obj interface{}, key string) *ValidationResult {
|
||||||
return v.apply(Email{Match{Regexp: emailPattern}, key}, obj)
|
return v.apply(Email{Match{Regexp: emailPattern}, key}, obj)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that the obj is IP address if type is string
|
||||||
func (v *Validation) IP(obj interface{}, key string) *ValidationResult {
|
func (v *Validation) IP(obj interface{}, key string) *ValidationResult {
|
||||||
return v.apply(IP{Match{Regexp: ipPattern}, key}, obj)
|
return v.apply(IP{Match{Regexp: ipPattern}, key}, obj)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that the obj is base64 encoded if type is string
|
||||||
func (v *Validation) Base64(obj interface{}, key string) *ValidationResult {
|
func (v *Validation) Base64(obj interface{}, key string) *ValidationResult {
|
||||||
return v.apply(Base64{Match{Regexp: base64Pattern}, key}, obj)
|
return v.apply(Base64{Match{Regexp: base64Pattern}, key}, obj)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that the obj is chinese mobile number if type is string
|
||||||
func (v *Validation) Mobile(obj interface{}, key string) *ValidationResult {
|
func (v *Validation) Mobile(obj interface{}, key string) *ValidationResult {
|
||||||
return v.apply(Mobile{Match{Regexp: mobilePattern}, key}, obj)
|
return v.apply(Mobile{Match{Regexp: mobilePattern}, key}, obj)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that the obj is chinese telephone number if type is string
|
||||||
func (v *Validation) Tel(obj interface{}, key string) *ValidationResult {
|
func (v *Validation) Tel(obj interface{}, key string) *ValidationResult {
|
||||||
return v.apply(Tel{Match{Regexp: telPattern}, key}, obj)
|
return v.apply(Tel{Match{Regexp: telPattern}, key}, obj)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that the obj is chinese mobile or telephone number if type is string
|
||||||
func (v *Validation) Phone(obj interface{}, key string) *ValidationResult {
|
func (v *Validation) Phone(obj interface{}, key string) *ValidationResult {
|
||||||
return v.apply(Phone{Mobile{Match: Match{Regexp: mobilePattern}},
|
return v.apply(Phone{Mobile{Match: Match{Regexp: mobilePattern}},
|
||||||
Tel{Match: Match{Regexp: telPattern}}, key}, obj)
|
Tel{Match: Match{Regexp: telPattern}}, key}, obj)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that the obj is chinese zip code if type is string
|
||||||
func (v *Validation) ZipCode(obj interface{}, key string) *ValidationResult {
|
func (v *Validation) ZipCode(obj interface{}, key string) *ValidationResult {
|
||||||
return v.apply(ZipCode{Match{Regexp: zipCodePattern}, key}, obj)
|
return v.apply(ZipCode{Match{Regexp: zipCodePattern}, key}, obj)
|
||||||
}
|
}
|
||||||
@ -210,6 +230,7 @@ func (v *Validation) setError(err *ValidationError) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set error message for one field in ValidationError
|
||||||
func (v *Validation) SetError(fieldName string, errMsg string) *ValidationError {
|
func (v *Validation) SetError(fieldName string, errMsg string) *ValidationError {
|
||||||
err := &ValidationError{Key: fieldName, Field: fieldName, Tmpl: errMsg, Message: errMsg}
|
err := &ValidationError{Key: fieldName, Field: fieldName, Tmpl: errMsg, Message: errMsg}
|
||||||
v.setError(err)
|
v.setError(err)
|
||||||
@ -230,6 +251,7 @@ func (v *Validation) Check(obj interface{}, checks ...Validator) *ValidationResu
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate a struct.
|
||||||
// the obj parameter must be a struct or a struct pointer
|
// the obj parameter must be a struct or a struct pointer
|
||||||
func (v *Validation) Valid(obj interface{}) (b bool, err error) {
|
func (v *Validation) Valid(obj interface{}) (b bool, err error) {
|
||||||
objT := reflect.TypeOf(obj)
|
objT := reflect.TypeOf(obj)
|
||||||
|
Loading…
Reference in New Issue
Block a user