mirror of
https://github.com/astaxie/beego.git
synced 2024-11-10 18:20:55 +00:00
version 1.1.2 release
This commit is contained in:
commit
6497f29ed7
2
.gitignore
vendored
2
.gitignore
vendored
@ -1 +1,3 @@
|
|||||||
.DS_Store
|
.DS_Store
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
24
beego.go
24
beego.go
@ -2,6 +2,7 @@ package beego
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -12,7 +13,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// beego web framework version.
|
// beego web framework version.
|
||||||
const VERSION = "1.1.1"
|
const VERSION = "1.1.2"
|
||||||
|
|
||||||
type hookfunc func() error //hook function to run
|
type hookfunc func() error //hook function to run
|
||||||
var hooks []hookfunc //hook function slice to store the hookfunc
|
var hooks []hookfunc //hook function slice to store the hookfunc
|
||||||
@ -174,6 +175,16 @@ func AddAPPStartHook(hf hookfunc) {
|
|||||||
// Run beego application.
|
// Run beego application.
|
||||||
// it's alias of App.Run.
|
// it's alias of App.Run.
|
||||||
func Run() {
|
func Run() {
|
||||||
|
initBeforeHttpRun()
|
||||||
|
|
||||||
|
if EnableAdmin {
|
||||||
|
go BeeAdminApp.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
BeeApp.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func initBeforeHttpRun() {
|
||||||
// if AppConfigPath not In the conf/app.conf reParse config
|
// if AppConfigPath not In the conf/app.conf reParse config
|
||||||
if AppConfigPath != filepath.Join(AppPath, "conf", "app.conf") {
|
if AppConfigPath != filepath.Join(AppPath, "conf", "app.conf") {
|
||||||
err := ParseConfig()
|
err := ParseConfig()
|
||||||
@ -222,12 +233,13 @@ func Run() {
|
|||||||
middleware.VERSION = VERSION
|
middleware.VERSION = VERSION
|
||||||
middleware.AppName = AppName
|
middleware.AppName = AppName
|
||||||
middleware.RegisterErrorHandler()
|
middleware.RegisterErrorHandler()
|
||||||
|
}
|
||||||
|
|
||||||
if EnableAdmin {
|
func TestBeegoInit(apppath string) {
|
||||||
go BeeAdminApp.Run()
|
AppPath = apppath
|
||||||
}
|
AppConfigPath = filepath.Join(AppPath, "conf", "app.conf")
|
||||||
|
os.Chdir(AppPath)
|
||||||
BeeApp.Run()
|
initBeforeHttpRun()
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
30
config.go
30
config.go
@ -11,12 +11,14 @@ import (
|
|||||||
"github.com/astaxie/beego/config"
|
"github.com/astaxie/beego/config"
|
||||||
"github.com/astaxie/beego/logs"
|
"github.com/astaxie/beego/logs"
|
||||||
"github.com/astaxie/beego/session"
|
"github.com/astaxie/beego/session"
|
||||||
|
"github.com/astaxie/beego/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
BeeApp *App // beego application
|
BeeApp *App // beego application
|
||||||
AppName string
|
AppName string
|
||||||
AppPath string
|
AppPath string
|
||||||
|
workPath string
|
||||||
AppConfigPath string
|
AppConfigPath string
|
||||||
StaticDir map[string]string
|
StaticDir map[string]string
|
||||||
TemplateCache map[string]*template.Template // template caching map
|
TemplateCache map[string]*template.Template // template caching map
|
||||||
@ -58,15 +60,28 @@ var (
|
|||||||
EnableAdmin bool // flag of enable admin module to log every request info.
|
EnableAdmin bool // flag of enable admin module to log every request info.
|
||||||
AdminHttpAddr string // http server configurations for admin module.
|
AdminHttpAddr string // http server configurations for admin module.
|
||||||
AdminHttpPort int
|
AdminHttpPort int
|
||||||
|
FlashName string // name of the flash variable found in response header and cookie
|
||||||
|
FlashSeperator string // used to seperate flash key:value
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
// create beego application
|
// create beego application
|
||||||
BeeApp = NewApp()
|
BeeApp = NewApp()
|
||||||
|
|
||||||
|
workPath, _ = os.Getwd()
|
||||||
|
workPath, _ = filepath.Abs(workPath)
|
||||||
// initialize default configurations
|
// initialize default configurations
|
||||||
AppPath, _ = filepath.Abs(filepath.Dir(os.Args[0]))
|
AppPath, _ = filepath.Abs(filepath.Dir(os.Args[0]))
|
||||||
os.Chdir(AppPath)
|
|
||||||
|
AppConfigPath = filepath.Join(AppPath, "conf", "app.conf")
|
||||||
|
|
||||||
|
if workPath != AppPath {
|
||||||
|
if utils.FileExists(AppConfigPath) {
|
||||||
|
os.Chdir(AppPath)
|
||||||
|
} else {
|
||||||
|
AppConfigPath = filepath.Join(workPath, "conf", "app.conf")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
StaticDir = make(map[string]string)
|
StaticDir = make(map[string]string)
|
||||||
StaticDir["/static"] = "static"
|
StaticDir["/static"] = "static"
|
||||||
@ -105,8 +120,6 @@ func init() {
|
|||||||
|
|
||||||
EnableGzip = false
|
EnableGzip = false
|
||||||
|
|
||||||
AppConfigPath = filepath.Join(AppPath, "conf", "app.conf")
|
|
||||||
|
|
||||||
HttpServerTimeOut = 0
|
HttpServerTimeOut = 0
|
||||||
|
|
||||||
ErrorsShow = true
|
ErrorsShow = true
|
||||||
@ -123,6 +136,9 @@ func init() {
|
|||||||
AdminHttpAddr = "127.0.0.1"
|
AdminHttpAddr = "127.0.0.1"
|
||||||
AdminHttpPort = 8088
|
AdminHttpPort = 8088
|
||||||
|
|
||||||
|
FlashName = "BEEGO_FLASH"
|
||||||
|
FlashSeperator = "BEEGOFLASH"
|
||||||
|
|
||||||
runtime.GOMAXPROCS(runtime.NumCPU())
|
runtime.GOMAXPROCS(runtime.NumCPU())
|
||||||
|
|
||||||
// init BeeLogger
|
// init BeeLogger
|
||||||
@ -271,6 +287,14 @@ func ParseConfig() (err error) {
|
|||||||
BeegoServerName = serverName
|
BeegoServerName = serverName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if flashname := AppConfig.String("FlashName"); flashname != "" {
|
||||||
|
FlashName = flashname
|
||||||
|
}
|
||||||
|
|
||||||
|
if flashseperator := AppConfig.String("FlashSeperator"); flashseperator != "" {
|
||||||
|
FlashSeperator = flashseperator
|
||||||
|
}
|
||||||
|
|
||||||
if sd := AppConfig.String("StaticDir"); sd != "" {
|
if sd := AppConfig.String("StaticDir"); sd != "" {
|
||||||
for k := range StaticDir {
|
for k := range StaticDir {
|
||||||
delete(StaticDir, k)
|
delete(StaticDir, k)
|
||||||
|
15
config_test.go
Normal file
15
config_test.go
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
package beego
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDefaults(t *testing.T) {
|
||||||
|
if FlashName != "BEEGO_FLASH" {
|
||||||
|
t.Errorf("FlashName was not set to default.")
|
||||||
|
}
|
||||||
|
|
||||||
|
if FlashSeperator != "BEEGOFLASH" {
|
||||||
|
t.Errorf("FlashName was not set to default.")
|
||||||
|
}
|
||||||
|
}
|
@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -13,11 +14,13 @@ import (
|
|||||||
// BeegoInput operates the http request header ,data ,cookie and body.
|
// BeegoInput operates the http request header ,data ,cookie and body.
|
||||||
// it also contains router params and current session.
|
// it also contains router params and current session.
|
||||||
type BeegoInput struct {
|
type BeegoInput struct {
|
||||||
CruSession session.SessionStore
|
CruSession session.SessionStore
|
||||||
Params map[string]string
|
Params map[string]string
|
||||||
Data map[interface{}]interface{} // store some values in this context when calling context in filter or controller.
|
Data map[interface{}]interface{} // store some values in this context when calling context in filter or controller.
|
||||||
Request *http.Request
|
Request *http.Request
|
||||||
RequestBody []byte
|
RequestBody []byte
|
||||||
|
RunController reflect.Type
|
||||||
|
RunMethod string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewInput return BeegoInput generated by http.Request.
|
// NewInput return BeegoInput generated by http.Request.
|
||||||
|
@ -62,7 +62,6 @@ type ControllerInterface interface {
|
|||||||
|
|
||||||
// Init generates default values of controller operations.
|
// Init generates default values of controller operations.
|
||||||
func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) {
|
func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) {
|
||||||
c.Data = make(map[interface{}]interface{})
|
|
||||||
c.Layout = ""
|
c.Layout = ""
|
||||||
c.TplNames = ""
|
c.TplNames = ""
|
||||||
c.controllerName = controllerName
|
c.controllerName = controllerName
|
||||||
|
17
flash.go
17
flash.go
@ -6,9 +6,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// the separation string when encoding flash data.
|
|
||||||
const BEEGO_FLASH_SEP = "#BEEGOFLASH#"
|
|
||||||
|
|
||||||
// FlashData is a tools to maintain data when using across request.
|
// FlashData is a tools to maintain data when using across request.
|
||||||
type FlashData struct {
|
type FlashData struct {
|
||||||
Data map[string]string
|
Data map[string]string
|
||||||
@ -54,29 +51,27 @@ func (fd *FlashData) Store(c *Controller) {
|
|||||||
c.Data["flash"] = fd.Data
|
c.Data["flash"] = fd.Data
|
||||||
var flashValue string
|
var flashValue string
|
||||||
for key, value := range fd.Data {
|
for key, value := range fd.Data {
|
||||||
flashValue += "\x00" + key + BEEGO_FLASH_SEP + value + "\x00"
|
flashValue += "\x00" + key + "\x23" + FlashSeperator + "\x23" + value + "\x00"
|
||||||
}
|
}
|
||||||
c.Ctx.SetCookie("BEEGO_FLASH", url.QueryEscape(flashValue), 0, "/")
|
c.Ctx.SetCookie(FlashName, url.QueryEscape(flashValue), 0, "/")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadFromRequest parsed flash data from encoded values in cookie.
|
// ReadFromRequest parsed flash data from encoded values in cookie.
|
||||||
func ReadFromRequest(c *Controller) *FlashData {
|
func ReadFromRequest(c *Controller) *FlashData {
|
||||||
flash := &FlashData{
|
flash := NewFlash()
|
||||||
Data: make(map[string]string),
|
if cookie, err := c.Ctx.Request.Cookie(FlashName); err == nil {
|
||||||
}
|
|
||||||
if cookie, err := c.Ctx.Request.Cookie("BEEGO_FLASH"); err == nil {
|
|
||||||
v, _ := url.QueryUnescape(cookie.Value)
|
v, _ := url.QueryUnescape(cookie.Value)
|
||||||
vals := strings.Split(v, "\x00")
|
vals := strings.Split(v, "\x00")
|
||||||
for _, v := range vals {
|
for _, v := range vals {
|
||||||
if len(v) > 0 {
|
if len(v) > 0 {
|
||||||
kv := strings.Split(v, BEEGO_FLASH_SEP)
|
kv := strings.Split(v, FlashSeperator)
|
||||||
if len(kv) == 2 {
|
if len(kv) == 2 {
|
||||||
flash.Data[kv[0]] = kv[1]
|
flash.Data[kv[0]] = kv[1]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
//read one time then delete it
|
//read one time then delete it
|
||||||
c.Ctx.SetCookie("BEEGO_FLASH", "", -1, "/")
|
c.Ctx.SetCookie(FlashName, "", -1, "/")
|
||||||
}
|
}
|
||||||
c.Data["flash"] = flash.Data
|
c.Data["flash"] = flash.Data
|
||||||
return flash
|
return flash
|
||||||
|
40
flash_test.go
Normal file
40
flash_test.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
package beego
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TestFlashController struct {
|
||||||
|
Controller
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *TestFlashController) TestWriteFlash() {
|
||||||
|
flash := NewFlash()
|
||||||
|
flash.Notice("TestFlashString")
|
||||||
|
flash.Store(&this.Controller)
|
||||||
|
// we choose to serve json because we don't want to load a template html file
|
||||||
|
this.ServeJson(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFlashHeader(t *testing.T) {
|
||||||
|
// create fake GET request
|
||||||
|
r, _ := http.NewRequest("GET", "/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// setup the handler
|
||||||
|
handler := NewControllerRegistor()
|
||||||
|
handler.Add("/", &TestFlashController{}, "get:TestWriteFlash")
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
// get the Set-Cookie value
|
||||||
|
sc := w.Header().Get("Set-Cookie")
|
||||||
|
// match for the expected header
|
||||||
|
res := strings.Contains(sc, "BEEGO_FLASH=%00notice%23BEEGOFLASH%23TestFlashString%00")
|
||||||
|
// validate the assertion
|
||||||
|
if res != true {
|
||||||
|
t.Errorf("TestFlashHeader() unable to validate flash message")
|
||||||
|
}
|
||||||
|
}
|
13
log.go
13
log.go
@ -22,12 +22,21 @@ func SetLevel(l int) {
|
|||||||
BeeLogger.SetLevel(l)
|
BeeLogger.SetLevel(l)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SetLogFuncCall(b bool) {
|
||||||
|
BeeLogger.EnableFuncCallDepth(b)
|
||||||
|
BeeLogger.SetLogFuncCallDepth(3)
|
||||||
|
}
|
||||||
|
|
||||||
// logger references the used application logger.
|
// logger references the used application logger.
|
||||||
var BeeLogger *logs.BeeLogger
|
var BeeLogger *logs.BeeLogger
|
||||||
|
|
||||||
// SetLogger sets a new logger.
|
// SetLogger sets a new logger.
|
||||||
func SetLogger(adaptername string, config string) {
|
func SetLogger(adaptername string, config string) error {
|
||||||
BeeLogger.SetLogger(adaptername, config)
|
err := BeeLogger.SetLogger(adaptername, config)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Trace logs a message at trace level.
|
// Trace logs a message at trace level.
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
func TestConsole(t *testing.T) {
|
func TestConsole(t *testing.T) {
|
||||||
log := NewLogger(10000)
|
log := NewLogger(10000)
|
||||||
|
log.EnableFuncCallDepth(true)
|
||||||
log.SetLogger("console", "")
|
log.SetLogger("console", "")
|
||||||
log.Trace("trace")
|
log.Trace("trace")
|
||||||
log.Info("info")
|
log.Info("info")
|
||||||
@ -23,6 +24,7 @@ func TestConsole(t *testing.T) {
|
|||||||
|
|
||||||
func BenchmarkConsole(b *testing.B) {
|
func BenchmarkConsole(b *testing.B) {
|
||||||
log := NewLogger(10000)
|
log := NewLogger(10000)
|
||||||
|
log.EnableFuncCallDepth(true)
|
||||||
log.SetLogger("console", "")
|
log.SetLogger("console", "")
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
log.Trace("trace")
|
log.Trace("trace")
|
||||||
|
@ -97,12 +97,12 @@ func (w *FileLogWriter) Init(jsonconfig string) error {
|
|||||||
if len(w.Filename) == 0 {
|
if len(w.Filename) == 0 {
|
||||||
return errors.New("jsonconfig must have filename")
|
return errors.New("jsonconfig must have filename")
|
||||||
}
|
}
|
||||||
err = w.StartLogger()
|
err = w.startLogger()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// start file logger. create log file and set to locker-inside file writer.
|
// start file logger. create log file and set to locker-inside file writer.
|
||||||
func (w *FileLogWriter) StartLogger() error {
|
func (w *FileLogWriter) startLogger() error {
|
||||||
fd, err := w.createLogFile()
|
fd, err := w.createLogFile()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -199,7 +199,7 @@ func (w *FileLogWriter) DoRotate() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// re-start logger
|
// re-start logger
|
||||||
err = w.StartLogger()
|
err = w.startLogger()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Rotate StartLogger: %s\n", err)
|
return fmt.Errorf("Rotate StartLogger: %s\n", err)
|
||||||
}
|
}
|
||||||
|
44
logs/log.go
44
logs/log.go
@ -2,6 +2,8 @@ package logs
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"path"
|
||||||
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -43,10 +45,12 @@ func Register(name string, log loggerType) {
|
|||||||
// BeeLogger is default logger in beego application.
|
// BeeLogger is default logger in beego application.
|
||||||
// it can contain several providers and log message into all providers.
|
// it can contain several providers and log message into all providers.
|
||||||
type BeeLogger struct {
|
type BeeLogger struct {
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
level int
|
level int
|
||||||
msg chan *logMsg
|
enableFuncCallDepth bool
|
||||||
outputs map[string]LoggerInterface
|
loggerFuncCallDepth int
|
||||||
|
msg chan *logMsg
|
||||||
|
outputs map[string]LoggerInterface
|
||||||
}
|
}
|
||||||
|
|
||||||
type logMsg struct {
|
type logMsg struct {
|
||||||
@ -59,10 +63,11 @@ type logMsg struct {
|
|||||||
// if the buffering chan is full, logger adapters write to file or other way.
|
// if the buffering chan is full, logger adapters write to file or other way.
|
||||||
func NewLogger(channellen int64) *BeeLogger {
|
func NewLogger(channellen int64) *BeeLogger {
|
||||||
bl := new(BeeLogger)
|
bl := new(BeeLogger)
|
||||||
|
bl.loggerFuncCallDepth = 2
|
||||||
bl.msg = make(chan *logMsg, channellen)
|
bl.msg = make(chan *logMsg, channellen)
|
||||||
bl.outputs = make(map[string]LoggerInterface)
|
bl.outputs = make(map[string]LoggerInterface)
|
||||||
//bl.SetLogger("console", "") // default output to console
|
//bl.SetLogger("console", "") // default output to console
|
||||||
go bl.StartLogger()
|
go bl.startLogger()
|
||||||
return bl
|
return bl
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -73,7 +78,10 @@ func (bl *BeeLogger) SetLogger(adaptername string, config string) error {
|
|||||||
defer bl.lock.Unlock()
|
defer bl.lock.Unlock()
|
||||||
if log, ok := adapters[adaptername]; ok {
|
if log, ok := adapters[adaptername]; ok {
|
||||||
lg := log()
|
lg := log()
|
||||||
lg.Init(config)
|
err := lg.Init(config)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
bl.outputs[adaptername] = lg
|
bl.outputs[adaptername] = lg
|
||||||
return nil
|
return nil
|
||||||
} else {
|
} else {
|
||||||
@ -100,7 +108,17 @@ func (bl *BeeLogger) writerMsg(loglevel int, msg string) error {
|
|||||||
}
|
}
|
||||||
lm := new(logMsg)
|
lm := new(logMsg)
|
||||||
lm.level = loglevel
|
lm.level = loglevel
|
||||||
lm.msg = msg
|
if bl.enableFuncCallDepth {
|
||||||
|
_, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth)
|
||||||
|
if ok {
|
||||||
|
_, filename := path.Split(file)
|
||||||
|
lm.msg = fmt.Sprintf("[%s:%d] %s", filename, line, msg)
|
||||||
|
} else {
|
||||||
|
lm.msg = msg
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
lm.msg = msg
|
||||||
|
}
|
||||||
bl.msg <- lm
|
bl.msg <- lm
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -111,9 +129,19 @@ func (bl *BeeLogger) SetLevel(l int) {
|
|||||||
bl.level = l
|
bl.level = l
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set log funcCallDepth
|
||||||
|
func (bl *BeeLogger) SetLogFuncCallDepth(d int) {
|
||||||
|
bl.loggerFuncCallDepth = d
|
||||||
|
}
|
||||||
|
|
||||||
|
// enable log funcCallDepth
|
||||||
|
func (bl *BeeLogger) EnableFuncCallDepth(b bool) {
|
||||||
|
bl.enableFuncCallDepth = b
|
||||||
|
}
|
||||||
|
|
||||||
// start logger chan reading.
|
// start logger chan reading.
|
||||||
// when chan is full, write logs.
|
// when chan is full, write logs.
|
||||||
func (bl *BeeLogger) StartLogger() {
|
func (bl *BeeLogger) startLogger() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case bm := <-bl.msg:
|
case bm := <-bl.msg:
|
||||||
|
112
orm/db_utils.go
112
orm/db_utils.go
@ -51,9 +51,16 @@ outFor:
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
switch v := arg.(type) {
|
kind := val.Kind()
|
||||||
case []byte:
|
if kind == reflect.Ptr {
|
||||||
case string:
|
val = val.Elem()
|
||||||
|
kind = val.Kind()
|
||||||
|
arg = val.Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
switch kind {
|
||||||
|
case reflect.String:
|
||||||
|
v := val.String()
|
||||||
if fi != nil {
|
if fi != nil {
|
||||||
if fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField {
|
if fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField {
|
||||||
var t time.Time
|
var t time.Time
|
||||||
@ -78,61 +85,66 @@ outFor:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
arg = v
|
arg = v
|
||||||
case time.Time:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
if fi != nil && fi.fieldType == TypeDateField {
|
arg = val.Int()
|
||||||
arg = v.In(tz).Format(format_Date)
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
} else {
|
arg = val.Uint()
|
||||||
arg = v.In(tz).Format(format_DateTime)
|
case reflect.Float32:
|
||||||
}
|
arg, _ = StrTo(ToStr(arg)).Float64()
|
||||||
default:
|
case reflect.Float64:
|
||||||
kind := val.Kind()
|
arg = val.Float()
|
||||||
switch kind {
|
case reflect.Bool:
|
||||||
case reflect.Slice, reflect.Array:
|
arg = val.Bool()
|
||||||
|
case reflect.Slice, reflect.Array:
|
||||||
var args []interface{}
|
if _, ok := arg.([]byte); ok {
|
||||||
for i := 0; i < val.Len(); i++ {
|
|
||||||
v := val.Index(i)
|
|
||||||
|
|
||||||
var vu interface{}
|
|
||||||
if v.CanInterface() {
|
|
||||||
vu = v.Interface()
|
|
||||||
}
|
|
||||||
|
|
||||||
if vu == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
args = append(args, vu)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(args) > 0 {
|
|
||||||
p := getFlatParams(fi, args, tz)
|
|
||||||
params = append(params, p...)
|
|
||||||
}
|
|
||||||
continue outFor
|
continue outFor
|
||||||
|
}
|
||||||
|
|
||||||
case reflect.Ptr, reflect.Struct:
|
var args []interface{}
|
||||||
ind := reflect.Indirect(val)
|
for i := 0; i < val.Len(); i++ {
|
||||||
|
v := val.Index(i)
|
||||||
|
|
||||||
if ind.Kind() == reflect.Struct {
|
var vu interface{}
|
||||||
typ := ind.Type()
|
if v.CanInterface() {
|
||||||
name := getFullName(typ)
|
vu = v.Interface()
|
||||||
var value interface{}
|
}
|
||||||
if mmi, ok := modelCache.getByFN(name); ok {
|
|
||||||
if _, vu, exist := getExistPk(mmi, ind); exist {
|
|
||||||
value = vu
|
|
||||||
}
|
|
||||||
}
|
|
||||||
arg = value
|
|
||||||
|
|
||||||
if arg == nil {
|
if vu == nil {
|
||||||
panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name))
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
args = append(args, vu)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(args) > 0 {
|
||||||
|
p := getFlatParams(fi, args, tz)
|
||||||
|
params = append(params, p...)
|
||||||
|
}
|
||||||
|
continue outFor
|
||||||
|
case reflect.Struct:
|
||||||
|
if v, ok := arg.(time.Time); ok {
|
||||||
|
if fi != nil && fi.fieldType == TypeDateField {
|
||||||
|
arg = v.In(tz).Format(format_Date)
|
||||||
} else {
|
} else {
|
||||||
arg = ind.Interface()
|
arg = v.In(tz).Format(format_DateTime)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
typ := val.Type()
|
||||||
|
name := getFullName(typ)
|
||||||
|
var value interface{}
|
||||||
|
if mmi, ok := modelCache.getByFN(name); ok {
|
||||||
|
if _, vu, exist := getExistPk(mmi, val); exist {
|
||||||
|
value = vu
|
||||||
|
}
|
||||||
|
}
|
||||||
|
arg = value
|
||||||
|
|
||||||
|
if arg == nil {
|
||||||
|
panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
params = append(params, arg)
|
params = append(params, arg)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
@ -144,6 +144,45 @@ type DataNull struct {
|
|||||||
NullInt64 sql.NullInt64 `orm:"null"`
|
NullInt64 sql.NullInt64 `orm:"null"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type String string
|
||||||
|
type Boolean bool
|
||||||
|
type Byte byte
|
||||||
|
type Rune rune
|
||||||
|
type Int int
|
||||||
|
type Int8 int8
|
||||||
|
type Int16 int16
|
||||||
|
type Int32 int32
|
||||||
|
type Int64 int64
|
||||||
|
type Uint uint
|
||||||
|
type Uint8 uint8
|
||||||
|
type Uint16 uint16
|
||||||
|
type Uint32 uint32
|
||||||
|
type Uint64 uint64
|
||||||
|
type Float32 float64
|
||||||
|
type Float64 float64
|
||||||
|
|
||||||
|
type DataCustom struct {
|
||||||
|
Id int
|
||||||
|
Boolean Boolean
|
||||||
|
Char string `orm:"size(50)"`
|
||||||
|
Text string `orm:"type(text)"`
|
||||||
|
Byte Byte
|
||||||
|
Rune Rune
|
||||||
|
Int Int
|
||||||
|
Int8 Int8
|
||||||
|
Int16 Int16
|
||||||
|
Int32 Int32
|
||||||
|
Int64 Int64
|
||||||
|
Uint Uint
|
||||||
|
Uint8 Uint8
|
||||||
|
Uint16 Uint16
|
||||||
|
Uint32 Uint32
|
||||||
|
Uint64 Uint64
|
||||||
|
Float32 Float32
|
||||||
|
Float64 Float64
|
||||||
|
Decimal Float64 `orm:"digits(8);decimals(4)"`
|
||||||
|
}
|
||||||
|
|
||||||
// only for mysql
|
// only for mysql
|
||||||
type UserBig struct {
|
type UserBig struct {
|
||||||
Id uint64
|
Id uint64
|
||||||
@ -155,7 +194,7 @@ type User struct {
|
|||||||
UserName string `orm:"size(30);unique"`
|
UserName string `orm:"size(30);unique"`
|
||||||
Email string `orm:"size(100)"`
|
Email string `orm:"size(100)"`
|
||||||
Password string `orm:"size(100)"`
|
Password string `orm:"size(100)"`
|
||||||
Status int16
|
Status int16 `orm:"column(Status)"`
|
||||||
IsStaff bool
|
IsStaff bool
|
||||||
IsActive bool `orm:"default(1)"`
|
IsActive bool `orm:"default(1)"`
|
||||||
Created time.Time `orm:"auto_now_add;type(date)"`
|
Created time.Time `orm:"auto_now_add;type(date)"`
|
||||||
|
@ -80,7 +80,6 @@ func getTableUnique(val reflect.Value) [][]string {
|
|||||||
|
|
||||||
// get snaked column name
|
// get snaked column name
|
||||||
func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string {
|
func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string {
|
||||||
col = strings.ToLower(col)
|
|
||||||
column := col
|
column := col
|
||||||
if col == "" {
|
if col == "" {
|
||||||
column = snakeString(sf.Name)
|
column = snakeString(sf.Name)
|
||||||
@ -99,34 +98,41 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col
|
|||||||
// return field type as type constant from reflect.Value
|
// return field type as type constant from reflect.Value
|
||||||
func getFieldType(val reflect.Value) (ft int, err error) {
|
func getFieldType(val reflect.Value) (ft int, err error) {
|
||||||
elm := reflect.Indirect(val)
|
elm := reflect.Indirect(val)
|
||||||
switch elm.Interface().(type) {
|
switch elm.Kind() {
|
||||||
case int8:
|
case reflect.Int8:
|
||||||
ft = TypeBitField
|
ft = TypeBitField
|
||||||
case int16:
|
case reflect.Int16:
|
||||||
ft = TypeSmallIntegerField
|
ft = TypeSmallIntegerField
|
||||||
case int32, int:
|
case reflect.Int32, reflect.Int:
|
||||||
ft = TypeIntegerField
|
ft = TypeIntegerField
|
||||||
case int64, sql.NullInt64:
|
case reflect.Int64:
|
||||||
ft = TypeBigIntegerField
|
ft = TypeBigIntegerField
|
||||||
case uint8:
|
case reflect.Uint8:
|
||||||
ft = TypePositiveBitField
|
ft = TypePositiveBitField
|
||||||
case uint16:
|
case reflect.Uint16:
|
||||||
ft = TypePositiveSmallIntegerField
|
ft = TypePositiveSmallIntegerField
|
||||||
case uint32, uint:
|
case reflect.Uint32, reflect.Uint:
|
||||||
ft = TypePositiveIntegerField
|
ft = TypePositiveIntegerField
|
||||||
case uint64:
|
case reflect.Uint64:
|
||||||
ft = TypePositiveBigIntegerField
|
ft = TypePositiveBigIntegerField
|
||||||
case float32, float64, sql.NullFloat64:
|
case reflect.Float32, reflect.Float64:
|
||||||
ft = TypeFloatField
|
ft = TypeFloatField
|
||||||
case bool, sql.NullBool:
|
case reflect.Bool:
|
||||||
ft = TypeBooleanField
|
ft = TypeBooleanField
|
||||||
case string, sql.NullString:
|
case reflect.String:
|
||||||
ft = TypeCharField
|
ft = TypeCharField
|
||||||
default:
|
default:
|
||||||
if elm.CanInterface() {
|
switch elm.Interface().(type) {
|
||||||
if _, ok := elm.Interface().(time.Time); ok {
|
case sql.NullInt64:
|
||||||
ft = TypeDateTimeField
|
ft = TypeBigIntegerField
|
||||||
}
|
case sql.NullFloat64:
|
||||||
|
ft = TypeFloatField
|
||||||
|
case sql.NullBool:
|
||||||
|
ft = TypeBooleanField
|
||||||
|
case sql.NullString:
|
||||||
|
ft = TypeCharField
|
||||||
|
case time.Time:
|
||||||
|
ft = TypeDateTimeField
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if ft&IsFieldType == 0 {
|
if ft&IsFieldType == 0 {
|
||||||
|
@ -149,7 +149,7 @@ func TestGetDB(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSyncDb(t *testing.T) {
|
func TestSyncDb(t *testing.T) {
|
||||||
RegisterModel(new(Data), new(DataNull))
|
RegisterModel(new(Data), new(DataNull), new(DataCustom))
|
||||||
RegisterModel(new(User))
|
RegisterModel(new(User))
|
||||||
RegisterModel(new(Profile))
|
RegisterModel(new(Profile))
|
||||||
RegisterModel(new(Post))
|
RegisterModel(new(Post))
|
||||||
@ -165,7 +165,7 @@ func TestSyncDb(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRegisterModels(t *testing.T) {
|
func TestRegisterModels(t *testing.T) {
|
||||||
RegisterModel(new(Data), new(DataNull))
|
RegisterModel(new(Data), new(DataNull), new(DataCustom))
|
||||||
RegisterModel(new(User))
|
RegisterModel(new(User))
|
||||||
RegisterModel(new(Profile))
|
RegisterModel(new(Profile))
|
||||||
RegisterModel(new(Post))
|
RegisterModel(new(Post))
|
||||||
@ -309,6 +309,39 @@ func TestNullDataTypes(t *testing.T) {
|
|||||||
throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42))
|
throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDataCustomTypes(t *testing.T) {
|
||||||
|
d := DataCustom{}
|
||||||
|
ind := reflect.Indirect(reflect.ValueOf(&d))
|
||||||
|
|
||||||
|
for name, value := range Data_Values {
|
||||||
|
e := ind.FieldByName(name)
|
||||||
|
if !e.IsValid() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
e.Set(reflect.ValueOf(value).Convert(e.Type()))
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := dORM.Insert(&d)
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(id, 1))
|
||||||
|
|
||||||
|
d = DataCustom{Id: 1}
|
||||||
|
err = dORM.Read(&d)
|
||||||
|
throwFail(t, err)
|
||||||
|
|
||||||
|
ind = reflect.Indirect(reflect.ValueOf(&d))
|
||||||
|
|
||||||
|
for name, value := range Data_Values {
|
||||||
|
e := ind.FieldByName(name)
|
||||||
|
if !e.IsValid() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
vu := e.Interface()
|
||||||
|
value = reflect.ValueOf(value).Convert(e.Type()).Interface()
|
||||||
|
throwFail(t, AssertIs(vu == value, true), value, vu)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCRUD(t *testing.T) {
|
func TestCRUD(t *testing.T) {
|
||||||
profile := NewProfile()
|
profile := NewProfile()
|
||||||
profile.Age = 30
|
profile.Age = 30
|
||||||
@ -562,6 +595,10 @@ func TestOperators(t *testing.T) {
|
|||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(num, 1))
|
throwFail(t, AssertIs(num, 1))
|
||||||
|
|
||||||
|
num, err = qs.Filter("user_name__exact", String("slene")).Count()
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 1))
|
||||||
|
|
||||||
num, err = qs.Filter("user_name__exact", "slene").Count()
|
num, err = qs.Filter("user_name__exact", "slene").Count()
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(num, 1))
|
throwFail(t, AssertIs(num, 1))
|
||||||
@ -602,11 +639,11 @@ func TestOperators(t *testing.T) {
|
|||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(num, 3))
|
throwFail(t, AssertIs(num, 3))
|
||||||
|
|
||||||
num, err = qs.Filter("status__lt", 3).Count()
|
num, err = qs.Filter("status__lt", Uint(3)).Count()
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(num, 2))
|
throwFail(t, AssertIs(num, 2))
|
||||||
|
|
||||||
num, err = qs.Filter("status__lte", 3).Count()
|
num, err = qs.Filter("status__lte", Int(3)).Count()
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(num, 3))
|
throwFail(t, AssertIs(num, 3))
|
||||||
|
|
||||||
@ -1380,7 +1417,7 @@ func TestRawQueryRow(t *testing.T) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
cols = []string{
|
cols = []string{
|
||||||
"id", "status", "profile_id",
|
"id", "Status", "profile_id",
|
||||||
}
|
}
|
||||||
query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s WHERE id = ?", Q, strings.Join(cols, sep), Q, Q, Q)
|
query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s WHERE id = ?", Q, strings.Join(cols, sep), Q, Q, Q)
|
||||||
err = dORM.Raw(query, 4).QueryRow(&uid, &status, &pid)
|
err = dORM.Raw(query, 4).QueryRow(&uid, &status, &pid)
|
||||||
@ -1460,7 +1497,7 @@ func TestRawValues(t *testing.T) {
|
|||||||
Q := dDbBaser.TableQuote()
|
Q := dDbBaser.TableQuote()
|
||||||
|
|
||||||
var maps []Params
|
var maps []Params
|
||||||
query := fmt.Sprintf("SELECT %suser_name%s FROM %suser%s WHERE %sstatus%s = ?", Q, Q, Q, Q, Q, Q)
|
query := fmt.Sprintf("SELECT %suser_name%s FROM %suser%s WHERE %sStatus%s = ?", Q, Q, Q, Q, Q, Q)
|
||||||
num, err := dORM.Raw(query, 1).Values(&maps)
|
num, err := dORM.Raw(query, 1).Values(&maps)
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(num, 1))
|
throwFail(t, AssertIs(num, 1))
|
||||||
|
55
router.go
55
router.go
@ -44,6 +44,11 @@ var (
|
|||||||
"GetControllerAndAction"}
|
"GetControllerAndAction"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// To append a slice's value into "exceptMethod", for controller's methods shouldn't reflect to AutoRouter
|
||||||
|
func ExceptMethodAppend(action string) {
|
||||||
|
exceptMethod = append(exceptMethod, action)
|
||||||
|
}
|
||||||
|
|
||||||
type controllerInfo struct {
|
type controllerInfo struct {
|
||||||
pattern string
|
pattern string
|
||||||
regex *regexp.Regexp
|
regex *regexp.Regexp
|
||||||
@ -621,29 +626,37 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
|||||||
context.Input.Body()
|
context.Input.Body()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if context.Input.RunController != nil && context.Input.RunMethod != "" {
|
||||||
|
findrouter = true
|
||||||
|
runMethod = context.Input.RunMethod
|
||||||
|
runrouter = context.Input.RunController
|
||||||
|
}
|
||||||
|
|
||||||
//first find path from the fixrouters to Improve Performance
|
//first find path from the fixrouters to Improve Performance
|
||||||
for _, route := range p.fixrouters {
|
if !findrouter {
|
||||||
n := len(requestPath)
|
for _, route := range p.fixrouters {
|
||||||
if requestPath == route.pattern {
|
n := len(requestPath)
|
||||||
runMethod = p.getRunMethod(r.Method, context, route)
|
if requestPath == route.pattern {
|
||||||
if runMethod != "" {
|
runMethod = p.getRunMethod(r.Method, context, route)
|
||||||
runrouter = route.controllerType
|
if runMethod != "" {
|
||||||
findrouter = true
|
runrouter = route.controllerType
|
||||||
break
|
findrouter = true
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
// pattern /admin url /admin 200 /admin/ 200
|
||||||
// pattern /admin url /admin 200 /admin/ 200
|
// pattern /admin/ url /admin 301 /admin/ 200
|
||||||
// pattern /admin/ url /admin 301 /admin/ 200
|
if requestPath[n-1] != '/' && requestPath+"/" == route.pattern {
|
||||||
if requestPath[n-1] != '/' && requestPath+"/" == route.pattern {
|
http.Redirect(w, r, requestPath+"/", 301)
|
||||||
http.Redirect(w, r, requestPath+"/", 301)
|
goto Admin
|
||||||
goto Admin
|
}
|
||||||
}
|
if requestPath[n-1] == '/' && route.pattern+"/" == requestPath {
|
||||||
if requestPath[n-1] == '/' && route.pattern+"/" == requestPath {
|
runMethod = p.getRunMethod(r.Method, context, route)
|
||||||
runMethod = p.getRunMethod(r.Method, context, route)
|
if runMethod != "" {
|
||||||
if runMethod != "" {
|
runrouter = route.controllerType
|
||||||
runrouter = route.controllerType
|
findrouter = true
|
||||||
findrouter = true
|
break
|
||||||
break
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -118,6 +118,7 @@ func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
pder.maxlifetime = maxlifetime
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
227
session/sess_postgresql.go
Normal file
227
session/sess_postgresql.go
Normal file
@ -0,0 +1,227 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
/*
|
||||||
|
|
||||||
|
beego session provider for postgresql
|
||||||
|
-------------------------------------
|
||||||
|
|
||||||
|
depends on github.com/lib/pq:
|
||||||
|
|
||||||
|
go install github.com/lib/pq
|
||||||
|
|
||||||
|
|
||||||
|
needs this table in your database:
|
||||||
|
|
||||||
|
CREATE TABLE session (
|
||||||
|
session_key char(64) NOT NULL,
|
||||||
|
session_data bytea,
|
||||||
|
session_expiry timestamp NOT NULL,
|
||||||
|
CONSTRAINT session_key PRIMARY KEY(session_key)
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
will be activated with these settings in app.conf:
|
||||||
|
|
||||||
|
SessionOn = true
|
||||||
|
SessionProvider = postgresql
|
||||||
|
SessionSavePath = "user=a password=b dbname=c sslmode=disable"
|
||||||
|
SessionName = session
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
_ "github.com/lib/pq"
|
||||||
|
)
|
||||||
|
|
||||||
|
var postgresqlpder = &PostgresqlProvider{}
|
||||||
|
|
||||||
|
// postgresql session store
|
||||||
|
type PostgresqlSessionStore struct {
|
||||||
|
c *sql.DB
|
||||||
|
sid string
|
||||||
|
lock sync.RWMutex
|
||||||
|
values map[interface{}]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// set value in postgresql session.
|
||||||
|
// it is temp value in map.
|
||||||
|
func (st *PostgresqlSessionStore) Set(key, value interface{}) error {
|
||||||
|
st.lock.Lock()
|
||||||
|
defer st.lock.Unlock()
|
||||||
|
st.values[key] = value
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// get value from postgresql session
|
||||||
|
func (st *PostgresqlSessionStore) Get(key interface{}) interface{} {
|
||||||
|
st.lock.RLock()
|
||||||
|
defer st.lock.RUnlock()
|
||||||
|
if v, ok := st.values[key]; ok {
|
||||||
|
return v
|
||||||
|
} else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// delete value in postgresql session
|
||||||
|
func (st *PostgresqlSessionStore) Delete(key interface{}) error {
|
||||||
|
st.lock.Lock()
|
||||||
|
defer st.lock.Unlock()
|
||||||
|
delete(st.values, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear all values in postgresql session
|
||||||
|
func (st *PostgresqlSessionStore) Flush() error {
|
||||||
|
st.lock.Lock()
|
||||||
|
defer st.lock.Unlock()
|
||||||
|
st.values = make(map[interface{}]interface{})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// get session id of this postgresql session store
|
||||||
|
func (st *PostgresqlSessionStore) SessionID() string {
|
||||||
|
return st.sid
|
||||||
|
}
|
||||||
|
|
||||||
|
// save postgresql session values to database.
|
||||||
|
// must call this method to save values to database.
|
||||||
|
func (st *PostgresqlSessionStore) SessionRelease(w http.ResponseWriter) {
|
||||||
|
defer st.c.Close()
|
||||||
|
b, err := encodeGob(st.values)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
st.c.Exec("UPDATE session set session_data=$1, session_expiry=$2 where session_key=$3",
|
||||||
|
b, time.Now().Format(time.RFC3339), st.sid)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// postgresql session provider
|
||||||
|
type PostgresqlProvider struct {
|
||||||
|
maxlifetime int64
|
||||||
|
savePath string
|
||||||
|
}
|
||||||
|
|
||||||
|
// connect to postgresql
|
||||||
|
func (mp *PostgresqlProvider) connectInit() *sql.DB {
|
||||||
|
db, e := sql.Open("postgres", mp.savePath)
|
||||||
|
if e != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
// init postgresql session.
|
||||||
|
// savepath is the connection string of postgresql.
|
||||||
|
func (mp *PostgresqlProvider) SessionInit(maxlifetime int64, savePath string) error {
|
||||||
|
mp.maxlifetime = maxlifetime
|
||||||
|
mp.savePath = savePath
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// get postgresql session by sid
|
||||||
|
func (mp *PostgresqlProvider) SessionRead(sid string) (SessionStore, error) {
|
||||||
|
c := mp.connectInit()
|
||||||
|
row := c.QueryRow("select session_data from session where session_key=$1", sid)
|
||||||
|
var sessiondata []byte
|
||||||
|
err := row.Scan(&sessiondata)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
_, err = c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)",
|
||||||
|
sid, "", time.Now().Format(time.RFC3339))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var kv map[interface{}]interface{}
|
||||||
|
if len(sessiondata) == 0 {
|
||||||
|
kv = make(map[interface{}]interface{})
|
||||||
|
} else {
|
||||||
|
kv, err = decodeGob(sessiondata)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rs := &PostgresqlSessionStore{c: c, sid: sid, values: kv}
|
||||||
|
return rs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// check postgresql session exist
|
||||||
|
func (mp *PostgresqlProvider) SessionExist(sid string) bool {
|
||||||
|
c := mp.connectInit()
|
||||||
|
defer c.Close()
|
||||||
|
row := c.QueryRow("select session_data from session where session_key=$1", sid)
|
||||||
|
var sessiondata []byte
|
||||||
|
err := row.Scan(&sessiondata)
|
||||||
|
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return false
|
||||||
|
} else {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generate new sid for postgresql session
|
||||||
|
func (mp *PostgresqlProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
|
||||||
|
c := mp.connectInit()
|
||||||
|
row := c.QueryRow("select session_data from session where session_key=$1", oldsid)
|
||||||
|
var sessiondata []byte
|
||||||
|
err := row.Scan(&sessiondata)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)",
|
||||||
|
oldsid, "", time.Now().Format(time.RFC3339))
|
||||||
|
}
|
||||||
|
c.Exec("update session set session_key=$1 where session_key=$2", sid, oldsid)
|
||||||
|
var kv map[interface{}]interface{}
|
||||||
|
if len(sessiondata) == 0 {
|
||||||
|
kv = make(map[interface{}]interface{})
|
||||||
|
} else {
|
||||||
|
kv, err = decodeGob(sessiondata)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rs := &PostgresqlSessionStore{c: c, sid: sid, values: kv}
|
||||||
|
return rs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// delete postgresql session by sid
|
||||||
|
func (mp *PostgresqlProvider) SessionDestroy(sid string) error {
|
||||||
|
c := mp.connectInit()
|
||||||
|
c.Exec("DELETE FROM session where session_key=$1", sid)
|
||||||
|
c.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// delete expired values in postgresql session
|
||||||
|
func (mp *PostgresqlProvider) SessionGC() {
|
||||||
|
c := mp.connectInit()
|
||||||
|
c.Exec("DELETE from session where EXTRACT(EPOCH FROM (current_timestamp - session_expiry)) > $1", mp.maxlifetime)
|
||||||
|
c.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// count values in postgresql session
|
||||||
|
func (mp *PostgresqlProvider) SessionAll() int {
|
||||||
|
c := mp.connectInit()
|
||||||
|
defer c.Close()
|
||||||
|
var total int
|
||||||
|
err := c.QueryRow("SELECT count(*) as num from session").Scan(&total)
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return total
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
Register("postgresql", postgresqlpder)
|
||||||
|
}
|
@ -151,7 +151,7 @@ func getTplDeep(root, file, parent string, t *template.Template) (*template.Temp
|
|||||||
fileabspath = filepath.Join(root, file)
|
fileabspath = filepath.Join(root, file)
|
||||||
}
|
}
|
||||||
if e := utils.FileExists(fileabspath); !e {
|
if e := utils.FileExists(fileabspath); !e {
|
||||||
panic("can't find template file" + file)
|
panic("can't find template file:" + file)
|
||||||
}
|
}
|
||||||
data, err := ioutil.ReadFile(fileabspath)
|
data, err := ioutil.ReadFile(fileabspath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -67,7 +67,7 @@ func (r Required) IsSatisfied(obj interface{}) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r Required) DefaultMessage() string {
|
func (r Required) DefaultMessage() string {
|
||||||
return "Required"
|
return fmt.Sprint(MessageTmpls["Required"])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r Required) GetKey() string {
|
func (r Required) GetKey() string {
|
||||||
|
Loading…
Reference in New Issue
Block a user