diff --git a/.go_style b/.go_style deleted file mode 100644 index 26682eed..00000000 --- a/.go_style +++ /dev/null @@ -1,16 +0,0 @@ -{ - "file_line": 500, - "func_line": 80, - "params_num":4, - "results_num":3, - "formated": true, - "pkg_name": true, - "camel_name":true, - "ignore":[ - "a/*", - "b/*/c/*.go" - ], - "fatal":[ - "formated" - ] -} diff --git a/admin.go b/admin.go index b860d268..d918b595 100644 --- a/admin.go +++ b/admin.go @@ -113,8 +113,6 @@ func listConf(rw http.ResponseWriter, r *http.Request) { m["SessionName"] = SessionName m["SessionGCMaxLifetime"] = SessionGCMaxLifetime m["SessionSavePath"] = SessionSavePath - m["SessionHashFunc"] = SessionHashFunc - m["SessionHashKey"] = SessionHashKey m["SessionCookieLifeTime"] = SessionCookieLifeTime m["UseFcgi"] = UseFcgi m["MaxMemory"] = MaxMemory @@ -458,6 +456,7 @@ func (admin *adminApp) Run() { for p, f := range admin.routers { http.Handle(p, f) } + BeeLogger.Info("Admin server Running on %s", addr) err := http.ListenAndServe(addr, nil) if err != nil { BeeLogger.Critical("Admin ListenAndServe: ", err) diff --git a/app.go b/app.go index f1706616..d155c531 100644 --- a/app.go +++ b/app.go @@ -20,13 +20,8 @@ import ( "net/http" "net/http/fcgi" "time" - - "github.com/astaxie/beego/context" ) -// FilterFunc defines filter function type. -type FilterFunc func(*context.Context) - // App defines beego application with a new PatternServeMux. type App struct { Handlers *ControllerRegistor @@ -48,8 +43,6 @@ func (app *App) Run() { addr = fmt.Sprintf("%s:%d", HttpAddr, HttpPort) } - BeeLogger.Info("Running on %s", addr) - var ( err error l net.Listener @@ -57,15 +50,24 @@ func (app *App) Run() { endRunning := make(chan bool, 1) if UseFcgi { - if HttpPort == 0 { - l, err = net.Listen("unix", addr) + if UseStdIo { + err = fcgi.Serve(nil, app.Handlers) // standard I/O + if err == nil { + BeeLogger.Info("Use FCGI via standard I/O") + } else { + BeeLogger.Info("Cannot use FCGI via standard I/O", err) + } } else { - l, err = net.Listen("tcp", addr) + if HttpPort == 0 { + l, err = net.Listen("unix", addr) + } else { + l, err = net.Listen("tcp", addr) + } + if err != nil { + BeeLogger.Critical("Listen: ", err) + } + err = fcgi.Serve(l, app.Handlers) } - if err != nil { - BeeLogger.Critical("Listen: ", err) - } - err = fcgi.Serve(l, app.Handlers) } else { app.Server.Addr = addr app.Server.Handler = app.Handlers @@ -78,6 +80,7 @@ func (app *App) Run() { if HttpsPort != 0 { app.Server.Addr = fmt.Sprintf("%s:%d", HttpAddr, HttpsPort) } + BeeLogger.Info("https server Running on %s", app.Server.Addr) err := app.Server.ListenAndServeTLS(HttpCertFile, HttpKeyFile) if err != nil { BeeLogger.Critical("ListenAndServeTLS: ", err) @@ -90,11 +93,29 @@ func (app *App) Run() { if EnableHttpListen { go func() { app.Server.Addr = addr - err := app.Server.ListenAndServe() - if err != nil { - BeeLogger.Critical("ListenAndServe: ", err) - time.Sleep(100 * time.Microsecond) - endRunning <- true + BeeLogger.Info("http server Running on %s", app.Server.Addr) + if ListenTCP4 && HttpAddr == "" { + ln, err := net.Listen("tcp4", app.Server.Addr) + if err != nil { + BeeLogger.Critical("ListenAndServe: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + return + } + err = app.Server.Serve(ln) + if err != nil { + BeeLogger.Critical("ListenAndServe: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + return + } + } else { + err := app.Server.ListenAndServe() + if err != nil { + BeeLogger.Critical("ListenAndServe: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + } } }() } diff --git a/beego.go b/beego.go index 1277a2a8..ab70a182 100644 --- a/beego.go +++ b/beego.go @@ -38,7 +38,7 @@ import ( ) // beego web framework version. -const VERSION = "1.4.1" +const VERSION = "1.4.2" type hookfunc func() error //hook function to run var hooks []hookfunc //hook function slice to store the hookfunc @@ -308,15 +308,20 @@ func SetStaticPath(url string, path string) *App { // DelStaticPath removes the static folder setting in this url pattern in beego application. func DelStaticPath(url string) *App { + if !strings.HasPrefix(url, "/") { + url = "/" + url + } + url = strings.TrimRight(url, "/") delete(StaticDir, url) return BeeApp } // InsertFilter adds a FilterFunc with pattern condition and action constant. // The pos means action constant including -// beego.BeforeRouter, beego.AfterStatic, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. -func InsertFilter(pattern string, pos int, filter FilterFunc) *App { - BeeApp.Handlers.InsertFilter(pattern, pos, filter) +// beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. +// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) +func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App { + BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...) return BeeApp } @@ -359,6 +364,9 @@ func initBeforeHttpRun() { } } + //init mime + AddAPPStartHook(initMime) + // do hooks function for _, hk := range hooks { err := hk() @@ -373,10 +381,8 @@ func initBeforeHttpRun() { if sessionConfig == "" { sessionConfig = `{"cookieName":"` + SessionName + `",` + `"gclifetime":` + strconv.FormatInt(SessionGCMaxLifetime, 10) + `,` + - `"providerConfig":"` + SessionSavePath + `",` + + `"providerConfig":"` + filepath.ToSlash(SessionSavePath) + `",` + `"secure":` + strconv.FormatBool(EnableHttpTLS) + `,` + - `"sessionIDHashFunc":"` + SessionHashFunc + `",` + - `"sessionIDHashKey":"` + SessionHashKey + `",` + `"enableSetCookie":` + strconv.FormatBool(SessionAutoSetCookie) + `,` + `"domain":"` + SessionDomain + `",` + `"cookieLifeTime":` + strconv.Itoa(SessionCookieLifeTime) + `}` @@ -404,9 +410,6 @@ func initBeforeHttpRun() { Get("/docs", serverDocs) Get("/docs/*", serverDocs) } - - //init mime - AddAPPStartHook(initMime) } // this function is for test package init diff --git a/cache/cache.go b/cache/cache.go index d4ad5d59..ddb2f857 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -81,13 +81,13 @@ func Register(name string, adapter Cache) { // Create a new cache driver by adapter name and config string. // config need to be correct JSON as string: {"interval":360}. // it will start gc automatically. -func NewCache(adapterName, config string) (adapter Cache, e error) { +func NewCache(adapterName, config string) (adapter Cache, err error) { adapter, ok := adapters[adapterName] if !ok { - e = fmt.Errorf("cache: unknown adapter name %q (forgot to import?)", adapterName) + err = fmt.Errorf("cache: unknown adapter name %q (forgot to import?)", adapterName) return } - err := adapter.StartAndGC(config) + err = adapter.StartAndGC(config) if err != nil { adapter = nil } diff --git a/cache/redis/redis.go b/cache/redis/redis.go index 35cf88cd..b205545d 100644 --- a/cache/redis/redis.go +++ b/cache/redis/redis.go @@ -75,14 +75,13 @@ func (rc *RedisCache) Get(key string) interface{} { // put cache to redis. func (rc *RedisCache) Put(key string, val interface{}, timeout int64) error { var err error - if _, err = rc.do("SET", key, val); err != nil { + if _, err = rc.do("SETEX", key, timeout, val); err != nil { return err } if _, err = rc.do("HSET", rc.key, key, true); err != nil { return err } - _, err = rc.do("EXPIRE", key, timeout) return err } diff --git a/config.go b/config.go index db234e16..f4d0aba2 100644 --- a/config.go +++ b/config.go @@ -15,7 +15,6 @@ package beego import ( - "errors" "fmt" "html/template" "os" @@ -41,6 +40,7 @@ var ( EnableHttpListen bool HttpAddr string HttpPort int + ListenTCP4 bool EnableHttpTLS bool HttpsPort int HttpCertFile string @@ -48,20 +48,19 @@ var ( RecoverPanic bool // flag of auto recover panic AutoRender bool // flag of render template automatically ViewsPath string - RunMode string // run mode, "dev" or "prod" - AppConfig config.ConfigContainer + AppConfig *beegoAppConfig + RunMode string // run mode, "dev" or "prod" GlobalSessions *session.Manager // global session mananger SessionOn bool // flag of starting session auto. default is false. SessionProvider string // default session provider, memory, mysql , redis ,etc. SessionName string // the cookie name when saving session id into cookie. SessionGCMaxLifetime int64 // session gc time for auto cleaning expired session. SessionSavePath string // if use mysql/redis/file provider, define save path to connection info. - SessionHashFunc string // session hash generation func. - SessionHashKey string // session hash salt string. SessionCookieLifeTime int // the life time of session id in cookie. SessionAutoSetCookie bool // auto setcookie SessionDomain string // the cookie domain default is empty UseFcgi bool + UseStdIo bool MaxMemory int64 EnableGzip bool // flag of enable gzip DirectoryIndex bool // flag of display directory index. default is false. @@ -81,8 +80,110 @@ var ( FlashSeperator string // used to seperate flash key:value AppConfigProvider string // config provider EnableDocs bool // enable generate docs & server docs API Swagger + RouterCaseSensitive bool // router case sensitive default is true ) +type beegoAppConfig struct { + innerConfig config.ConfigContainer +} + +func newAppConfig(AppConfigProvider, AppConfigPath string) (*beegoAppConfig, error) { + ac, err := config.NewConfig(AppConfigProvider, AppConfigPath) + if err != nil { + return nil, err + } + rac := &beegoAppConfig{ac} + return rac, nil +} + +func (b *beegoAppConfig) Set(key, val string) error { + return b.innerConfig.Set(key, val) +} + +func (b *beegoAppConfig) String(key string) string { + v := b.innerConfig.String(RunMode + "::" + key) + if v == "" { + return b.innerConfig.String(key) + } + return v +} + +func (b *beegoAppConfig) Strings(key string) []string { + v := b.innerConfig.Strings(RunMode + "::" + key) + if len(v) == 0 { + return b.innerConfig.Strings(key) + } + return v +} + +func (b *beegoAppConfig) Int(key string) (int, error) { + v, err := b.innerConfig.Int(RunMode + "::" + key) + if err != nil { + return b.innerConfig.Int(key) + } + return v, nil +} + +func (b *beegoAppConfig) Int64(key string) (int64, error) { + v, err := b.innerConfig.Int64(RunMode + "::" + key) + if err != nil { + return b.innerConfig.Int64(key) + } + return v, nil +} + +func (b *beegoAppConfig) Bool(key string) (bool, error) { + v, err := b.innerConfig.Bool(RunMode + "::" + key) + if err != nil { + return b.innerConfig.Bool(key) + } + return v, nil +} + +func (b *beegoAppConfig) Float(key string) (float64, error) { + v, err := b.innerConfig.Float(RunMode + "::" + key) + if err != nil { + return b.innerConfig.Float(key) + } + return v, nil +} + +func (b *beegoAppConfig) DefaultString(key string, defaultval string) string { + return b.innerConfig.DefaultString(key, defaultval) +} + +func (b *beegoAppConfig) DefaultStrings(key string, defaultval []string) []string { + return b.innerConfig.DefaultStrings(key, defaultval) +} + +func (b *beegoAppConfig) DefaultInt(key string, defaultval int) int { + return b.innerConfig.DefaultInt(key, defaultval) +} + +func (b *beegoAppConfig) DefaultInt64(key string, defaultval int64) int64 { + return b.innerConfig.DefaultInt64(key, defaultval) +} + +func (b *beegoAppConfig) DefaultBool(key string, defaultval bool) bool { + return b.innerConfig.DefaultBool(key, defaultval) +} + +func (b *beegoAppConfig) DefaultFloat(key string, defaultval float64) float64 { + return b.innerConfig.DefaultFloat(key, defaultval) +} + +func (b *beegoAppConfig) DIY(key string) (interface{}, error) { + return b.innerConfig.DIY(key) +} + +func (b *beegoAppConfig) GetSection(section string) (map[string]string, error) { + return b.innerConfig.GetSection(section) +} + +func (b *beegoAppConfig) SaveConfigFile(filename string) error { + return b.innerConfig.SaveConfigFile(filename) +} + func init() { // create beego application BeeApp = NewApp() @@ -134,12 +235,11 @@ func init() { SessionName = "beegosessionID" SessionGCMaxLifetime = 3600 SessionSavePath = "" - SessionHashFunc = "sha1" - SessionHashKey = "beegoserversessionkey" SessionCookieLifeTime = 0 //set cookie default is the brower life SessionAutoSetCookie = true UseFcgi = false + UseStdIo = false MaxMemory = 1 << 26 //64MB @@ -164,6 +264,8 @@ func init() { FlashName = "BEEGO_FLASH" FlashSeperator = "BEEGOFLASH" + RouterCaseSensitive = true + runtime.GOMAXPROCS(runtime.NumCPU()) // init BeeLogger @@ -172,262 +274,211 @@ func init() { if err != nil { fmt.Println("init console log error:", err) } + SetLogFuncCall(true) err = ParseConfig() if err != nil && !os.IsNotExist(err) { // for init if doesn't have app.conf will not panic - Info(err) + ac := config.NewFakeConfig() + AppConfig = &beegoAppConfig{ac} + Warning(err) } } // ParseConfig parsed default config file. // now only support ini, next will support json. func ParseConfig() (err error) { - AppConfig, err = config.NewConfig(AppConfigProvider, AppConfigPath) + AppConfig, err = newAppConfig(AppConfigProvider, AppConfigPath) if err != nil { - AppConfig = config.NewFakeConfig() return err - } else { + } + envRunMode := os.Getenv("BEEGO_RUNMODE") + // set the runmode first + if envRunMode != "" { + RunMode = envRunMode + } else if runmode := AppConfig.String("RunMode"); runmode != "" { + RunMode = runmode + } - if v, err := GetConfig("string", "HttpAddr"); err == nil { - HttpAddr = v.(string) + HttpAddr = AppConfig.String("HttpAddr") + + if v, err := AppConfig.Int("HttpPort"); err == nil { + HttpPort = v + } + + if v, err := AppConfig.Bool("ListenTCP4"); err == nil { + ListenTCP4 = v + } + + if v, err := AppConfig.Bool("EnableHttpListen"); err == nil { + EnableHttpListen = v + } + + if maxmemory, err := AppConfig.Int64("MaxMemory"); err == nil { + MaxMemory = maxmemory + } + + if appname := AppConfig.String("AppName"); appname != "" { + AppName = appname + } + + if autorender, err := AppConfig.Bool("AutoRender"); err == nil { + AutoRender = autorender + } + + if autorecover, err := AppConfig.Bool("RecoverPanic"); err == nil { + RecoverPanic = autorecover + } + + if views := AppConfig.String("ViewsPath"); views != "" { + ViewsPath = views + } + + if sessionon, err := AppConfig.Bool("SessionOn"); err == nil { + SessionOn = sessionon + } + + if sessProvider := AppConfig.String("SessionProvider"); sessProvider != "" { + SessionProvider = sessProvider + } + + if sessName := AppConfig.String("SessionName"); sessName != "" { + SessionName = sessName + } + + if sesssavepath := AppConfig.String("SessionSavePath"); sesssavepath != "" { + SessionSavePath = sesssavepath + } + + if sessMaxLifeTime, err := AppConfig.Int64("SessionGCMaxLifetime"); err == nil && sessMaxLifeTime != 0 { + SessionGCMaxLifetime = sessMaxLifeTime + } + + if sesscookielifetime, err := AppConfig.Int("SessionCookieLifeTime"); err == nil && sesscookielifetime != 0 { + SessionCookieLifeTime = sesscookielifetime + } + + if usefcgi, err := AppConfig.Bool("UseFcgi"); err == nil { + UseFcgi = usefcgi + } + + if enablegzip, err := AppConfig.Bool("EnableGzip"); err == nil { + EnableGzip = enablegzip + } + + if directoryindex, err := AppConfig.Bool("DirectoryIndex"); err == nil { + DirectoryIndex = directoryindex + } + + if timeout, err := AppConfig.Int64("HttpServerTimeOut"); err == nil { + HttpServerTimeOut = timeout + } + + if errorsshow, err := AppConfig.Bool("ErrorsShow"); err == nil { + ErrorsShow = errorsshow + } + + if copyrequestbody, err := AppConfig.Bool("CopyRequestBody"); err == nil { + CopyRequestBody = copyrequestbody + } + + if xsrfkey := AppConfig.String("XSRFKEY"); xsrfkey != "" { + XSRFKEY = xsrfkey + } + + if enablexsrf, err := AppConfig.Bool("EnableXSRF"); err == nil { + EnableXSRF = enablexsrf + } + + if expire, err := AppConfig.Int("XSRFExpire"); err == nil { + XSRFExpire = expire + } + + if tplleft := AppConfig.String("TemplateLeft"); tplleft != "" { + TemplateLeft = tplleft + } + + if tplright := AppConfig.String("TemplateRight"); tplright != "" { + TemplateRight = tplright + } + + if httptls, err := AppConfig.Bool("EnableHttpTLS"); err == nil { + EnableHttpTLS = httptls + } + + if httpsport, err := AppConfig.Int("HttpsPort"); err == nil { + HttpsPort = httpsport + } + + if certfile := AppConfig.String("HttpCertFile"); certfile != "" { + HttpCertFile = certfile + } + + if keyfile := AppConfig.String("HttpKeyFile"); keyfile != "" { + HttpKeyFile = keyfile + } + + if serverName := AppConfig.String("BeegoServerName"); serverName != "" { + BeegoServerName = serverName + } + + if flashname := AppConfig.String("FlashName"); flashname != "" { + FlashName = flashname + } + + if flashseperator := AppConfig.String("FlashSeperator"); flashseperator != "" { + FlashSeperator = flashseperator + } + + if sd := AppConfig.String("StaticDir"); sd != "" { + for k := range StaticDir { + delete(StaticDir, k) } - - if v, err := GetConfig("int", "HttpPort"); err == nil { - HttpPort = v.(int) - } - - if v, err := GetConfig("bool", "EnableHttpListen"); err == nil { - EnableHttpListen = v.(bool) - } - - if maxmemory, err := GetConfig("int64", "MaxMemory"); err == nil { - MaxMemory = maxmemory.(int64) - } - - if appname, _ := GetConfig("string", "AppName"); appname != "" { - AppName = appname.(string) - } - - if runmode, _ := GetConfig("string", "RunMode"); runmode != "" { - RunMode = runmode.(string) - } - - if autorender, err := GetConfig("bool", "AutoRender"); err == nil { - AutoRender = autorender.(bool) - } - - if autorecover, err := GetConfig("bool", "RecoverPanic"); err == nil { - RecoverPanic = autorecover.(bool) - } - - if views, _ := GetConfig("string", "ViewsPath"); views != "" { - ViewsPath = views.(string) - } - - if sessionon, err := GetConfig("bool", "SessionOn"); err == nil { - SessionOn = sessionon.(bool) - } - - if sessProvider, _ := GetConfig("string", "SessionProvider"); sessProvider != "" { - SessionProvider = sessProvider.(string) - } - - if sessName, _ := GetConfig("string", "SessionName"); sessName != "" { - SessionName = sessName.(string) - } - - if sesssavepath, _ := GetConfig("string", "SessionSavePath"); sesssavepath != "" { - SessionSavePath = sesssavepath.(string) - } - - if sesshashfunc, _ := GetConfig("string", "SessionHashFunc"); sesshashfunc != "" { - SessionHashFunc = sesshashfunc.(string) - } - - if sesshashkey, _ := GetConfig("string", "SessionHashKey"); sesshashkey != "" { - SessionHashKey = sesshashkey.(string) - } - - if sessMaxLifeTime, err := GetConfig("int64", "SessionGCMaxLifetime"); err == nil && sessMaxLifeTime != 0 { - SessionGCMaxLifetime = sessMaxLifeTime.(int64) - } - - if sesscookielifetime, err := GetConfig("int", "SessionCookieLifeTime"); err == nil && sesscookielifetime != 0 { - SessionCookieLifeTime = sesscookielifetime.(int) - } - - if usefcgi, err := GetConfig("bool", "UseFcgi"); err == nil { - UseFcgi = usefcgi.(bool) - } - - if enablegzip, err := GetConfig("bool", "EnableGzip"); err == nil { - EnableGzip = enablegzip.(bool) - } - - if directoryindex, err := GetConfig("bool", "DirectoryIndex"); err == nil { - DirectoryIndex = directoryindex.(bool) - } - - if timeout, err := GetConfig("int64", "HttpServerTimeOut"); err == nil { - HttpServerTimeOut = timeout.(int64) - } - - if errorsshow, err := GetConfig("bool", "ErrorsShow"); err == nil { - ErrorsShow = errorsshow.(bool) - } - - if copyrequestbody, err := GetConfig("bool", "CopyRequestBody"); err == nil { - CopyRequestBody = copyrequestbody.(bool) - } - - if xsrfkey, _ := GetConfig("string", "XSRFKEY"); xsrfkey != "" { - XSRFKEY = xsrfkey.(string) - } - - if enablexsrf, err := GetConfig("bool", "EnableXSRF"); err == nil { - EnableXSRF = enablexsrf.(bool) - } - - if expire, err := GetConfig("int", "XSRFExpire"); err == nil { - XSRFExpire = expire.(int) - } - - if tplleft, _ := GetConfig("string", "TemplateLeft"); tplleft != "" { - TemplateLeft = tplleft.(string) - } - - if tplright, _ := GetConfig("string", "TemplateRight"); tplright != "" { - TemplateRight = tplright.(string) - } - - if httptls, err := GetConfig("bool", "EnableHttpTLS"); err == nil { - EnableHttpTLS = httptls.(bool) - } - - if httpsport, err := GetConfig("int", "HttpsPort"); err == nil { - HttpsPort = httpsport.(int) - } - - if certfile, _ := GetConfig("string", "HttpCertFile"); certfile != "" { - HttpCertFile = certfile.(string) - } - - if keyfile, _ := GetConfig("string", "HttpKeyFile"); keyfile != "" { - HttpKeyFile = keyfile.(string) - } - - if serverName, _ := GetConfig("string", "BeegoServerName"); serverName != "" { - BeegoServerName = serverName.(string) - } - - if flashname, _ := GetConfig("string", "FlashName"); flashname != "" { - FlashName = flashname.(string) - } - - if flashseperator, _ := GetConfig("string", "FlashSeperator"); flashseperator != "" { - FlashSeperator = flashseperator.(string) - } - - if sd, _ := GetConfig("string", "StaticDir"); sd != "" { - for k := range StaticDir { - delete(StaticDir, k) + sds := strings.Fields(sd) + for _, v := range sds { + if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 { + StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[1] + } else { + StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[0] } - sds := strings.Fields(sd.(string)) - for _, v := range sds { - if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 { - StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[1] - } else { - StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[0] + } + } + + if sgz := AppConfig.String("StaticExtensionsToGzip"); sgz != "" { + extensions := strings.Split(sgz, ",") + if len(extensions) > 0 { + StaticExtensionsToGzip = []string{} + for _, ext := range extensions { + if len(ext) == 0 { + continue } - } - } - - if sgz, _ := GetConfig("string", "StaticExtensionsToGzip"); sgz != "" { - extensions := strings.Split(sgz.(string), ",") - if len(extensions) > 0 { - StaticExtensionsToGzip = []string{} - for _, ext := range extensions { - if len(ext) == 0 { - continue - } - extWithDot := ext - if extWithDot[:1] != "." { - extWithDot = "." + extWithDot - } - StaticExtensionsToGzip = append(StaticExtensionsToGzip, extWithDot) + extWithDot := ext + if extWithDot[:1] != "." { + extWithDot = "." + extWithDot } + StaticExtensionsToGzip = append(StaticExtensionsToGzip, extWithDot) } } + } - if enableadmin, err := GetConfig("bool", "EnableAdmin"); err == nil { - EnableAdmin = enableadmin.(bool) - } + if enableadmin, err := AppConfig.Bool("EnableAdmin"); err == nil { + EnableAdmin = enableadmin + } - if adminhttpaddr, _ := GetConfig("string", "AdminHttpAddr"); adminhttpaddr != "" { - AdminHttpAddr = adminhttpaddr.(string) - } + if adminhttpaddr := AppConfig.String("AdminHttpAddr"); adminhttpaddr != "" { + AdminHttpAddr = adminhttpaddr + } - if adminhttpport, err := GetConfig("int", "AdminHttpPort"); err == nil { - AdminHttpPort = adminhttpport.(int) - } + if adminhttpport, err := AppConfig.Int("AdminHttpPort"); err == nil { + AdminHttpPort = adminhttpport + } - if enabledocs, err := GetConfig("bool", "EnableDocs"); err == nil { - EnableDocs = enabledocs.(bool) - } + if enabledocs, err := AppConfig.Bool("EnableDocs"); err == nil { + EnableDocs = enabledocs + } + + if casesensitive, err := AppConfig.Bool("RouterCaseSensitive"); err == nil { + RouterCaseSensitive = casesensitive } return nil } - -// Getconfig throw the Runmode -// [dev] -// name = astaixe -// IsEnable = false -// [prod] -// name = slene -// IsEnable = true -// -// usage: -// GetConfig("string", "name") -// GetConfig("bool", "IsEnable") -func GetConfig(typ, key string) (interface{}, error) { - switch typ { - case "string": - v := AppConfig.String(RunMode + "::" + key) - if v == "" { - v = AppConfig.String(key) - } - return v, nil - case "strings": - v := AppConfig.Strings(RunMode + "::" + key) - if len(v) == 0 { - v = AppConfig.Strings(key) - } - return v, nil - case "int": - v, err := AppConfig.Int(RunMode + "::" + key) - if err != nil || v == 0 { - return AppConfig.Int(key) - } - return v, nil - case "bool": - v, err := AppConfig.Bool(RunMode + "::" + key) - if err != nil { - return AppConfig.Bool(key) - } - return v, nil - case "int64": - v, err := AppConfig.Int64(RunMode + "::" + key) - if err != nil || v == 0 { - return AppConfig.Int64(key) - } - return v, nil - case "float": - v, err := AppConfig.Float(RunMode + "::" + key) - if err != nil || v == 0 { - return AppConfig.Float(key) - } - return v, nil - } - return "", errors.New("not support type") -} diff --git a/config/ini.go b/config/ini.go index 1bf2e808..837c9ffe 100644 --- a/config/ini.go +++ b/config/ini.go @@ -48,6 +48,10 @@ type IniConfig struct { // ParseFile creates a new Config and parses the file configuration from the named file. func (ini *IniConfig) Parse(name string) (ConfigContainer, error) { + return ini.parseFile(name) +} + +func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) { file, err := os.Open(name) if err != nil { return nil, err @@ -66,6 +70,13 @@ func (ini *IniConfig) Parse(name string) (ConfigContainer, error) { var comment bytes.Buffer buf := bufio.NewReader(file) + // check the BOM + head, err := buf.Peek(3) + if err == nil && head[0] == 239 && head[1] == 187 && head[2] == 191 { + for i := 1; i <= 3; i++ { + buf.ReadByte() + } + } section := DEFAULT_SECTION for { line, _, err := buf.ReadLine() @@ -108,13 +119,48 @@ func (ini *IniConfig) Parse(name string) (ConfigContainer, error) { cfg.data[section] = make(map[string]string) } keyValue := bytes.SplitN(line, bEqual, 2) + + key := string(bytes.TrimSpace(keyValue[0])) // key name case insensitive + key = strings.ToLower(key) + + // handle include "other.conf" + if len(keyValue) == 1 && strings.HasPrefix(key, "include") { + includefiles := strings.Fields(key) + if includefiles[0] == "include" && len(includefiles) == 2 { + otherfile := strings.Trim(includefiles[1], "\"") + if !path.IsAbs(otherfile) { + otherfile = path.Join(path.Dir(name), otherfile) + } + i, err := ini.parseFile(otherfile) + if err != nil { + return nil, err + } + for sec, dt := range i.data { + if _, ok := cfg.data[sec]; !ok { + cfg.data[sec] = make(map[string]string) + } + for k, v := range dt { + cfg.data[sec][k] = v + } + } + for sec, comm := range i.sectionComment { + cfg.sectionComment[sec] = comm + } + for k, comm := range i.keyComment { + cfg.keyComment[k] = comm + } + continue + } + } + + if len(keyValue) != 2 { + return nil, errors.New("read the content error: \"" + string(line) + "\", should key = val") + } val := bytes.TrimSpace(keyValue[1]) if bytes.HasPrefix(val, bDQuote) { val = bytes.Trim(val, `"`) } - key := string(bytes.TrimSpace(keyValue[0])) // key name case insensitive - key = strings.ToLower(key) cfg.data[section][key] = string(val) if comment.Len() > 0 { cfg.keyComment[section+"."+key] = comment.String() diff --git a/context/context.go b/context/context.go index d31076d4..89b5ffe4 100644 --- a/context/context.go +++ b/context/context.go @@ -69,7 +69,8 @@ func (ctx *Context) Abort(status int, body string) { panic(e) } // last panic user string - panic(body) + ctx.ResponseWriter.Write([]byte(body)) + panic("User stop run") } // Write string to response body. diff --git a/controller.go b/controller.go index eee79513..72ba323b 100644 --- a/controller.go +++ b/controller.go @@ -382,8 +382,37 @@ func (c *Controller) GetStrings(key string) []string { return []string{} } -// GetInt returns input value as int64. -func (c *Controller) GetInt(key string) (int64, error) { +// GetInt returns input as an int +func (c *Controller) GetInt(key string) (int, error) { + return strconv.Atoi(c.Ctx.Input.Query(key)) +} + +// GetInt8 return input as an int8 +func (c *Controller) GetInt8(key string) (int8, error) { + i64, err := strconv.ParseInt(c.Ctx.Input.Query(key), 10, 8) + i8 := int8(i64) + + return i8, err +} + +// GetInt16 returns input as an int16 +func (c *Controller) GetInt16(key string) (int16, error) { + i64, err := strconv.ParseInt(c.Ctx.Input.Query(key), 10, 16) + i16 := int16(i64) + + return i16, err +} + +// GetInt32 returns input as an int32 +func (c *Controller) GetInt32(key string) (int32, error) { + i64, err := strconv.ParseInt(c.Ctx.Input.Query(key), 10, 32) + i32 := int32(i64) + + return i32, err +} + +// GetInt64 returns input value as int64. +func (c *Controller) GetInt64(key string) (int64, error) { return strconv.ParseInt(c.Ctx.Input.Query(key), 10, 64) } diff --git a/controller_test.go b/controller_test.go new file mode 100644 index 00000000..15938cdc --- /dev/null +++ b/controller_test.go @@ -0,0 +1,75 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "fmt" + "github.com/astaxie/beego/context" +) + +func ExampleGetInt() { + + i := &context.BeegoInput{Params: map[string]string{"age": "40"}} + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + + val, _ := ctrlr.GetInt("age") + fmt.Printf("%T", val) + //Output: int +} + +func ExampleGetInt8() { + + i := &context.BeegoInput{Params: map[string]string{"age": "40"}} + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + + val, _ := ctrlr.GetInt8("age") + fmt.Printf("%T", val) + //Output: int8 +} + +func ExampleGetInt16() { + + i := &context.BeegoInput{Params: map[string]string{"age": "40"}} + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + + val, _ := ctrlr.GetInt16("age") + fmt.Printf("%T", val) + //Output: int16 +} + +func ExampleGetInt32() { + + i := &context.BeegoInput{Params: map[string]string{"age": "40"}} + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + + val, _ := ctrlr.GetInt32("age") + fmt.Printf("%T", val) + //Output: int32 +} + +func ExampleGetInt64() { + + i := &context.BeegoInput{Params: map[string]string{"age": "40"}} + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + + val, _ := ctrlr.GetInt64("age") + fmt.Printf("%T", val) + //Output: int64 +} diff --git a/example/chat/controllers/ws.go b/example/chat/controllers/ws.go index 862f89c9..9ec5b418 100644 --- a/example/chat/controllers/ws.go +++ b/example/chat/controllers/ws.go @@ -150,12 +150,12 @@ type WSController struct { } var upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, + ReadBufferSize: 1024, + WriteBufferSize: 1024, } func (this *WSController) Get() { - ws, err := upgrader.Upgrade(this.Ctx.ResponseWriter, this.Ctx.Request,nil) + ws, err := upgrader.Upgrade(this.Ctx.ResponseWriter, this.Ctx.Request, nil) if _, ok := err.(websocket.HandshakeError); ok { http.Error(this.Ctx.ResponseWriter, "Not a websocket handshake", 400) return diff --git a/filter.go b/filter.go index 294966d4..ddd61094 100644 --- a/filter.go +++ b/filter.go @@ -14,12 +14,18 @@ package beego +import "github.com/astaxie/beego/context" + +// FilterFunc defines filter function type. +type FilterFunc func(*context.Context) + // FilterRouter defines filter operation before controller handler execution. // it can match patterned url and do filter function when action arrives. type FilterRouter struct { - filterFunc FilterFunc - tree *Tree - pattern string + filterFunc FilterFunc + tree *Tree + pattern string + returnOnOutput bool } // ValidRouter check current request is valid for this filter. diff --git a/flash.go b/flash.go index 6e85141f..5ccf339a 100644 --- a/flash.go +++ b/flash.go @@ -32,6 +32,24 @@ func NewFlash() *FlashData { } } +// Set message to flash +func (fd *FlashData) Set(key string, msg string, args ...interface{}) { + if len(args) == 0 { + fd.Data[key] = msg + } else { + fd.Data[key] = fmt.Sprintf(msg, args...) + } +} + +// Success writes success message to flash. +func (fd *FlashData) Success(msg string, args ...interface{}) { + if len(args) == 0 { + fd.Data["success"] = msg + } else { + fd.Data["success"] = fmt.Sprintf(msg, args...) + } +} + // Notice writes notice message to flash. func (fd *FlashData) Notice(msg string, args ...interface{}) { if len(args) == 0 { diff --git a/httplib/httplib.go b/httplib/httplib.go index b1f209c9..37ba3b33 100644 --- a/httplib/httplib.go +++ b/httplib/httplib.go @@ -37,6 +37,7 @@ import ( "encoding/xml" "io" "io/ioutil" + "log" "mime/multipart" "net" "net/http" @@ -275,35 +276,36 @@ func (b *BeegoHttpRequest) getResponse() (*http.Response, error) { } else { b.url = b.url + "?" + paramBody } - } else if b.req.Method == "POST" && b.req.Body == nil && len(paramBody) > 0 { + } else if b.req.Method == "POST" && b.req.Body == nil { if len(b.files) > 0 { - bodyBuf := &bytes.Buffer{} - bodyWriter := multipart.NewWriter(bodyBuf) - for formname, filename := range b.files { - fileWriter, err := bodyWriter.CreateFormFile(formname, filename) - if err != nil { - return nil, err + pr, pw := io.Pipe() + bodyWriter := multipart.NewWriter(pw) + go func() { + for formname, filename := range b.files { + fileWriter, err := bodyWriter.CreateFormFile(formname, filename) + if err != nil { + log.Fatal(err) + } + fh, err := os.Open(filename) + if err != nil { + log.Fatal(err) + } + //iocopy + _, err = io.Copy(fileWriter, fh) + fh.Close() + if err != nil { + log.Fatal(err) + } } - fh, err := os.Open(filename) - if err != nil { - return nil, err + for k, v := range b.params { + bodyWriter.WriteField(k, v) } - //iocopy - _, err = io.Copy(fileWriter, fh) - fh.Close() - if err != nil { - return nil, err - } - } - for k, v := range b.params { - bodyWriter.WriteField(k, v) - } - contentType := bodyWriter.FormDataContentType() - bodyWriter.Close() - b.Header("Content-Type", contentType) - b.req.Body = ioutil.NopCloser(bodyBuf) - b.req.ContentLength = int64(bodyBuf.Len()) - } else { + bodyWriter.Close() + pw.Close() + }() + b.Header("Content-Type", bodyWriter.FormDataContentType()) + b.req.Body = ioutil.NopCloser(pr) + } else if len(paramBody) > 0 { b.Header("Content-Type", "application/x-www-form-urlencoded") b.Body(paramBody) } @@ -355,7 +357,7 @@ func (b *BeegoHttpRequest) getResponse() (*http.Response, error) { Jar: jar, } - if b.setting.UserAgent != "" { + if b.setting.UserAgent != "" && b.req.Header.Get("User-Agent") == "" { b.req.Header.Set("User-Agent", b.setting.UserAgent) } diff --git a/httplib/httplib_test.go b/httplib/httplib_test.go index 02068c0b..0b551c53 100644 --- a/httplib/httplib_test.go +++ b/httplib/httplib_test.go @@ -66,23 +66,24 @@ func TestSimplePost(t *testing.T) { } } -func TestPostFile(t *testing.T) { - v := "smallfish" - req := Post("http://httpbin.org/post") - req.Param("username", v) - req.PostFile("uploadfile", "httplib_test.go") +//func TestPostFile(t *testing.T) { +// v := "smallfish" +// req := Post("http://httpbin.org/post") +// req.Debug(true) +// req.Param("username", v) +// req.PostFile("uploadfile", "httplib_test.go") - str, err := req.String() - if err != nil { - t.Fatal(err) - } - t.Log(str) +// str, err := req.String() +// if err != nil { +// t.Fatal(err) +// } +// t.Log(str) - n := strings.Index(str, v) - if n == -1 { - t.Fatal(v + " not found in post") - } -} +// n := strings.Index(str, v) +// if n == -1 { +// t.Fatal(v + " not found in post") +// } +//} func TestSimplePut(t *testing.T) { str, err := Put("http://httpbin.org/put").String() @@ -203,3 +204,13 @@ func TestToFile(t *testing.T) { t.Fatal(err) } } + +func TestHeader(t *testing.T) { + req := Get("http://httpbin.org/headers") + req.Header("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/31.0.1650.57 Safari/537.36") + str, err := req.String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} diff --git a/logs/log.go b/logs/log.go index 341df572..6abfb005 100644 --- a/logs/log.go +++ b/logs/log.go @@ -155,6 +155,9 @@ func (bl *BeeLogger) writerMsg(loglevel int, msg string) error { lm.level = loglevel if bl.enableFuncCallDepth { _, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth) + if _, filename := path.Split(file); filename == "log.go" && (line == 97 || line == 83) { + _, file, line, ok = runtime.Caller(bl.loggerFuncCallDepth + 1) + } if ok { _, filename := path.Split(file) lm.msg = fmt.Sprintf("[%s:%d] %s", filename, line, msg) diff --git a/migration/ddl.go b/migration/ddl.go new file mode 100644 index 00000000..f9b60117 --- /dev/null +++ b/migration/ddl.go @@ -0,0 +1,46 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package migration + +type Table struct { + TableName string + Columns []*Column +} + +func (t *Table) Create() string { + return "" +} + +func (t *Table) Drop() string { + return "" +} + +type Column struct { + Name string + Type string + Default interface{} +} + +func Create(tbname string, columns ...Column) string { + return "" +} + +func Drop(tbname string, columns ...Column) string { + return "" +} + +func TableDDL(tbname string, columns ...Column) string { + return "" +} diff --git a/namespace.go b/namespace.go index d0109291..4e2632e5 100644 --- a/namespace.go +++ b/namespace.go @@ -217,7 +217,7 @@ func (n *Namespace) Namespace(ns ...*Namespace) *Namespace { n.handlers.routers[k] = t } } - if n.handlers.enableFilter { + if ni.handlers.enableFilter { for pos, filterList := range ni.handlers.filters { for _, mr := range filterList { t := NewTree() diff --git a/orm/README.md b/orm/README.md index 74f1b47b..356407c1 100644 --- a/orm/README.md +++ b/orm/README.md @@ -6,8 +6,6 @@ A powerful orm framework for go. It is heavily influenced by Django ORM, SQLAlchemy. -now, beta, unstable, may be changing some api make your app build failed. - **Support Database:** * MySQL: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) diff --git a/orm/qb.go b/orm/qb.go new file mode 100644 index 00000000..efe368db --- /dev/null +++ b/orm/qb.go @@ -0,0 +1,57 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import "errors" + +type QueryBuilder interface { + Select(fields ...string) QueryBuilder + From(tables ...string) QueryBuilder + InnerJoin(table string) QueryBuilder + LeftJoin(table string) QueryBuilder + RightJoin(table string) QueryBuilder + On(cond string) QueryBuilder + Where(cond string) QueryBuilder + And(cond string) QueryBuilder + Or(cond string) QueryBuilder + In(vals ...string) QueryBuilder + OrderBy(fields ...string) QueryBuilder + Asc() QueryBuilder + Desc() QueryBuilder + Limit(limit int) QueryBuilder + Offset(offset int) QueryBuilder + GroupBy(fields ...string) QueryBuilder + Having(cond string) QueryBuilder + Update(tables ...string) QueryBuilder + Set(kv ...string) QueryBuilder + Delete(tables ...string) QueryBuilder + InsertInto(table string, fields ...string) QueryBuilder + Values(vals ...string) QueryBuilder + Subquery(sub string, alias string) string + String() string +} + +func NewQueryBuilder(driver string) (qb QueryBuilder, err error) { + if driver == "mysql" { + qb = new(MySQLQueryBuilder) + } else if driver == "postgres" { + err = errors.New("postgres query builder is not supported yet!") + } else if driver == "sqlite" { + err = errors.New("sqlite query builder is not supported yet!") + } else { + err = errors.New("unknown driver for query builder!") + } + return +} diff --git a/orm/qb_mysql.go b/orm/qb_mysql.go new file mode 100644 index 00000000..9ce9b7d9 --- /dev/null +++ b/orm/qb_mysql.go @@ -0,0 +1,153 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "strconv" + "strings" +) + +const COMMA_SPACE = ", " + +type MySQLQueryBuilder struct { + Tokens []string +} + +func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, COMMA_SPACE)) + return qb +} + +func (qb *MySQLQueryBuilder) From(tables ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, COMMA_SPACE)) + return qb +} + +func (qb *MySQLQueryBuilder) InnerJoin(table string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "INNER JOIN", table) + return qb +} + +func (qb *MySQLQueryBuilder) LeftJoin(table string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "LEFT JOIN", table) + return qb +} + +func (qb *MySQLQueryBuilder) RightJoin(table string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table) + return qb +} + +func (qb *MySQLQueryBuilder) On(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "ON", cond) + return qb +} + +func (qb *MySQLQueryBuilder) Where(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "WHERE", cond) + return qb +} + +func (qb *MySQLQueryBuilder) And(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "AND", cond) + return qb +} + +func (qb *MySQLQueryBuilder) Or(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "OR", cond) + return qb +} + +func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, COMMA_SPACE), ")") + return qb +} + +func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, COMMA_SPACE)) + return qb +} + +func (qb *MySQLQueryBuilder) Asc() QueryBuilder { + qb.Tokens = append(qb.Tokens, "ASC") + return qb +} + +func (qb *MySQLQueryBuilder) Desc() QueryBuilder { + qb.Tokens = append(qb.Tokens, "DESC") + return qb +} + +func (qb *MySQLQueryBuilder) Limit(limit int) QueryBuilder { + qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit)) + return qb +} + +func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder { + qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset)) + return qb +} + +func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, COMMA_SPACE)) + return qb +} + +func (qb *MySQLQueryBuilder) Having(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "HAVING", cond) + return qb +} + +func (qb *MySQLQueryBuilder) Update(tables ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, COMMA_SPACE)) + return qb +} + +func (qb *MySQLQueryBuilder) Set(kv ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, COMMA_SPACE)) + return qb +} + +func (qb *MySQLQueryBuilder) Delete(tables ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "DELETE") + if len(tables) != 0 { + qb.Tokens = append(qb.Tokens, strings.Join(tables, COMMA_SPACE)) + } + return qb +} + +func (qb *MySQLQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "INSERT INTO", table) + if len(fields) != 0 { + fieldsStr := strings.Join(fields, COMMA_SPACE) + qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")") + } + return qb +} + +func (qb *MySQLQueryBuilder) Values(vals ...string) QueryBuilder { + valsStr := strings.Join(vals, COMMA_SPACE) + qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")") + return qb +} + +func (qb *MySQLQueryBuilder) Subquery(sub string, alias string) string { + return fmt.Sprintf("(%s) AS %s", sub, alias) +} + +func (qb *MySQLQueryBuilder) String() string { + return strings.Join(qb.Tokens, " ") +} diff --git a/parser.go b/parser.go index d29470dc..0acdc8f7 100644 --- a/parser.go +++ b/parser.go @@ -42,6 +42,7 @@ func init() { var ( lastupdateFilename string = "lastupdate.tmp" + commentFilename string = "commentsRouter.go" pkgLastupdate map[string]int64 genInfoList map[string][]ControllerComments ) @@ -52,6 +53,8 @@ func init() { } func parserPkg(pkgRealpath, pkgpath string) error { + rep := strings.NewReplacer("/", "_", ".", "_") + commentFilename = rep.Replace(pkgpath) + "_" + commentFilename if !compareFile(pkgRealpath) { Info(pkgRealpath + " don't has updated") return nil @@ -155,7 +158,7 @@ func genRouterCode() { } } if globalinfo != "" { - f, err := os.Create(path.Join(workPath, "routers", "commentsRouter.go")) + f, err := os.Create(path.Join(workPath, "routers", commentFilename)) if err != nil { panic(err) } @@ -165,7 +168,7 @@ func genRouterCode() { } func compareFile(pkgRealpath string) bool { - if !utils.FileExists(path.Join(workPath, "routers", "commentsRouter.go")) { + if !utils.FileExists(path.Join(workPath, "routers", commentFilename)) { return true } if utils.FileExists(path.Join(workPath, lastupdateFilename)) { diff --git a/router.go b/router.go index 7c52b9fc..1f47c907 100644 --- a/router.go +++ b/router.go @@ -72,9 +72,31 @@ var ( "SetSecureCookie", "XsrfToken", "CheckXsrfCookie", "XsrfFormHtml", "GetControllerAndAction"} - url_placeholder = "{{placeholder}}" + url_placeholder = "{{placeholder}}" + DefaultLogFilter FilterHandler = &logFilter{} ) +type FilterHandler interface { + Filter(*beecontext.Context) bool +} + +// default log filter static file will not show +type logFilter struct { +} + +func (l *logFilter) Filter(ctx *beecontext.Context) bool { + requestPath := path.Clean(ctx.Input.Request.URL.Path) + if requestPath == "/favicon.ico" || requestPath == "/robots.txt" { + return true + } + for prefix, _ := range StaticDir { + if strings.HasPrefix(requestPath, prefix) { + return true + } + } + return false +} + // To append a slice's value into "exceptMethod", for controller's methods shouldn't reflect to AutoRouter func ExceptMethodAppend(action string) { exceptMethod = append(exceptMethod, action) @@ -163,6 +185,9 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM } func (p *ControllerRegistor) addToRouter(method, pattern string, r *controllerInfo) { + if !RouterCaseSensitive { + pattern = strings.ToLower(pattern) + } if t, ok := p.routers[method]; ok { t.AddRouter(pattern, r) } else { @@ -376,11 +401,21 @@ func (p *ControllerRegistor) AddAutoPrefix(prefix string, c ControllerInterface) } // Add a FilterFunc with pattern rule and action constant. -func (p *ControllerRegistor) InsertFilter(pattern string, pos int, filter FilterFunc) error { +// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) +func (p *ControllerRegistor) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { + mr := new(FilterRouter) mr.tree = NewTree() mr.pattern = pattern mr.filterFunc = filter + if !RouterCaseSensitive { + pattern = strings.ToLower(pattern) + } + if len(params) == 0 { + mr.returnOnOutput = true + } else { + mr.returnOnOutput = params[0] + } mr.tree.AddRouter(pattern, true) return p.insertFilterRouter(pos, mr) } @@ -415,10 +450,10 @@ func (p *ControllerRegistor) UrlFor(endpoint string, values ...string) string { } } } - controllName := strings.Join(paths[:len(paths)-1], ".") + controllName := strings.Join(paths[:len(paths)-1], "/") methodName := paths[len(paths)-1] - for _, t := range p.routers { - ok, url := p.geturl(t, "/", controllName, methodName, params) + for m, t := range p.routers { + ok, url := p.geturl(t, "/", controllName, methodName, params, m) if ok { return url } @@ -426,24 +461,25 @@ func (p *ControllerRegistor) UrlFor(endpoint string, values ...string) string { return "" } -func (p *ControllerRegistor) geturl(t *Tree, url, controllName, methodName string, params map[string]string) (bool, string) { +func (p *ControllerRegistor) geturl(t *Tree, url, controllName, methodName string, params map[string]string, httpMethod string) (bool, string) { for k, subtree := range t.fixrouters { u := path.Join(url, k) - ok, u := p.geturl(subtree, u, controllName, methodName, params) + ok, u := p.geturl(subtree, u, controllName, methodName, params, httpMethod) if ok { return ok, u } } if t.wildcard != nil { - url = path.Join(url, url_placeholder) - ok, u := p.geturl(t.wildcard, url, controllName, methodName, params) + u := path.Join(url, url_placeholder) + ok, u := p.geturl(t.wildcard, u, controllName, methodName, params, httpMethod) if ok { return ok, u } } for _, l := range t.leaves { if c, ok := l.runObject.(*controllerInfo); ok { - if c.routerType == routerTypeBeego && c.controllerType.Name() == controllName { + if c.routerType == routerTypeBeego && + strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllName) { find := false if _, ok := HTTPMETHOD[strings.ToUpper(methodName)]; ok { if len(c.methods) == 0 { @@ -455,8 +491,8 @@ func (p *ControllerRegistor) geturl(t *Tree, url, controllName, methodName strin } } if !find { - for _, md := range c.methods { - if md == methodName { + for m, md := range c.methods { + if (m == "*" || m == httpMethod) && md == methodName { find = true } } @@ -564,15 +600,21 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) context.Output.Context = context context.Output.EnableGzip = EnableGzip + var urlPath string + if !RouterCaseSensitive { + urlPath = strings.ToLower(r.URL.Path) + } else { + urlPath = r.URL.Path + } // defined filter function do_filter := func(pos int) (started bool) { if p.enableFilter { if l, ok := p.filters[pos]; ok { for _, filterR := range l { - if ok, p := filterR.ValidRouter(r.URL.Path); ok { + if ok, p := filterR.ValidRouter(urlPath); ok { context.Input.Params = p filterR.filterFunc(context) - if w.started { + if filterR.returnOnOutput && w.started { return true } } @@ -602,7 +644,13 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) // session init if SessionOn { - context.Input.CruSession = GlobalSessions.SessionStart(w, r) + var err error + context.Input.CruSession, err = GlobalSessions.SessionStart(w, r) + if err != nil { + Error(err) + middleware.Exception("503", rw, r, "") + return + } defer func() { context.Input.CruSession.SessionRelease(w) }() @@ -626,8 +674,18 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) } if !findrouter { - if t, ok := p.routers[r.Method]; ok { - runObject, p := t.Match(r.URL.Path) + http_method := r.Method + + if http_method == "POST" && context.Input.Query("_method") == "PUT" { + http_method = "PUT" + } + + if http_method == "POST" && context.Input.Query("_method") == "DELETE" { + http_method = "DELETE" + } + + if t, ok := p.routers[http_method]; ok { + runObject, p := t.Match(urlPath) if r, ok := runObject.(*controllerInfo); ok { routerInfo = r findrouter = true @@ -783,7 +841,9 @@ Admin: } else { devinfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s |", r.Method, r.URL.Path, timeend.String(), "notmatch") } - Debug(devinfo) + if DefaultLogFilter == nil || !DefaultLogFilter.Filter(context) { + Debug(devinfo) + } } // Call WriteHeader if status code has been set changed @@ -797,7 +857,9 @@ func (p *ControllerRegistor) recoverPanic(rw http.ResponseWriter, r *http.Reques if err == USERSTOPRUN { return } - if _, ok := err.(middleware.HTTPException); ok { + if he, ok := err.(middleware.HTTPException); ok { + rw.WriteHeader(he.StatusCode) + rw.Write([]byte(he.Description)) // catch intented errors, only for HTTP 4XX and 5XX } else { if RunMode == "dev" { @@ -829,9 +891,15 @@ func (p *ControllerRegistor) recoverPanic(rw http.ResponseWriter, r *http.Reques } else { // in production model show all infomation if ErrorsShow { - handler := p.getErrorHandler(fmt.Sprint(err)) - handler(rw, r) - return + if handler, ok := middleware.ErrorMaps[fmt.Sprint(err)]; ok { + handler(rw, r) + return + } else if handler, ok := middleware.ErrorMaps["503"]; ok { + handler(rw, r) + return + } else { + rw.Write([]byte(fmt.Sprint(err))) + } } else { Critical("the request url is ", r.URL.Path) Critical("Handler crashed with error", err) @@ -850,24 +918,6 @@ func (p *ControllerRegistor) recoverPanic(rw http.ResponseWriter, r *http.Reques } } -// there always should be error handler that sets error code accordingly for all unhandled errors. -// in order to have custom UI for error page it's necessary to override "500" error. -func (p *ControllerRegistor) getErrorHandler(errorCode string) func(rw http.ResponseWriter, r *http.Request) { - handler := middleware.SimpleServerError - ok := true - if errorCode != "" { - handler, ok = middleware.ErrorMaps[errorCode] - if !ok { - handler, ok = middleware.ErrorMaps["500"] - } - if !ok || handler == nil { - handler = middleware.SimpleServerError - } - } - - return handler -} - //responseWriter is a wrapper for the http.ResponseWriter //started set to true if response was written to then don't execute other handler type responseWriter struct { diff --git a/router_test.go b/router_test.go index d378589b..c64bb191 100644 --- a/router_test.go +++ b/router_test.go @@ -17,6 +17,7 @@ package beego import ( "net/http" "net/http/httptest" + "strings" "testing" "github.com/astaxie/beego/context" @@ -385,3 +386,196 @@ func testRequest(method, path string) (*httptest.ResponseRecorder, *http.Request return recorder, request } + +// Execution point: BeforeRouter +// expectation: only BeforeRouter function is executed, notmatch output as router doesn't handle +func TestFilterBeforeRouter(t *testing.T) { + testName := "TestFilterBeforeRouter" + url := "/beforeRouter" + + mux := NewControllerRegister() + mux.InsertFilter(url, BeforeRouter, beegoBeforeRouter1) + + mux.Get(url, beegoFilterFunc) + + rw, r := testRequest("GET", url) + mux.ServeHTTP(rw, r) + + if strings.Contains(rw.Body.String(), "BeforeRouter1") == false { + t.Errorf(testName + " BeforeRouter did not run") + } + if strings.Contains(rw.Body.String(), "hello") == true { + t.Errorf(testName + " BeforeRouter did not return properly") + } +} + +// Execution point: BeforeExec +// expectation: only BeforeExec function is executed, match as router determines route only +func TestFilterBeforeExec(t *testing.T) { + testName := "TestFilterBeforeExec" + url := "/beforeExec" + + mux := NewControllerRegister() + mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) + mux.InsertFilter(url, BeforeExec, beegoBeforeExec1) + + mux.Get(url, beegoFilterFunc) + + rw, r := testRequest("GET", url) + mux.ServeHTTP(rw, r) + + if strings.Contains(rw.Body.String(), "BeforeExec1") == false { + t.Errorf(testName + " BeforeExec did not run") + } + if strings.Contains(rw.Body.String(), "hello") == true { + t.Errorf(testName + " BeforeExec did not return properly") + } + if strings.Contains(rw.Body.String(), "BeforeRouter") == true { + t.Errorf(testName + " BeforeRouter ran in error") + } +} + +// Execution point: AfterExec +// expectation: only AfterExec function is executed, match as router handles +func TestFilterAfterExec(t *testing.T) { + testName := "TestFilterAfterExec" + url := "/afterExec" + + mux := NewControllerRegister() + mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) + mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput) + mux.InsertFilter(url, AfterExec, beegoAfterExec1) + + mux.Get(url, beegoFilterFunc) + + rw, r := testRequest("GET", url) + mux.ServeHTTP(rw, r) + + if strings.Contains(rw.Body.String(), "AfterExec1") == false { + t.Errorf(testName + " AfterExec did not run") + } + if strings.Contains(rw.Body.String(), "hello") == false { + t.Errorf(testName + " handler did not run properly") + } + if strings.Contains(rw.Body.String(), "BeforeRouter") == true { + t.Errorf(testName + " BeforeRouter ran in error") + } + if strings.Contains(rw.Body.String(), "BeforeExec") == true { + t.Errorf(testName + " BeforeExec ran in error") + } +} + +// Execution point: FinishRouter +// expectation: only FinishRouter function is executed, match as router handles +func TestFilterFinishRouter(t *testing.T) { + testName := "TestFilterFinishRouter" + url := "/finishRouter" + + mux := NewControllerRegister() + mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) + mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput) + mux.InsertFilter(url, AfterExec, beegoFilterNoOutput) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1) + + mux.Get(url, beegoFilterFunc) + + rw, r := testRequest("GET", url) + mux.ServeHTTP(rw, r) + + if strings.Contains(rw.Body.String(), "FinishRouter1") == true { + t.Errorf(testName + " FinishRouter did not run") + } + if strings.Contains(rw.Body.String(), "hello") == false { + t.Errorf(testName + " handler did not run properly") + } + if strings.Contains(rw.Body.String(), "AfterExec1") == true { + t.Errorf(testName + " AfterExec ran in error") + } + if strings.Contains(rw.Body.String(), "BeforeRouter") == true { + t.Errorf(testName + " BeforeRouter ran in error") + } + if strings.Contains(rw.Body.String(), "BeforeExec") == true { + t.Errorf(testName + " BeforeExec ran in error") + } +} + +// Execution point: FinishRouter +// expectation: only first FinishRouter function is executed, match as router handles +func TestFilterFinishRouterMultiFirstOnly(t *testing.T) { + testName := "TestFilterFinishRouterMultiFirstOnly" + url := "/finishRouterMultiFirstOnly" + + mux := NewControllerRegister() + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter2) + + mux.Get(url, beegoFilterFunc) + + rw, r := testRequest("GET", url) + mux.ServeHTTP(rw, r) + + if strings.Contains(rw.Body.String(), "FinishRouter1") == false { + t.Errorf(testName + " FinishRouter1 did not run") + } + if strings.Contains(rw.Body.String(), "hello") == false { + t.Errorf(testName + " handler did not run properly") + } + // not expected in body + if strings.Contains(rw.Body.String(), "FinishRouter2") == true { + t.Errorf(testName + " FinishRouter2 did run") + } +} + +// Execution point: FinishRouter +// expectation: both FinishRouter functions execute, match as router handles +func TestFilterFinishRouterMulti(t *testing.T) { + testName := "TestFilterFinishRouterMulti" + url := "/finishRouterMulti" + + mux := NewControllerRegister() + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter2) + + mux.Get(url, beegoFilterFunc) + + rw, r := testRequest("GET", url) + mux.ServeHTTP(rw, r) + + if strings.Contains(rw.Body.String(), "FinishRouter1") == false { + t.Errorf(testName + " FinishRouter1 did not run") + } + if strings.Contains(rw.Body.String(), "hello") == false { + t.Errorf(testName + " handler did not run properly") + } + if strings.Contains(rw.Body.String(), "FinishRouter2") == false { + t.Errorf(testName + " FinishRouter2 did not run properly") + } +} + +func beegoFilterNoOutput(ctx *context.Context) { + return +} +func beegoBeforeRouter1(ctx *context.Context) { + ctx.WriteString("|BeforeRouter1") +} +func beegoBeforeRouter2(ctx *context.Context) { + ctx.WriteString("|BeforeRouter2") +} +func beegoBeforeExec1(ctx *context.Context) { + ctx.WriteString("|BeforeExec1") +} +func beegoBeforeExec2(ctx *context.Context) { + ctx.WriteString("|BeforeExec2") +} +func beegoAfterExec1(ctx *context.Context) { + ctx.WriteString("|AfterExec1") +} +func beegoAfterExec2(ctx *context.Context) { + ctx.WriteString("|AfterExec2") +} +func beegoFinishRouter1(ctx *context.Context) { + ctx.WriteString("|FinishRouter1") +} +func beegoFinishRouter2(ctx *context.Context) { + ctx.WriteString("|FinishRouter2") +} diff --git a/session/ledis/ledis_session.go b/session/ledis/ledis_session.go new file mode 100644 index 00000000..3ada47ac --- /dev/null +++ b/session/ledis/ledis_session.go @@ -0,0 +1,168 @@ +package session + +import ( + "net/http" + "sync" + + "github.com/astaxie/beego/session" + "github.com/siddontang/ledisdb/config" + "github.com/siddontang/ledisdb/ledis" +) + +var ledispder = &LedisProvider{} +var c *ledis.DB + +// ledis session store +type LedisSessionStore struct { + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxlifetime int64 +} + +// set value in ledis session +func (ls *LedisSessionStore) Set(key, value interface{}) error { + ls.lock.Lock() + defer ls.lock.Unlock() + ls.values[key] = value + return nil +} + +// get value in ledis session +func (ls *LedisSessionStore) Get(key interface{}) interface{} { + ls.lock.RLock() + defer ls.lock.RUnlock() + if v, ok := ls.values[key]; ok { + return v + } else { + return nil + } +} + +// delete value in ledis session +func (ls *LedisSessionStore) Delete(key interface{}) error { + ls.lock.Lock() + defer ls.lock.Unlock() + delete(ls.values, key) + return nil +} + +// clear all values in ledis session +func (ls *LedisSessionStore) Flush() error { + ls.lock.Lock() + defer ls.lock.Unlock() + ls.values = make(map[interface{}]interface{}) + return nil +} + +// get ledis session id +func (ls *LedisSessionStore) SessionID() string { + return ls.sid +} + +// save session values to ledis +func (ls *LedisSessionStore) SessionRelease(w http.ResponseWriter) { + b, err := session.EncodeGob(ls.values) + if err != nil { + return + } + c.Set([]byte(ls.sid), b) + c.Expire([]byte(ls.sid), ls.maxlifetime) +} + +// ledis session provider +type LedisProvider struct { + maxlifetime int64 + savePath string +} + +// init ledis session +// savepath like ledis server saveDataPath,pool size +// e.g. 127.0.0.1:6379,100,astaxie +func (lp *LedisProvider) SessionInit(maxlifetime int64, savePath string) error { + lp.maxlifetime = maxlifetime + lp.savePath = savePath + cfg := new(config.Config) + cfg.DataDir = lp.savePath + var err error + nowLedis, err := ledis.Open(cfg) + c, err = nowLedis.Select(0) + if err != nil { + println(err) + return nil + } + return nil +} + +// read ledis session by sid +func (lp *LedisProvider) SessionRead(sid string) (session.SessionStore, error) { + kvs, err := c.Get([]byte(sid)) + var kv map[interface{}]interface{} + if len(kvs) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob(kvs) + if err != nil { + return nil, err + } + } + ls := &LedisSessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime} + return ls, nil +} + +// check ledis session exist by sid +func (lp *LedisProvider) SessionExist(sid string) bool { + count, _ := c.Exists([]byte(sid)) + if count == 0 { + return false + } else { + return true + } +} + +// generate new sid for ledis session +func (lp *LedisProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) { + count, _ := c.Exists([]byte(sid)) + if count == 0 { + // oldsid doesn't exists, set the new sid directly + // ignore error here, since if it return error + // the existed value will be 0 + c.Set([]byte(sid), []byte("")) + c.Expire([]byte(sid), lp.maxlifetime) + } else { + data, _ := c.Get([]byte(oldsid)) + c.Set([]byte(sid), data) + c.Expire([]byte(sid), lp.maxlifetime) + } + kvs, err := c.Get([]byte(sid)) + var kv map[interface{}]interface{} + if len(kvs) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob([]byte(kvs)) + if err != nil { + return nil, err + } + } + ls := &LedisSessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime} + return ls, nil +} + +// delete ledis session by id +func (lp *LedisProvider) SessionDestroy(sid string) error { + c.Del([]byte(sid)) + return nil +} + +// Impelment method, no used. +func (lp *LedisProvider) SessionGC() { + return +} + +// @todo +func (lp *LedisProvider) SessionAll() int { + return 0 +} +func init() { + session.Register("ledis", ledispder) +} diff --git a/session/redis/sess_redis.go b/session/redis/sess_redis.go index f2b0b29b..82cdd812 100644 --- a/session/redis/sess_redis.go +++ b/session/redis/sess_redis.go @@ -109,7 +109,7 @@ func (rs *RedisSessionStore) SessionRelease(w http.ResponseWriter) { return } - c.Do("SET", rs.sid, string(b), "EX", rs.maxlifetime) + c.Do("SETEX", rs.sid, rs.maxlifetime, string(b)) } // redis session provider diff --git a/session/sess_cookie_test.go b/session/sess_cookie_test.go index 4f40a7ba..fe3ac806 100644 --- a/session/sess_cookie_test.go +++ b/session/sess_cookie_test.go @@ -29,7 +29,10 @@ func TestCookie(t *testing.T) { } r, _ := http.NewRequest("GET", "/", nil) w := httptest.NewRecorder() - sess := globalSessions.SessionStart(w, r) + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("set error,", err) + } err = sess.Set("username", "astaxie") if err != nil { t.Fatal("set error,", err) diff --git a/session/sess_mem_test.go b/session/sess_mem_test.go index 03927c76..43f5b0a9 100644 --- a/session/sess_mem_test.go +++ b/session/sess_mem_test.go @@ -26,9 +26,12 @@ func TestMem(t *testing.T) { go globalSessions.GC() r, _ := http.NewRequest("GET", "/", nil) w := httptest.NewRecorder() - sess := globalSessions.SessionStart(w, r) + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("set error,", err) + } defer sess.SessionRelease(w) - err := sess.Set("username", "astaxie") + err = sess.Set("username", "astaxie") if err != nil { t.Fatal("set error,", err) } diff --git a/session/session.go b/session/session.go index 88e94d59..3cbd2b05 100644 --- a/session/session.go +++ b/session/session.go @@ -28,19 +28,13 @@ package session import ( - "crypto/hmac" - "crypto/md5" "crypto/rand" - "crypto/sha1" "encoding/hex" "encoding/json" "fmt" - "io" "net/http" "net/url" "time" - - "github.com/astaxie/beego/utils" ) // SessionStore contains all data for one session process with specific id. @@ -81,16 +75,15 @@ func Register(name string, provide Provider) { } type managerConfig struct { - CookieName string `json:"cookieName"` - EnableSetCookie bool `json:"enableSetCookie,omitempty"` - Gclifetime int64 `json:"gclifetime"` - Maxlifetime int64 `json:"maxLifetime"` - Secure bool `json:"secure"` - SessionIDHashFunc string `json:"sessionIDHashFunc"` - SessionIDHashKey string `json:"sessionIDHashKey"` - CookieLifeTime int `json:"cookieLifeTime"` - ProviderConfig string `json:"providerConfig"` - Domain string `json:"domain"` + CookieName string `json:"cookieName"` + EnableSetCookie bool `json:"enableSetCookie,omitempty"` + Gclifetime int64 `json:"gclifetime"` + Maxlifetime int64 `json:"maxLifetime"` + Secure bool `json:"secure"` + CookieLifeTime int `json:"cookieLifeTime"` + ProviderConfig string `json:"providerConfig"` + Domain string `json:"domain"` + SessionIdLength int64 `json:"sessionIdLength"` } // Manager contains Provider and its configuration. @@ -129,11 +122,9 @@ func NewManager(provideName, config string) (*Manager, error) { if err != nil { return nil, err } - if cf.SessionIDHashFunc == "" { - cf.SessionIDHashFunc = "sha1" - } - if cf.SessionIDHashKey == "" { - cf.SessionIDHashKey = string(generateRandomKey(16)) + + if cf.SessionIdLength == 0 { + cf.SessionIdLength = 16 } return &Manager{ @@ -144,11 +135,14 @@ func NewManager(provideName, config string) (*Manager, error) { // Start session. generate or read the session id from http request. // if session id exists, return SessionStore with this id. -func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore) { - cookie, err := r.Cookie(manager.config.CookieName) - if err != nil || cookie.Value == "" { - sid := manager.sessionId(r) - session, _ = manager.provider.SessionRead(sid) +func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore, err error) { + cookie, errs := r.Cookie(manager.config.CookieName) + if errs != nil || cookie.Value == "" { + sid, errs := manager.sessionId(r) + if errs != nil { + return nil, errs + } + session, err = manager.provider.SessionRead(sid) cookie = &http.Cookie{Name: manager.config.CookieName, Value: url.QueryEscape(sid), Path: "/", @@ -163,12 +157,18 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se } r.AddCookie(cookie) } else { - sid, _ := url.QueryUnescape(cookie.Value) + sid, errs := url.QueryUnescape(cookie.Value) + if errs != nil { + return nil, errs + } if manager.provider.SessionExist(sid) { - session, _ = manager.provider.SessionRead(sid) + session, err = manager.provider.SessionRead(sid) } else { - sid = manager.sessionId(r) - session, _ = manager.provider.SessionRead(sid) + sid, err = manager.sessionId(r) + if err != nil { + return nil, err + } + session, err = manager.provider.SessionRead(sid) cookie = &http.Cookie{Name: manager.config.CookieName, Value: url.QueryEscape(sid), Path: "/", @@ -219,7 +219,10 @@ func (manager *Manager) GC() { // Regenerate a session id for this SessionStore who's id is saving in http request. func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Request) (session SessionStore) { - sid := manager.sessionId(r) + sid, err := manager.sessionId(r) + if err != nil { + return + } cookie, err := r.Cookie(manager.config.CookieName) if err != nil && cookie.Value == "" { //delete old cookie @@ -251,36 +254,16 @@ func (manager *Manager) GetActiveSession() int { return manager.provider.SessionAll() } -// Set hash function for generating session id. -func (manager *Manager) SetHashFunc(hasfunc, hashkey string) { - manager.config.SessionIDHashFunc = hasfunc - manager.config.SessionIDHashKey = hashkey -} - // Set cookie with https. func (manager *Manager) SetSecure(secure bool) { manager.config.Secure = secure } -// generate session id with rand string, unix nano time, remote addr by hash function. -func (manager *Manager) sessionId(r *http.Request) (sid string) { - bs := make([]byte, 32) - if n, err := io.ReadFull(rand.Reader, bs); n != 32 || err != nil { - bs = utils.RandomCreateBytes(32) +func (manager *Manager) sessionId(r *http.Request) (string, error) { + b := make([]byte, manager.config.SessionIdLength) + n, err := rand.Read(b) + if n != len(b) || err != nil { + return "", fmt.Errorf("Could not successfully read from the system CSPRNG.") } - sig := fmt.Sprintf("%s%d%s", r.RemoteAddr, time.Now().UnixNano(), bs) - if manager.config.SessionIDHashFunc == "md5" { - h := md5.New() - h.Write([]byte(sig)) - sid = hex.EncodeToString(h.Sum(nil)) - } else if manager.config.SessionIDHashFunc == "sha1" { - h := hmac.New(sha1.New, []byte(manager.config.SessionIDHashKey)) - fmt.Fprintf(h, "%s", sig) - sid = hex.EncodeToString(h.Sum(nil)) - } else { - h := hmac.New(sha1.New, []byte(manager.config.SessionIDHashKey)) - fmt.Fprintf(h, "%s", sig) - sid = hex.EncodeToString(h.Sum(nil)) - } - return + return hex.EncodeToString(b), nil } diff --git a/staticfile.go b/staticfile.go index a9deabe9..d9855064 100644 --- a/staticfile.go +++ b/staticfile.go @@ -31,6 +31,7 @@ func serverStaticRouter(ctx *context.Context) { return } requestPath := path.Clean(ctx.Input.Request.URL.Path) + i := 0 for prefix, staticDir := range StaticDir { if len(prefix) == 0 { continue @@ -41,8 +42,13 @@ func serverStaticRouter(ctx *context.Context) { http.ServeFile(ctx.ResponseWriter, ctx.Request, file) return } else { - http.NotFound(ctx.ResponseWriter, ctx.Request) - return + i++ + if i == len(StaticDir) { + http.NotFound(ctx.ResponseWriter, ctx.Request) + return + } else { + continue + } } } if strings.HasPrefix(requestPath, prefix) { @@ -59,9 +65,20 @@ func serverStaticRouter(ctx *context.Context) { return } //if the request is dir and DirectoryIndex is false then - if finfo.IsDir() && !DirectoryIndex { - middleware.Exception("403", ctx.ResponseWriter, ctx.Request, "403 Forbidden") - return + if finfo.IsDir() { + if !DirectoryIndex { + middleware.Exception("403", ctx.ResponseWriter, ctx.Request, "403 Forbidden") + return + } else if ctx.Input.Request.URL.Path[len(ctx.Input.Request.URL.Path)-1] != '/' { + http.Redirect(ctx.ResponseWriter, ctx.Request, ctx.Input.Request.URL.Path+"/", 302) + return + } + } else if strings.HasSuffix(requestPath, "/index.html") { + file := path.Join(staticDir, requestPath) + if utils.FileExists(file) { + http.ServeFile(ctx.ResponseWriter, ctx.Request, file) + return + } } //This block obtained from (https://github.com/smithfox/beego) - it should probably get merged into astaxie/beego after a pull request diff --git a/templatefunc.go b/templatefunc.go index a365718d..16067613 100644 --- a/templatefunc.go +++ b/templatefunc.go @@ -302,6 +302,14 @@ func ParseForm(form url.Values, obj interface{}) error { switch fieldT.Type.Kind() { case reflect.Bool: + if strings.ToLower(value) == "on" || strings.ToLower(value) == "1" || strings.ToLower(value) == "yes" { + fieldV.SetBool(true) + continue + } + if strings.ToLower(value) == "off" || strings.ToLower(value) == "0" || strings.ToLower(value) == "no" { + fieldV.SetBool(false) + continue + } b, err := strconv.ParseBool(value) if err != nil { return err @@ -329,6 +337,19 @@ func ParseForm(form url.Values, obj interface{}) error { fieldV.Set(reflect.ValueOf(value)) case reflect.String: fieldV.SetString(value) + case reflect.Struct: + switch fieldT.Type.String() { + case "time.Time": + format := time.RFC3339 + if len(tags) > 1 { + format = tags[1] + } + t, err := time.Parse(format, value) + if err != nil { + return err + } + fieldV.Set(reflect.ValueOf(t)) + } } } return nil @@ -368,23 +389,31 @@ func RenderForm(obj interface{}) template.HTML { fieldT := objT.Field(i) - label, name, fType, ignored := parseFormTag(fieldT) + label, name, fType, id, class, ignored := parseFormTag(fieldT) if ignored { continue } - raw = append(raw, renderFormField(label, name, fType, fieldV.Interface())) + raw = append(raw, renderFormField(label, name, fType, fieldV.Interface(), id, class)) } return template.HTML(strings.Join(raw, "
")) } // renderFormField returns a string containing HTML of a single form field. -func renderFormField(label, name, fType string, value interface{}) string { - if isValidForInput(fType) { - return fmt.Sprintf(`%v`, label, name, fType, value) +func renderFormField(label, name, fType string, value interface{}, id string, class string) string { + if id != "" { + id = " id=\"" + id + "\"" } - return fmt.Sprintf(`%v<%v name="%v">%v`, label, fType, name, value, fType) + if class != "" { + class = " class=\"" + class + "\"" + } + + if isValidForInput(fType) { + return fmt.Sprintf(`%v`, label, id, class, name, fType, value) + } + + return fmt.Sprintf(`%v<%v%v%v name="%v">%v`, label, fType, id, class, name, value, fType) } // isValidForInput checks if fType is a valid value for the `type` property of an HTML input element. @@ -400,12 +429,14 @@ func isValidForInput(fType string) bool { // parseFormTag takes the stuct-tag of a StructField and parses the `form` value. // returned are the form label, name-property, type and wether the field should be ignored. -func parseFormTag(fieldT reflect.StructField) (label, name, fType string, ignored bool) { +func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id string, class string, ignored bool) { tags := strings.Split(fieldT.Tag.Get("form"), ",") label = fieldT.Name + ": " name = fieldT.Name fType = "text" ignored = false + id = fieldT.Tag.Get("id") + class = fieldT.Tag.Get("class") switch len(tags) { case 1: diff --git a/templatefunc_test.go b/templatefunc_test.go index 44a06dec..3692a821 100644 --- a/templatefunc_test.go +++ b/templatefunc_test.go @@ -102,12 +102,14 @@ func TestHtmlunquote(t *testing.T) { func TestParseForm(t *testing.T) { type user struct { - Id int `form:"-"` - tag string `form:"tag"` - Name interface{} `form:"username"` - Age int `form:"age,text"` - Email string - Intro string `form:",textarea"` + Id int `form:"-"` + tag string `form:"tag"` + Name interface{} `form:"username"` + Age int `form:"age,text"` + Email string + Intro string `form:",textarea"` + StrBool bool `form:"strbool"` + Date time.Time `form:"date,2006-01-02"` } u := user{} @@ -119,6 +121,8 @@ func TestParseForm(t *testing.T) { "age": []string{"40"}, "Email": []string{"test@gmail.com"}, "Intro": []string{"I am an engineer!"}, + "strbool": []string{"yes"}, + "date": []string{"2014-11-12"}, } if err := ParseForm(form, u); err == nil { t.Fatal("nothing will be changed") @@ -144,6 +148,13 @@ func TestParseForm(t *testing.T) { if u.Intro != "I am an engineer!" { t.Errorf("Intro should equal `I am an engineer!` but got `%v`", u.Intro) } + if u.StrBool != true { + t.Errorf("strboll should equal `true`, but got `%v`", u.StrBool) + } + y, m, d := u.Date.Date() + if y != 2014 || m.String() != "November" || d != 12 { + t.Errorf("Date should equal `2014-11-12`, but got `%v`", u.Date.String()) + } } func TestRenderForm(t *testing.T) { @@ -175,12 +186,12 @@ func TestRenderForm(t *testing.T) { } func TestRenderFormField(t *testing.T) { - html := renderFormField("Label: ", "Name", "text", "Value") + html := renderFormField("Label: ", "Name", "text", "Value", "", "") if html != `Label: ` { t.Errorf("Wrong html output for input[type=text]: %v ", html) } - html = renderFormField("Label: ", "Name", "textarea", "Value") + html = renderFormField("Label: ", "Name", "textarea", "Value", "", "") if html != `Label: ` { t.Errorf("Wrong html output for textarea: %v ", html) } @@ -192,33 +203,34 @@ func TestParseFormTag(t *testing.T) { All int `form:"name,text,年龄:"` NoName int `form:",hidden,年龄:"` OnlyLabel int `form:",,年龄:"` - OnlyName int `form:"name"` + OnlyName int `form:"name" id:"name" class:"form-name"` Ignored int `form:"-"` } objT := reflect.TypeOf(&user{}).Elem() - label, name, fType, ignored := parseFormTag(objT.Field(0)) + label, name, fType, id, class, ignored := parseFormTag(objT.Field(0)) if !(name == "name" && label == "年龄:" && fType == "text" && ignored == false) { t.Errorf("Form Tag with name, label and type was not correctly parsed.") } - label, name, fType, ignored = parseFormTag(objT.Field(1)) + label, name, fType, id, class, ignored = parseFormTag(objT.Field(1)) if !(name == "NoName" && label == "年龄:" && fType == "hidden" && ignored == false) { t.Errorf("Form Tag with label and type but without name was not correctly parsed.") } - label, name, fType, ignored = parseFormTag(objT.Field(2)) + label, name, fType, id, class, ignored = parseFormTag(objT.Field(2)) if !(name == "OnlyLabel" && label == "年龄:" && fType == "text" && ignored == false) { t.Errorf("Form Tag containing only label was not correctly parsed.") } - label, name, fType, ignored = parseFormTag(objT.Field(3)) - if !(name == "name" && label == "OnlyName: " && fType == "text" && ignored == false) { + label, name, fType, id, class, ignored = parseFormTag(objT.Field(3)) + if !(name == "name" && label == "OnlyName: " && fType == "text" && ignored == false && + id == "name" && class == "form-name") { t.Errorf("Form Tag containing only name was not correctly parsed.") } - label, name, fType, ignored = parseFormTag(objT.Field(4)) + label, name, fType, id, class, ignored = parseFormTag(objT.Field(4)) if ignored == false { t.Errorf("Form Tag that should be ignored was not correctly parsed.") } diff --git a/toolbox/statistics.go b/toolbox/statistics.go index 382daba0..beeafc7b 100644 --- a/toolbox/statistics.go +++ b/toolbox/statistics.go @@ -111,6 +111,27 @@ func (m *UrlMap) GetMap() map[string]interface{} { return content } +func (m *UrlMap) GetMapData() []map[string]interface{} { + + resultLists := make([]map[string]interface{}, 0) + + for k, v := range m.urlmap { + for kk, vv := range v { + result := map[string]interface{}{ + "request_url": k, + "method": kk, + "times": vv.RequestNum, + "total_time": toS(vv.TotalTime), + "max_time": toS(vv.MaxTime), + "min_time": toS(vv.MinTime), + "avg_time": toS(time.Duration(int64(vv.TotalTime) / vv.RequestNum)), + } + resultLists = append(resultLists, result) + } + } + return resultLists +} + // global statistics data map var StatisticsMap *UrlMap diff --git a/toolbox/statistics_test.go b/toolbox/statistics_test.go index 448b2af5..ac29476c 100644 --- a/toolbox/statistics_test.go +++ b/toolbox/statistics_test.go @@ -15,6 +15,7 @@ package toolbox import ( + "encoding/json" "testing" "time" ) @@ -28,4 +29,12 @@ func TestStatics(t *testing.T) { StatisticsMap.AddStatistics("POST", "/api/user/xiemengjun", "&admin.user", time.Duration(13000)) StatisticsMap.AddStatistics("DELETE", "/api/user", "&admin.user", time.Duration(1400)) t.Log(StatisticsMap.GetMap()) + + data := StatisticsMap.GetMapData() + b, err := json.Marshal(data) + if err != nil { + t.Errorf(err.Error()) + } + + t.Log(string(b)) } diff --git a/tree.go b/tree.go index 9f86dd48..25947442 100644 --- a/tree.go +++ b/tree.go @@ -394,6 +394,9 @@ func (leaf *leafInfo) match(wildcardValues []string) (ok bool, params map[string } return true, params } + if len(wildcardValues) <= j { + return false, nil + } params[v] = wildcardValues[j] j += 1 } diff --git a/utils/mail.go b/utils/mail.go index c7ab73d8..aa219626 100644 --- a/utils/mail.go +++ b/utils/mail.go @@ -157,19 +157,37 @@ func (e *Email) Bytes() ([]byte, error) { } // Add attach file to the send mail -func (e *Email) AttachFile(filename string) (a *Attachment, err error) { +func (e *Email) AttachFile(args ...string) (a *Attachment, err error) { + if len(args) < 1 && len(args) > 2 { + err = errors.New("Must specify a file name and number of parameters can not exceed at least two") + return + } + filename := args[0] + id := "" + if len(args) > 1 { + id = args[1] + } f, err := os.Open(filename) if err != nil { return } ct := mime.TypeByExtension(filepath.Ext(filename)) basename := path.Base(filename) - return e.Attach(f, basename, ct) + return e.Attach(f, basename, ct, id) } // Attach is used to attach content from an io.Reader to the email. // Parameters include an io.Reader, the desired filename for the attachment, and the Content-Type. -func (e *Email) Attach(r io.Reader, filename string, c string) (a *Attachment, err error) { +func (e *Email) Attach(r io.Reader, filename string, args ...string) (a *Attachment, err error) { + if len(args) < 1 && len(args) > 2 { + err = errors.New("Must specify the file type and number of parameters can not exceed at least two") + return + } + c := args[0] //Content-Type + id := "" + if len(args) > 1 { + id = args[1] //Content-ID + } var buffer bytes.Buffer if _, err = io.Copy(&buffer, r); err != nil { return @@ -186,7 +204,12 @@ func (e *Email) Attach(r io.Reader, filename string, c string) (a *Attachment, e // If the Content-Type is blank, set the Content-Type to "application/octet-stream" at.Header.Set("Content-Type", "application/octet-stream") } - at.Header.Set("Content-Disposition", fmt.Sprintf("attachment;\r\n filename=\"%s\"", filename)) + if id != "" { + at.Header.Set("Content-Disposition", fmt.Sprintf("inline;\r\n filename=\"%s\"", filename)) + at.Header.Set("Content-ID", fmt.Sprintf("<%s>", id)) + } else { + at.Header.Set("Content-Disposition", fmt.Sprintf("attachment;\r\n filename=\"%s\"", filename)) + } at.Header.Set("Content-Transfer-Encoding", "base64") e.Attachments = append(e.Attachments, at) return at, nil @@ -269,7 +292,7 @@ func qpEscape(dest []byte, c byte) { const nums = "0123456789ABCDEF" dest[0] = '=' dest[1] = nums[(c&0xf0)>>4] - dest[2] = nums[(c & 0xf)] + dest[2] = nums[(c&0xf)] } // headerToBytes enumerates the key and values in the header, and writes the results to the IO Writer diff --git a/utils/pagination/controller.go b/utils/pagination/controller.go new file mode 100644 index 00000000..28473f8a --- /dev/null +++ b/utils/pagination/controller.go @@ -0,0 +1,26 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagination + +import ( + "github.com/astaxie/beego/context" +) + +// Instantiates a Paginator and assigns it to context.Input.Data["paginator"]. +func SetPaginator(context *context.Context, per int, nums int64) (paginator *Paginator) { + paginator = NewPaginator(context.Request, per, nums) + context.Input.Data["paginator"] = paginator + return +} diff --git a/utils/pagination/doc.go b/utils/pagination/doc.go new file mode 100644 index 00000000..df0fa3b7 --- /dev/null +++ b/utils/pagination/doc.go @@ -0,0 +1,59 @@ +/* + +The pagination package provides utilities to setup a paginator within the +context of a http request. + +Usage + +In your beego.Controller: + + package controllers + + import "github.com/astaxie/beego/utils/pagination" + + type PostsController struct { + beego.Controller + } + + func (this *PostsController) ListAllPosts() { + // sets this.Data["paginator"] with the current offset (from the url query param) + postsPerPage := 20 + paginator := pagination.SetPaginator(this.Ctx, postsPerPage, CountPosts()) + + // fetch the next 20 posts + this.Data["posts"] = ListPostsByOffsetAndLimit(paginator.Offset(), postsPerPage) + } + + +In your view templates: + + {{if .paginator.HasPages}} + + {{end}} + +See also + +http://beego.me/docs/mvc/view/page.md + +*/ +package pagination diff --git a/utils/pagination/paginator.go b/utils/pagination/paginator.go new file mode 100644 index 00000000..f89e878e --- /dev/null +++ b/utils/pagination/paginator.go @@ -0,0 +1,189 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagination + +import ( + "math" + "net/http" + "net/url" + "strconv" +) + +// Paginator within the state of a http request. +type Paginator struct { + Request *http.Request + PerPageNums int + MaxPages int + + nums int64 + pageRange []int + pageNums int + page int +} + +// Returns the total number of pages. +func (p *Paginator) PageNums() int { + if p.pageNums != 0 { + return p.pageNums + } + pageNums := math.Ceil(float64(p.nums) / float64(p.PerPageNums)) + if p.MaxPages > 0 { + pageNums = math.Min(pageNums, float64(p.MaxPages)) + } + p.pageNums = int(pageNums) + return p.pageNums +} + +// Returns the total number of items (e.g. from doing SQL count). +func (p *Paginator) Nums() int64 { + return p.nums +} + +// Sets the total number of items. +func (p *Paginator) SetNums(nums interface{}) { + p.nums, _ = ToInt64(nums) +} + +// Returns the current page. +func (p *Paginator) Page() int { + if p.page != 0 { + return p.page + } + if p.Request.Form == nil { + p.Request.ParseForm() + } + p.page, _ = strconv.Atoi(p.Request.Form.Get("p")) + if p.page > p.PageNums() { + p.page = p.PageNums() + } + if p.page <= 0 { + p.page = 1 + } + return p.page +} + +// Returns a list of all pages. +// +// Usage (in a view template): +// +// {{range $index, $page := .paginator.Pages}} +// +// {{$page}} +// +// {{end}} +func (p *Paginator) Pages() []int { + if p.pageRange == nil && p.nums > 0 { + var pages []int + pageNums := p.PageNums() + page := p.Page() + switch { + case page >= pageNums-4 && pageNums > 9: + start := pageNums - 9 + 1 + pages = make([]int, 9) + for i, _ := range pages { + pages[i] = start + i + } + case page >= 5 && pageNums > 9: + start := page - 5 + 1 + pages = make([]int, int(math.Min(9, float64(page+4+1)))) + for i, _ := range pages { + pages[i] = start + i + } + default: + pages = make([]int, int(math.Min(9, float64(pageNums)))) + for i, _ := range pages { + pages[i] = i + 1 + } + } + p.pageRange = pages + } + return p.pageRange +} + +// Returns URL for a given page index. +func (p *Paginator) PageLink(page int) string { + link, _ := url.ParseRequestURI(p.Request.RequestURI) + values := link.Query() + if page == 1 { + values.Del("p") + } else { + values.Set("p", strconv.Itoa(page)) + } + link.RawQuery = values.Encode() + return link.String() +} + +// Returns URL to the previous page. +func (p *Paginator) PageLinkPrev() (link string) { + if p.HasPrev() { + link = p.PageLink(p.Page() - 1) + } + return +} + +// Returns URL to the next page. +func (p *Paginator) PageLinkNext() (link string) { + if p.HasNext() { + link = p.PageLink(p.Page() + 1) + } + return +} + +// Returns URL to the first page. +func (p *Paginator) PageLinkFirst() (link string) { + return p.PageLink(1) +} + +// Returns URL to the last page. +func (p *Paginator) PageLinkLast() (link string) { + return p.PageLink(p.PageNums()) +} + +// Returns true if the current page has a predecessor. +func (p *Paginator) HasPrev() bool { + return p.Page() > 1 +} + +// Returns true if the current page has a successor. +func (p *Paginator) HasNext() bool { + return p.Page() < p.PageNums() +} + +// Returns true if the given page index points to the current page. +func (p *Paginator) IsActive(page int) bool { + return p.Page() == page +} + +// Returns the current offset. +func (p *Paginator) Offset() int { + return (p.Page() - 1) * p.PerPageNums +} + +// Returns true if there is more than one page. +func (p *Paginator) HasPages() bool { + return p.PageNums() > 1 +} + +// Instantiates a paginator struct for the current http request. +func NewPaginator(req *http.Request, per int, nums interface{}) *Paginator { + p := Paginator{} + p.Request = req + if per <= 0 { + per = 10 + } + p.PerPageNums = per + p.SetNums(nums) + return &p +} diff --git a/utils/pagination/utils.go b/utils/pagination/utils.go new file mode 100644 index 00000000..5932647d --- /dev/null +++ b/utils/pagination/utils.go @@ -0,0 +1,34 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagination + +import ( + "fmt" + "reflect" +) + +// convert any numeric value to int64 +func ToInt64(value interface{}) (d int64, err error) { + val := reflect.ValueOf(value) + switch value.(type) { + case int, int8, int16, int32, int64: + d = val.Int() + case uint, uint8, uint16, uint32, uint64: + d = int64(val.Uint()) + default: + err = fmt.Errorf("ToInt64 need numeric not `%T`", value) + } + return +}