1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-10 18:20:55 +00:00

version 1.1.2 release

This commit is contained in:
astaxie 2014-04-03 15:56:31 +08:00
commit 6497f29ed7
21 changed files with 600 additions and 136 deletions

2
.gitignore vendored
View File

@ -1 +1,3 @@
.DS_Store .DS_Store
*.swp
*.swo

View File

@ -2,6 +2,7 @@ package beego
import ( import (
"net/http" "net/http"
"os"
"path" "path"
"path/filepath" "path/filepath"
"strconv" "strconv"
@ -12,7 +13,7 @@ import (
) )
// beego web framework version. // beego web framework version.
const VERSION = "1.1.1" const VERSION = "1.1.2"
type hookfunc func() error //hook function to run type hookfunc func() error //hook function to run
var hooks []hookfunc //hook function slice to store the hookfunc var hooks []hookfunc //hook function slice to store the hookfunc
@ -174,6 +175,16 @@ func AddAPPStartHook(hf hookfunc) {
// Run beego application. // Run beego application.
// it's alias of App.Run. // it's alias of App.Run.
func Run() { func Run() {
initBeforeHttpRun()
if EnableAdmin {
go BeeAdminApp.Run()
}
BeeApp.Run()
}
func initBeforeHttpRun() {
// if AppConfigPath not In the conf/app.conf reParse config // if AppConfigPath not In the conf/app.conf reParse config
if AppConfigPath != filepath.Join(AppPath, "conf", "app.conf") { if AppConfigPath != filepath.Join(AppPath, "conf", "app.conf") {
err := ParseConfig() err := ParseConfig()
@ -222,12 +233,13 @@ func Run() {
middleware.VERSION = VERSION middleware.VERSION = VERSION
middleware.AppName = AppName middleware.AppName = AppName
middleware.RegisterErrorHandler() middleware.RegisterErrorHandler()
}
if EnableAdmin { func TestBeegoInit(apppath string) {
go BeeAdminApp.Run() AppPath = apppath
} AppConfigPath = filepath.Join(AppPath, "conf", "app.conf")
os.Chdir(AppPath)
BeeApp.Run() initBeforeHttpRun()
} }
func init() { func init() {

View File

@ -11,12 +11,14 @@ import (
"github.com/astaxie/beego/config" "github.com/astaxie/beego/config"
"github.com/astaxie/beego/logs" "github.com/astaxie/beego/logs"
"github.com/astaxie/beego/session" "github.com/astaxie/beego/session"
"github.com/astaxie/beego/utils"
) )
var ( var (
BeeApp *App // beego application BeeApp *App // beego application
AppName string AppName string
AppPath string AppPath string
workPath string
AppConfigPath string AppConfigPath string
StaticDir map[string]string StaticDir map[string]string
TemplateCache map[string]*template.Template // template caching map TemplateCache map[string]*template.Template // template caching map
@ -58,15 +60,28 @@ var (
EnableAdmin bool // flag of enable admin module to log every request info. EnableAdmin bool // flag of enable admin module to log every request info.
AdminHttpAddr string // http server configurations for admin module. AdminHttpAddr string // http server configurations for admin module.
AdminHttpPort int AdminHttpPort int
FlashName string // name of the flash variable found in response header and cookie
FlashSeperator string // used to seperate flash key:value
) )
func init() { func init() {
// create beego application // create beego application
BeeApp = NewApp() BeeApp = NewApp()
workPath, _ = os.Getwd()
workPath, _ = filepath.Abs(workPath)
// initialize default configurations // initialize default configurations
AppPath, _ = filepath.Abs(filepath.Dir(os.Args[0])) AppPath, _ = filepath.Abs(filepath.Dir(os.Args[0]))
os.Chdir(AppPath)
AppConfigPath = filepath.Join(AppPath, "conf", "app.conf")
if workPath != AppPath {
if utils.FileExists(AppConfigPath) {
os.Chdir(AppPath)
} else {
AppConfigPath = filepath.Join(workPath, "conf", "app.conf")
}
}
StaticDir = make(map[string]string) StaticDir = make(map[string]string)
StaticDir["/static"] = "static" StaticDir["/static"] = "static"
@ -105,8 +120,6 @@ func init() {
EnableGzip = false EnableGzip = false
AppConfigPath = filepath.Join(AppPath, "conf", "app.conf")
HttpServerTimeOut = 0 HttpServerTimeOut = 0
ErrorsShow = true ErrorsShow = true
@ -123,6 +136,9 @@ func init() {
AdminHttpAddr = "127.0.0.1" AdminHttpAddr = "127.0.0.1"
AdminHttpPort = 8088 AdminHttpPort = 8088
FlashName = "BEEGO_FLASH"
FlashSeperator = "BEEGOFLASH"
runtime.GOMAXPROCS(runtime.NumCPU()) runtime.GOMAXPROCS(runtime.NumCPU())
// init BeeLogger // init BeeLogger
@ -271,6 +287,14 @@ func ParseConfig() (err error) {
BeegoServerName = serverName BeegoServerName = serverName
} }
if flashname := AppConfig.String("FlashName"); flashname != "" {
FlashName = flashname
}
if flashseperator := AppConfig.String("FlashSeperator"); flashseperator != "" {
FlashSeperator = flashseperator
}
if sd := AppConfig.String("StaticDir"); sd != "" { if sd := AppConfig.String("StaticDir"); sd != "" {
for k := range StaticDir { for k := range StaticDir {
delete(StaticDir, k) delete(StaticDir, k)

15
config_test.go Normal file
View File

@ -0,0 +1,15 @@
package beego
import (
"testing"
)
func TestDefaults(t *testing.T) {
if FlashName != "BEEGO_FLASH" {
t.Errorf("FlashName was not set to default.")
}
if FlashSeperator != "BEEGOFLASH" {
t.Errorf("FlashName was not set to default.")
}
}

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"reflect"
"strconv" "strconv"
"strings" "strings"
@ -13,11 +14,13 @@ import (
// BeegoInput operates the http request header ,data ,cookie and body. // BeegoInput operates the http request header ,data ,cookie and body.
// it also contains router params and current session. // it also contains router params and current session.
type BeegoInput struct { type BeegoInput struct {
CruSession session.SessionStore CruSession session.SessionStore
Params map[string]string Params map[string]string
Data map[interface{}]interface{} // store some values in this context when calling context in filter or controller. Data map[interface{}]interface{} // store some values in this context when calling context in filter or controller.
Request *http.Request Request *http.Request
RequestBody []byte RequestBody []byte
RunController reflect.Type
RunMethod string
} }
// NewInput return BeegoInput generated by http.Request. // NewInput return BeegoInput generated by http.Request.

View File

@ -62,7 +62,6 @@ type ControllerInterface interface {
// Init generates default values of controller operations. // Init generates default values of controller operations.
func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) { func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) {
c.Data = make(map[interface{}]interface{})
c.Layout = "" c.Layout = ""
c.TplNames = "" c.TplNames = ""
c.controllerName = controllerName c.controllerName = controllerName

View File

@ -6,9 +6,6 @@ import (
"strings" "strings"
) )
// the separation string when encoding flash data.
const BEEGO_FLASH_SEP = "#BEEGOFLASH#"
// FlashData is a tools to maintain data when using across request. // FlashData is a tools to maintain data when using across request.
type FlashData struct { type FlashData struct {
Data map[string]string Data map[string]string
@ -54,29 +51,27 @@ func (fd *FlashData) Store(c *Controller) {
c.Data["flash"] = fd.Data c.Data["flash"] = fd.Data
var flashValue string var flashValue string
for key, value := range fd.Data { for key, value := range fd.Data {
flashValue += "\x00" + key + BEEGO_FLASH_SEP + value + "\x00" flashValue += "\x00" + key + "\x23" + FlashSeperator + "\x23" + value + "\x00"
} }
c.Ctx.SetCookie("BEEGO_FLASH", url.QueryEscape(flashValue), 0, "/") c.Ctx.SetCookie(FlashName, url.QueryEscape(flashValue), 0, "/")
} }
// ReadFromRequest parsed flash data from encoded values in cookie. // ReadFromRequest parsed flash data from encoded values in cookie.
func ReadFromRequest(c *Controller) *FlashData { func ReadFromRequest(c *Controller) *FlashData {
flash := &FlashData{ flash := NewFlash()
Data: make(map[string]string), if cookie, err := c.Ctx.Request.Cookie(FlashName); err == nil {
}
if cookie, err := c.Ctx.Request.Cookie("BEEGO_FLASH"); err == nil {
v, _ := url.QueryUnescape(cookie.Value) v, _ := url.QueryUnescape(cookie.Value)
vals := strings.Split(v, "\x00") vals := strings.Split(v, "\x00")
for _, v := range vals { for _, v := range vals {
if len(v) > 0 { if len(v) > 0 {
kv := strings.Split(v, BEEGO_FLASH_SEP) kv := strings.Split(v, FlashSeperator)
if len(kv) == 2 { if len(kv) == 2 {
flash.Data[kv[0]] = kv[1] flash.Data[kv[0]] = kv[1]
} }
} }
} }
//read one time then delete it //read one time then delete it
c.Ctx.SetCookie("BEEGO_FLASH", "", -1, "/") c.Ctx.SetCookie(FlashName, "", -1, "/")
} }
c.Data["flash"] = flash.Data c.Data["flash"] = flash.Data
return flash return flash

40
flash_test.go Normal file
View File

@ -0,0 +1,40 @@
package beego
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
type TestFlashController struct {
Controller
}
func (this *TestFlashController) TestWriteFlash() {
flash := NewFlash()
flash.Notice("TestFlashString")
flash.Store(&this.Controller)
// we choose to serve json because we don't want to load a template html file
this.ServeJson(true)
}
func TestFlashHeader(t *testing.T) {
// create fake GET request
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
// setup the handler
handler := NewControllerRegistor()
handler.Add("/", &TestFlashController{}, "get:TestWriteFlash")
handler.ServeHTTP(w, r)
// get the Set-Cookie value
sc := w.Header().Get("Set-Cookie")
// match for the expected header
res := strings.Contains(sc, "BEEGO_FLASH=%00notice%23BEEGOFLASH%23TestFlashString%00")
// validate the assertion
if res != true {
t.Errorf("TestFlashHeader() unable to validate flash message")
}
}

13
log.go
View File

@ -22,12 +22,21 @@ func SetLevel(l int) {
BeeLogger.SetLevel(l) BeeLogger.SetLevel(l)
} }
func SetLogFuncCall(b bool) {
BeeLogger.EnableFuncCallDepth(b)
BeeLogger.SetLogFuncCallDepth(3)
}
// logger references the used application logger. // logger references the used application logger.
var BeeLogger *logs.BeeLogger var BeeLogger *logs.BeeLogger
// SetLogger sets a new logger. // SetLogger sets a new logger.
func SetLogger(adaptername string, config string) { func SetLogger(adaptername string, config string) error {
BeeLogger.SetLogger(adaptername, config) err := BeeLogger.SetLogger(adaptername, config)
if err != nil {
return err
}
return nil
} }
// Trace logs a message at trace level. // Trace logs a message at trace level.

View File

@ -6,6 +6,7 @@ import (
func TestConsole(t *testing.T) { func TestConsole(t *testing.T) {
log := NewLogger(10000) log := NewLogger(10000)
log.EnableFuncCallDepth(true)
log.SetLogger("console", "") log.SetLogger("console", "")
log.Trace("trace") log.Trace("trace")
log.Info("info") log.Info("info")
@ -23,6 +24,7 @@ func TestConsole(t *testing.T) {
func BenchmarkConsole(b *testing.B) { func BenchmarkConsole(b *testing.B) {
log := NewLogger(10000) log := NewLogger(10000)
log.EnableFuncCallDepth(true)
log.SetLogger("console", "") log.SetLogger("console", "")
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
log.Trace("trace") log.Trace("trace")

View File

@ -97,12 +97,12 @@ func (w *FileLogWriter) Init(jsonconfig string) error {
if len(w.Filename) == 0 { if len(w.Filename) == 0 {
return errors.New("jsonconfig must have filename") return errors.New("jsonconfig must have filename")
} }
err = w.StartLogger() err = w.startLogger()
return err return err
} }
// start file logger. create log file and set to locker-inside file writer. // start file logger. create log file and set to locker-inside file writer.
func (w *FileLogWriter) StartLogger() error { func (w *FileLogWriter) startLogger() error {
fd, err := w.createLogFile() fd, err := w.createLogFile()
if err != nil { if err != nil {
return err return err
@ -199,7 +199,7 @@ func (w *FileLogWriter) DoRotate() error {
} }
// re-start logger // re-start logger
err = w.StartLogger() err = w.startLogger()
if err != nil { if err != nil {
return fmt.Errorf("Rotate StartLogger: %s\n", err) return fmt.Errorf("Rotate StartLogger: %s\n", err)
} }

View File

@ -2,6 +2,8 @@ package logs
import ( import (
"fmt" "fmt"
"path"
"runtime"
"sync" "sync"
) )
@ -43,10 +45,12 @@ func Register(name string, log loggerType) {
// BeeLogger is default logger in beego application. // BeeLogger is default logger in beego application.
// it can contain several providers and log message into all providers. // it can contain several providers and log message into all providers.
type BeeLogger struct { type BeeLogger struct {
lock sync.Mutex lock sync.Mutex
level int level int
msg chan *logMsg enableFuncCallDepth bool
outputs map[string]LoggerInterface loggerFuncCallDepth int
msg chan *logMsg
outputs map[string]LoggerInterface
} }
type logMsg struct { type logMsg struct {
@ -59,10 +63,11 @@ type logMsg struct {
// if the buffering chan is full, logger adapters write to file or other way. // if the buffering chan is full, logger adapters write to file or other way.
func NewLogger(channellen int64) *BeeLogger { func NewLogger(channellen int64) *BeeLogger {
bl := new(BeeLogger) bl := new(BeeLogger)
bl.loggerFuncCallDepth = 2
bl.msg = make(chan *logMsg, channellen) bl.msg = make(chan *logMsg, channellen)
bl.outputs = make(map[string]LoggerInterface) bl.outputs = make(map[string]LoggerInterface)
//bl.SetLogger("console", "") // default output to console //bl.SetLogger("console", "") // default output to console
go bl.StartLogger() go bl.startLogger()
return bl return bl
} }
@ -73,7 +78,10 @@ func (bl *BeeLogger) SetLogger(adaptername string, config string) error {
defer bl.lock.Unlock() defer bl.lock.Unlock()
if log, ok := adapters[adaptername]; ok { if log, ok := adapters[adaptername]; ok {
lg := log() lg := log()
lg.Init(config) err := lg.Init(config)
if err != nil {
return err
}
bl.outputs[adaptername] = lg bl.outputs[adaptername] = lg
return nil return nil
} else { } else {
@ -100,7 +108,17 @@ func (bl *BeeLogger) writerMsg(loglevel int, msg string) error {
} }
lm := new(logMsg) lm := new(logMsg)
lm.level = loglevel lm.level = loglevel
lm.msg = msg if bl.enableFuncCallDepth {
_, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth)
if ok {
_, filename := path.Split(file)
lm.msg = fmt.Sprintf("[%s:%d] %s", filename, line, msg)
} else {
lm.msg = msg
}
} else {
lm.msg = msg
}
bl.msg <- lm bl.msg <- lm
return nil return nil
} }
@ -111,9 +129,19 @@ func (bl *BeeLogger) SetLevel(l int) {
bl.level = l bl.level = l
} }
// set log funcCallDepth
func (bl *BeeLogger) SetLogFuncCallDepth(d int) {
bl.loggerFuncCallDepth = d
}
// enable log funcCallDepth
func (bl *BeeLogger) EnableFuncCallDepth(b bool) {
bl.enableFuncCallDepth = b
}
// start logger chan reading. // start logger chan reading.
// when chan is full, write logs. // when chan is full, write logs.
func (bl *BeeLogger) StartLogger() { func (bl *BeeLogger) startLogger() {
for { for {
select { select {
case bm := <-bl.msg: case bm := <-bl.msg:

View File

@ -51,9 +51,16 @@ outFor:
continue continue
} }
switch v := arg.(type) { kind := val.Kind()
case []byte: if kind == reflect.Ptr {
case string: val = val.Elem()
kind = val.Kind()
arg = val.Interface()
}
switch kind {
case reflect.String:
v := val.String()
if fi != nil { if fi != nil {
if fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField { if fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField {
var t time.Time var t time.Time
@ -78,61 +85,66 @@ outFor:
} }
} }
arg = v arg = v
case time.Time: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if fi != nil && fi.fieldType == TypeDateField { arg = val.Int()
arg = v.In(tz).Format(format_Date) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
} else { arg = val.Uint()
arg = v.In(tz).Format(format_DateTime) case reflect.Float32:
} arg, _ = StrTo(ToStr(arg)).Float64()
default: case reflect.Float64:
kind := val.Kind() arg = val.Float()
switch kind { case reflect.Bool:
case reflect.Slice, reflect.Array: arg = val.Bool()
case reflect.Slice, reflect.Array:
var args []interface{} if _, ok := arg.([]byte); ok {
for i := 0; i < val.Len(); i++ {
v := val.Index(i)
var vu interface{}
if v.CanInterface() {
vu = v.Interface()
}
if vu == nil {
continue
}
args = append(args, vu)
}
if len(args) > 0 {
p := getFlatParams(fi, args, tz)
params = append(params, p...)
}
continue outFor continue outFor
}
case reflect.Ptr, reflect.Struct: var args []interface{}
ind := reflect.Indirect(val) for i := 0; i < val.Len(); i++ {
v := val.Index(i)
if ind.Kind() == reflect.Struct { var vu interface{}
typ := ind.Type() if v.CanInterface() {
name := getFullName(typ) vu = v.Interface()
var value interface{} }
if mmi, ok := modelCache.getByFN(name); ok {
if _, vu, exist := getExistPk(mmi, ind); exist {
value = vu
}
}
arg = value
if arg == nil { if vu == nil {
panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name)) continue
} }
args = append(args, vu)
}
if len(args) > 0 {
p := getFlatParams(fi, args, tz)
params = append(params, p...)
}
continue outFor
case reflect.Struct:
if v, ok := arg.(time.Time); ok {
if fi != nil && fi.fieldType == TypeDateField {
arg = v.In(tz).Format(format_Date)
} else { } else {
arg = ind.Interface() arg = v.In(tz).Format(format_DateTime)
}
} else {
typ := val.Type()
name := getFullName(typ)
var value interface{}
if mmi, ok := modelCache.getByFN(name); ok {
if _, vu, exist := getExistPk(mmi, val); exist {
value = vu
}
}
arg = value
if arg == nil {
panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name))
} }
} }
} }
params = append(params, arg) params = append(params, arg)
} }
return return

View File

@ -144,6 +144,45 @@ type DataNull struct {
NullInt64 sql.NullInt64 `orm:"null"` NullInt64 sql.NullInt64 `orm:"null"`
} }
type String string
type Boolean bool
type Byte byte
type Rune rune
type Int int
type Int8 int8
type Int16 int16
type Int32 int32
type Int64 int64
type Uint uint
type Uint8 uint8
type Uint16 uint16
type Uint32 uint32
type Uint64 uint64
type Float32 float64
type Float64 float64
type DataCustom struct {
Id int
Boolean Boolean
Char string `orm:"size(50)"`
Text string `orm:"type(text)"`
Byte Byte
Rune Rune
Int Int
Int8 Int8
Int16 Int16
Int32 Int32
Int64 Int64
Uint Uint
Uint8 Uint8
Uint16 Uint16
Uint32 Uint32
Uint64 Uint64
Float32 Float32
Float64 Float64
Decimal Float64 `orm:"digits(8);decimals(4)"`
}
// only for mysql // only for mysql
type UserBig struct { type UserBig struct {
Id uint64 Id uint64
@ -155,7 +194,7 @@ type User struct {
UserName string `orm:"size(30);unique"` UserName string `orm:"size(30);unique"`
Email string `orm:"size(100)"` Email string `orm:"size(100)"`
Password string `orm:"size(100)"` Password string `orm:"size(100)"`
Status int16 Status int16 `orm:"column(Status)"`
IsStaff bool IsStaff bool
IsActive bool `orm:"default(1)"` IsActive bool `orm:"default(1)"`
Created time.Time `orm:"auto_now_add;type(date)"` Created time.Time `orm:"auto_now_add;type(date)"`

View File

@ -80,7 +80,6 @@ func getTableUnique(val reflect.Value) [][]string {
// get snaked column name // get snaked column name
func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string { func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string {
col = strings.ToLower(col)
column := col column := col
if col == "" { if col == "" {
column = snakeString(sf.Name) column = snakeString(sf.Name)
@ -99,34 +98,41 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col
// return field type as type constant from reflect.Value // return field type as type constant from reflect.Value
func getFieldType(val reflect.Value) (ft int, err error) { func getFieldType(val reflect.Value) (ft int, err error) {
elm := reflect.Indirect(val) elm := reflect.Indirect(val)
switch elm.Interface().(type) { switch elm.Kind() {
case int8: case reflect.Int8:
ft = TypeBitField ft = TypeBitField
case int16: case reflect.Int16:
ft = TypeSmallIntegerField ft = TypeSmallIntegerField
case int32, int: case reflect.Int32, reflect.Int:
ft = TypeIntegerField ft = TypeIntegerField
case int64, sql.NullInt64: case reflect.Int64:
ft = TypeBigIntegerField ft = TypeBigIntegerField
case uint8: case reflect.Uint8:
ft = TypePositiveBitField ft = TypePositiveBitField
case uint16: case reflect.Uint16:
ft = TypePositiveSmallIntegerField ft = TypePositiveSmallIntegerField
case uint32, uint: case reflect.Uint32, reflect.Uint:
ft = TypePositiveIntegerField ft = TypePositiveIntegerField
case uint64: case reflect.Uint64:
ft = TypePositiveBigIntegerField ft = TypePositiveBigIntegerField
case float32, float64, sql.NullFloat64: case reflect.Float32, reflect.Float64:
ft = TypeFloatField ft = TypeFloatField
case bool, sql.NullBool: case reflect.Bool:
ft = TypeBooleanField ft = TypeBooleanField
case string, sql.NullString: case reflect.String:
ft = TypeCharField ft = TypeCharField
default: default:
if elm.CanInterface() { switch elm.Interface().(type) {
if _, ok := elm.Interface().(time.Time); ok { case sql.NullInt64:
ft = TypeDateTimeField ft = TypeBigIntegerField
} case sql.NullFloat64:
ft = TypeFloatField
case sql.NullBool:
ft = TypeBooleanField
case sql.NullString:
ft = TypeCharField
case time.Time:
ft = TypeDateTimeField
} }
} }
if ft&IsFieldType == 0 { if ft&IsFieldType == 0 {

View File

@ -149,7 +149,7 @@ func TestGetDB(t *testing.T) {
} }
func TestSyncDb(t *testing.T) { func TestSyncDb(t *testing.T) {
RegisterModel(new(Data), new(DataNull)) RegisterModel(new(Data), new(DataNull), new(DataCustom))
RegisterModel(new(User)) RegisterModel(new(User))
RegisterModel(new(Profile)) RegisterModel(new(Profile))
RegisterModel(new(Post)) RegisterModel(new(Post))
@ -165,7 +165,7 @@ func TestSyncDb(t *testing.T) {
} }
func TestRegisterModels(t *testing.T) { func TestRegisterModels(t *testing.T) {
RegisterModel(new(Data), new(DataNull)) RegisterModel(new(Data), new(DataNull), new(DataCustom))
RegisterModel(new(User)) RegisterModel(new(User))
RegisterModel(new(Profile)) RegisterModel(new(Profile))
RegisterModel(new(Post)) RegisterModel(new(Post))
@ -309,6 +309,39 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42)) throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42))
} }
func TestDataCustomTypes(t *testing.T) {
d := DataCustom{}
ind := reflect.Indirect(reflect.ValueOf(&d))
for name, value := range Data_Values {
e := ind.FieldByName(name)
if !e.IsValid() {
continue
}
e.Set(reflect.ValueOf(value).Convert(e.Type()))
}
id, err := dORM.Insert(&d)
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
d = DataCustom{Id: 1}
err = dORM.Read(&d)
throwFail(t, err)
ind = reflect.Indirect(reflect.ValueOf(&d))
for name, value := range Data_Values {
e := ind.FieldByName(name)
if !e.IsValid() {
continue
}
vu := e.Interface()
value = reflect.ValueOf(value).Convert(e.Type()).Interface()
throwFail(t, AssertIs(vu == value, true), value, vu)
}
}
func TestCRUD(t *testing.T) { func TestCRUD(t *testing.T) {
profile := NewProfile() profile := NewProfile()
profile.Age = 30 profile.Age = 30
@ -562,6 +595,10 @@ func TestOperators(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name__exact", String("slene")).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name__exact", "slene").Count() num, err = qs.Filter("user_name__exact", "slene").Count()
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
@ -602,11 +639,11 @@ func TestOperators(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 3)) throwFail(t, AssertIs(num, 3))
num, err = qs.Filter("status__lt", 3).Count() num, err = qs.Filter("status__lt", Uint(3)).Count()
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 2)) throwFail(t, AssertIs(num, 2))
num, err = qs.Filter("status__lte", 3).Count() num, err = qs.Filter("status__lte", Int(3)).Count()
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 3)) throwFail(t, AssertIs(num, 3))
@ -1380,7 +1417,7 @@ func TestRawQueryRow(t *testing.T) {
) )
cols = []string{ cols = []string{
"id", "status", "profile_id", "id", "Status", "profile_id",
} }
query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s WHERE id = ?", Q, strings.Join(cols, sep), Q, Q, Q) query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s WHERE id = ?", Q, strings.Join(cols, sep), Q, Q, Q)
err = dORM.Raw(query, 4).QueryRow(&uid, &status, &pid) err = dORM.Raw(query, 4).QueryRow(&uid, &status, &pid)
@ -1460,7 +1497,7 @@ func TestRawValues(t *testing.T) {
Q := dDbBaser.TableQuote() Q := dDbBaser.TableQuote()
var maps []Params var maps []Params
query := fmt.Sprintf("SELECT %suser_name%s FROM %suser%s WHERE %sstatus%s = ?", Q, Q, Q, Q, Q, Q) query := fmt.Sprintf("SELECT %suser_name%s FROM %suser%s WHERE %sStatus%s = ?", Q, Q, Q, Q, Q, Q)
num, err := dORM.Raw(query, 1).Values(&maps) num, err := dORM.Raw(query, 1).Values(&maps)
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))

View File

@ -44,6 +44,11 @@ var (
"GetControllerAndAction"} "GetControllerAndAction"}
) )
// To append a slice's value into "exceptMethod", for controller's methods shouldn't reflect to AutoRouter
func ExceptMethodAppend(action string) {
exceptMethod = append(exceptMethod, action)
}
type controllerInfo struct { type controllerInfo struct {
pattern string pattern string
regex *regexp.Regexp regex *regexp.Regexp
@ -621,29 +626,37 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
context.Input.Body() context.Input.Body()
} }
if context.Input.RunController != nil && context.Input.RunMethod != "" {
findrouter = true
runMethod = context.Input.RunMethod
runrouter = context.Input.RunController
}
//first find path from the fixrouters to Improve Performance //first find path from the fixrouters to Improve Performance
for _, route := range p.fixrouters { if !findrouter {
n := len(requestPath) for _, route := range p.fixrouters {
if requestPath == route.pattern { n := len(requestPath)
runMethod = p.getRunMethod(r.Method, context, route) if requestPath == route.pattern {
if runMethod != "" { runMethod = p.getRunMethod(r.Method, context, route)
runrouter = route.controllerType if runMethod != "" {
findrouter = true runrouter = route.controllerType
break findrouter = true
break
}
} }
} // pattern /admin url /admin 200 /admin/ 200
// pattern /admin url /admin 200 /admin/ 200 // pattern /admin/ url /admin 301 /admin/ 200
// pattern /admin/ url /admin 301 /admin/ 200 if requestPath[n-1] != '/' && requestPath+"/" == route.pattern {
if requestPath[n-1] != '/' && requestPath+"/" == route.pattern { http.Redirect(w, r, requestPath+"/", 301)
http.Redirect(w, r, requestPath+"/", 301) goto Admin
goto Admin }
} if requestPath[n-1] == '/' && route.pattern+"/" == requestPath {
if requestPath[n-1] == '/' && route.pattern+"/" == requestPath { runMethod = p.getRunMethod(r.Method, context, route)
runMethod = p.getRunMethod(r.Method, context, route) if runMethod != "" {
if runMethod != "" { runrouter = route.controllerType
runrouter = route.controllerType findrouter = true
findrouter = true break
break }
} }
} }
} }

View File

@ -118,6 +118,7 @@ func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error
if err != nil { if err != nil {
return err return err
} }
pder.maxlifetime = maxlifetime
return nil return nil
} }

227
session/sess_postgresql.go Normal file
View File

@ -0,0 +1,227 @@
package session
/*
beego session provider for postgresql
-------------------------------------
depends on github.com/lib/pq:
go install github.com/lib/pq
needs this table in your database:
CREATE TABLE session (
session_key char(64) NOT NULL,
session_data bytea,
session_expiry timestamp NOT NULL,
CONSTRAINT session_key PRIMARY KEY(session_key)
);
will be activated with these settings in app.conf:
SessionOn = true
SessionProvider = postgresql
SessionSavePath = "user=a password=b dbname=c sslmode=disable"
SessionName = session
*/
import (
"database/sql"
"net/http"
"sync"
"time"
_ "github.com/lib/pq"
)
var postgresqlpder = &PostgresqlProvider{}
// postgresql session store
type PostgresqlSessionStore struct {
c *sql.DB
sid string
lock sync.RWMutex
values map[interface{}]interface{}
}
// set value in postgresql session.
// it is temp value in map.
func (st *PostgresqlSessionStore) Set(key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values[key] = value
return nil
}
// get value from postgresql session
func (st *PostgresqlSessionStore) Get(key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.values[key]; ok {
return v
} else {
return nil
}
return nil
}
// delete value in postgresql session
func (st *PostgresqlSessionStore) Delete(key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.values, key)
return nil
}
// clear all values in postgresql session
func (st *PostgresqlSessionStore) Flush() error {
st.lock.Lock()
defer st.lock.Unlock()
st.values = make(map[interface{}]interface{})
return nil
}
// get session id of this postgresql session store
func (st *PostgresqlSessionStore) SessionID() string {
return st.sid
}
// save postgresql session values to database.
// must call this method to save values to database.
func (st *PostgresqlSessionStore) SessionRelease(w http.ResponseWriter) {
defer st.c.Close()
b, err := encodeGob(st.values)
if err != nil {
return
}
st.c.Exec("UPDATE session set session_data=$1, session_expiry=$2 where session_key=$3",
b, time.Now().Format(time.RFC3339), st.sid)
}
// postgresql session provider
type PostgresqlProvider struct {
maxlifetime int64
savePath string
}
// connect to postgresql
func (mp *PostgresqlProvider) connectInit() *sql.DB {
db, e := sql.Open("postgres", mp.savePath)
if e != nil {
return nil
}
return db
}
// init postgresql session.
// savepath is the connection string of postgresql.
func (mp *PostgresqlProvider) SessionInit(maxlifetime int64, savePath string) error {
mp.maxlifetime = maxlifetime
mp.savePath = savePath
return nil
}
// get postgresql session by sid
func (mp *PostgresqlProvider) SessionRead(sid string) (SessionStore, error) {
c := mp.connectInit()
row := c.QueryRow("select session_data from session where session_key=$1", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
_, err = c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)",
sid, "", time.Now().Format(time.RFC3339))
if err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}
var kv map[interface{}]interface{}
if len(sessiondata) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = decodeGob(sessiondata)
if err != nil {
return nil, err
}
}
rs := &PostgresqlSessionStore{c: c, sid: sid, values: kv}
return rs, nil
}
// check postgresql session exist
func (mp *PostgresqlProvider) SessionExist(sid string) bool {
c := mp.connectInit()
defer c.Close()
row := c.QueryRow("select session_data from session where session_key=$1", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
return false
} else {
return true
}
}
// generate new sid for postgresql session
func (mp *PostgresqlProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
c := mp.connectInit()
row := c.QueryRow("select session_data from session where session_key=$1", oldsid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)",
oldsid, "", time.Now().Format(time.RFC3339))
}
c.Exec("update session set session_key=$1 where session_key=$2", sid, oldsid)
var kv map[interface{}]interface{}
if len(sessiondata) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = decodeGob(sessiondata)
if err != nil {
return nil, err
}
}
rs := &PostgresqlSessionStore{c: c, sid: sid, values: kv}
return rs, nil
}
// delete postgresql session by sid
func (mp *PostgresqlProvider) SessionDestroy(sid string) error {
c := mp.connectInit()
c.Exec("DELETE FROM session where session_key=$1", sid)
c.Close()
return nil
}
// delete expired values in postgresql session
func (mp *PostgresqlProvider) SessionGC() {
c := mp.connectInit()
c.Exec("DELETE from session where EXTRACT(EPOCH FROM (current_timestamp - session_expiry)) > $1", mp.maxlifetime)
c.Close()
return
}
// count values in postgresql session
func (mp *PostgresqlProvider) SessionAll() int {
c := mp.connectInit()
defer c.Close()
var total int
err := c.QueryRow("SELECT count(*) as num from session").Scan(&total)
if err != nil {
return 0
}
return total
}
func init() {
Register("postgresql", postgresqlpder)
}

View File

@ -151,7 +151,7 @@ func getTplDeep(root, file, parent string, t *template.Template) (*template.Temp
fileabspath = filepath.Join(root, file) fileabspath = filepath.Join(root, file)
} }
if e := utils.FileExists(fileabspath); !e { if e := utils.FileExists(fileabspath); !e {
panic("can't find template file" + file) panic("can't find template file:" + file)
} }
data, err := ioutil.ReadFile(fileabspath) data, err := ioutil.ReadFile(fileabspath)
if err != nil { if err != nil {

View File

@ -67,7 +67,7 @@ func (r Required) IsSatisfied(obj interface{}) bool {
} }
func (r Required) DefaultMessage() string { func (r Required) DefaultMessage() string {
return "Required" return fmt.Sprint(MessageTmpls["Required"])
} }
func (r Required) GetKey() string { func (r Required) GetKey() string {