diff --git a/.gitignore b/.gitignore index 39ae5706..9806457b 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ .DS_Store *.swp *.swo +beego.iml diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 00000000..c59cef61 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,30 @@ +language: go + +go: + - 1.5.1 + +services: + - redis-server + - mysql + - postgresql + - memcached +env: + - ORM_DRIVER=sqlite3 ORM_SOURCE=$TRAVIS_BUILD_DIR/orm_test.db + - ORM_DRIVER=mysql ORM_SOURCE="root:@/orm_test?charset=utf8" + - ORM_DRIVER=postgres ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" +install: + - go get github.com/lib/pq + - go get github.com/go-sql-driver/mysql + - go get github.com/mattn/go-sqlite3 + - go get github.com/bradfitz/gomemcache/memcache + - go get github.com/garyburd/redigo/redis + - go get github.com/beego/x2j + - go get github.com/beego/goyaml2 + - go get github.com/belogik/goes + - go get github.com/couchbase/go-couchbase + - go get github.com/siddontang/ledisdb/config + - go get github.com/siddontang/ledisdb/ledis +before_script: + - sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi" + - sh -c "if [ '$ORM_DRIVER' = 'mysql' ]; then mysql -u root -e 'create database orm_test;'; fi" + - sh -c "if [ '$ORM_DRIVER' = 'sqlite' ]; then touch $TRAVIS_BUILD_DIR/orm_test.db; fi" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..9d511616 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,52 @@ +# Contributing to beego + +beego is an open source project. + +It is the work of hundreds of contributors. We appreciate your help! + +Here are instructions to get you started. They are probably not perfect, +please let us know if anything feels wrong or incomplete. + +## Contribution guidelines + +### Pull requests + +First of all. beego follow the gitflow. So please send you pull request +to **develop** branch. We will close the pull request to master branch. + +We are always happy to receive pull requests, and do our best to +review them as fast as possible. Not sure if that typo is worth a pull +request? Do it! We will appreciate it. + +If your pull request is not accepted on the first try, don't be +discouraged! Sometimes we can make a mistake, please do more explaining +for us. We will appreciate it. + +We're trying very hard to keep beego simple and fast. We don't want it +to do everything for everybody. This means that we might decide against +incorporating a new feature. But we will give you some advice on how to +do it in other way. + +### Create issues + +Any significant improvement should be documented as [a GitHub +issue](https://github.com/astaxie/beego/issues) before anybody +starts working on it. + +Also when filing an issue, make sure to answer these five questions: + +- What version of beego are you using (bee version)? +- What operating system and processor architecture are you using? +- What did you do? +- What did you expect to see? +- What did you see instead? + +### but check existing issues and docs first! + +Please take a moment to check that an issue doesn't already exist +documenting your bug report or improvement proposal. If it does, it +never hurts to add a quick "+1" or "I have this problem too". This will +help prioritize the most common problems and requests. + +Also if you don't know how to use it. please make sure you have read though +the docs in http://beego.me/docs \ No newline at end of file diff --git a/README.md b/README.md index 7b650887..fec6113f 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,38 @@ ## Beego -[![Build Status](https://drone.io/github.com/astaxie/beego/status.png)](https://drone.io/github.com/astaxie/beego/latest) +[![Build Status](https://travis-ci.org/astaxie/beego.svg?branch=master)](https://travis-ci.org/astaxie/beego) [![GoDoc](http://godoc.org/github.com/astaxie/beego?status.svg)](http://godoc.org/github.com/astaxie/beego) -beego is an open-source, high-performance, modular, full-stack web framework. +beego is used for rapid development of RESTful APIs, web apps and backend services in Go. +It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific features such as interfaces and struct embedding. More info [beego.me](http://beego.me) -## Installation +##Quick Start +######Download and install go get github.com/astaxie/beego +######Create file `hello.go` +```go +package main + +import "github.com/astaxie/beego" + +func main(){ + beego.Run() +} +``` +######Build and run +```bash + go build hello.go + ./hello +``` +######Congratulations! +You just built your first beego app. +Open your browser and visit `http://localhost:8000`. +Please see [Documentation](http://beego.me/docs) for more. + ## Features * RESTful support @@ -26,6 +48,7 @@ More info [beego.me](http://beego.me) * [English](http://beego.me/docs/intro/) * [中文文档](http://beego.me/docs/intro/) +* [Русский](http://beego.me/docs/intro/) ## Community @@ -33,5 +56,5 @@ More info [beego.me](http://beego.me) ## LICENSE -beego is licensed under the Apache Licence, Version 2.0 +beego source code is licensed under the Apache Licence, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0.html). diff --git a/admin.go b/admin.go index 64d7fe34..3effc582 100644 --- a/admin.go +++ b/admin.go @@ -65,24 +65,15 @@ func init() { // AdminIndex is the default http.Handler for admin module. // it matches url pattern "/". func adminIndex(rw http.ResponseWriter, r *http.Request) { - tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) - tmpl = template.Must(tmpl.Parse(indexTpl)) - tmpl = template.Must(tmpl.Parse(defaultScriptsTpl)) - data := make(map[interface{}]interface{}) - tmpl.Execute(rw, data) + execTpl(rw, map[interface{}]interface{}{}, indexTpl, defaultScriptsTpl) } // QpsIndex is the http.Handler for writing qbs statistics map result info in http.ResponseWriter. // it's registered with url pattern "/qbs" in admin module. func qpsIndex(rw http.ResponseWriter, r *http.Request) { - tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) - tmpl = template.Must(tmpl.Parse(qpsTpl)) - tmpl = template.Must(tmpl.Parse(defaultScriptsTpl)) data := make(map[interface{}]interface{}) data["Content"] = toolbox.StatisticsMap.GetMap() - - tmpl.Execute(rw, data) - + execTpl(rw, data, qpsTpl, defaultScriptsTpl) } // ListConf is the http.Handler of displaying all beego configuration values as key/value pair. @@ -90,178 +81,145 @@ func qpsIndex(rw http.ResponseWriter, r *http.Request) { func listConf(rw http.ResponseWriter, r *http.Request) { r.ParseForm() command := r.Form.Get("command") - if command != "" { - data := make(map[interface{}]interface{}) - switch command { - case "conf": - m := make(map[string]interface{}) + if command == "" { + rw.Write([]byte("command not support")) + return + } - m["AppName"] = AppName - m["AppPath"] = AppPath - m["AppConfigPath"] = AppConfigPath - m["StaticDir"] = StaticDir - m["StaticExtensionsToGzip"] = StaticExtensionsToGzip - m["HttpAddr"] = HttpAddr - m["HttpPort"] = HttpPort - m["HttpTLS"] = EnableHttpTLS - m["HttpCertFile"] = HttpCertFile - m["HttpKeyFile"] = HttpKeyFile - m["RecoverPanic"] = RecoverPanic - m["AutoRender"] = AutoRender - m["ViewsPath"] = ViewsPath - m["RunMode"] = RunMode - m["SessionOn"] = SessionOn - m["SessionProvider"] = SessionProvider - m["SessionName"] = SessionName - m["SessionGCMaxLifetime"] = SessionGCMaxLifetime - m["SessionSavePath"] = SessionSavePath - m["SessionCookieLifeTime"] = SessionCookieLifeTime - m["UseFcgi"] = UseFcgi - m["MaxMemory"] = MaxMemory - m["EnableGzip"] = EnableGzip - m["DirectoryIndex"] = DirectoryIndex - m["HttpServerTimeOut"] = HttpServerTimeOut - m["ErrorsShow"] = ErrorsShow - m["XSRFKEY"] = XSRFKEY - m["EnableXSRF"] = EnableXSRF - m["XSRFExpire"] = XSRFExpire - m["CopyRequestBody"] = CopyRequestBody - m["TemplateLeft"] = TemplateLeft - m["TemplateRight"] = TemplateRight - m["BeegoServerName"] = BeegoServerName - m["EnableAdmin"] = EnableAdmin - m["AdminHttpAddr"] = AdminHttpAddr - m["AdminHttpPort"] = AdminHttpPort + data := make(map[interface{}]interface{}) + switch command { + case "conf": + m := make(map[string]interface{}) + m["AppConfigPath"] = AppConfigPath + m["AppConfigProvider"] = AppConfigProvider + m["BConfig.AppName"] = BConfig.AppName + m["BConfig.RunMode"] = BConfig.RunMode + m["BConfig.RouterCaseSensitive"] = BConfig.RouterCaseSensitive + m["BConfig.ServerName"] = BConfig.ServerName + m["BConfig.RecoverPanic"] = BConfig.RecoverPanic + m["BConfig.CopyRequestBody"] = BConfig.CopyRequestBody + m["BConfig.EnableGzip"] = BConfig.EnableGzip + m["BConfig.MaxMemory"] = BConfig.MaxMemory + m["BConfig.EnableErrorsShow"] = BConfig.EnableErrorsShow + m["BConfig.Listen.Graceful"] = BConfig.Listen.Graceful + m["BConfig.Listen.ServerTimeOut"] = BConfig.Listen.ServerTimeOut + m["BConfig.Listen.ListenTCP4"] = BConfig.Listen.ListenTCP4 + m["BConfig.Listen.EnableHTTP"] = BConfig.Listen.EnableHTTP + m["BConfig.Listen.HTTPAddr"] = BConfig.Listen.HTTPAddr + m["BConfig.Listen.HTTPPort"] = BConfig.Listen.HTTPPort + m["BConfig.Listen.EnableHTTPS"] = BConfig.Listen.EnableHTTPS + m["BConfig.Listen.HTTPSAddr"] = BConfig.Listen.HTTPSAddr + m["BConfig.Listen.HTTPSPort"] = BConfig.Listen.HTTPSPort + m["BConfig.Listen.HTTPSCertFile"] = BConfig.Listen.HTTPSCertFile + m["BConfig.Listen.HTTPSKeyFile"] = BConfig.Listen.HTTPSKeyFile + m["BConfig.Listen.EnableAdmin"] = BConfig.Listen.EnableAdmin + m["BConfig.Listen.AdminAddr"] = BConfig.Listen.AdminAddr + m["BConfig.Listen.AdminPort"] = BConfig.Listen.AdminPort + m["BConfig.Listen.EnableFcgi"] = BConfig.Listen.EnableFcgi + m["BConfig.Listen.EnableStdIo"] = BConfig.Listen.EnableStdIo + m["BConfig.WebConfig.AutoRender"] = BConfig.WebConfig.AutoRender + m["BConfig.WebConfig.EnableDocs"] = BConfig.WebConfig.EnableDocs + m["BConfig.WebConfig.FlashName"] = BConfig.WebConfig.FlashName + m["BConfig.WebConfig.FlashSeparator"] = BConfig.WebConfig.FlashSeparator + m["BConfig.WebConfig.DirectoryIndex"] = BConfig.WebConfig.DirectoryIndex + m["BConfig.WebConfig.StaticDir"] = BConfig.WebConfig.StaticDir + m["BConfig.WebConfig.StaticExtensionsToGzip"] = BConfig.WebConfig.StaticExtensionsToGzip + m["BConfig.WebConfig.TemplateLeft"] = BConfig.WebConfig.TemplateLeft + m["BConfig.WebConfig.TemplateRight"] = BConfig.WebConfig.TemplateRight + m["BConfig.WebConfig.ViewsPath"] = BConfig.WebConfig.ViewsPath + m["BConfig.WebConfig.EnableXSRF"] = BConfig.WebConfig.EnableXSRF + m["BConfig.WebConfig.XSRFKEY"] = BConfig.WebConfig.XSRFKey + m["BConfig.WebConfig.XSRFExpire"] = BConfig.WebConfig.XSRFExpire + m["BConfig.WebConfig.Session.SessionOn"] = BConfig.WebConfig.Session.SessionOn + m["BConfig.WebConfig.Session.SessionProvider"] = BConfig.WebConfig.Session.SessionProvider + m["BConfig.WebConfig.Session.SessionName"] = BConfig.WebConfig.Session.SessionName + m["BConfig.WebConfig.Session.SessionGCMaxLifetime"] = BConfig.WebConfig.Session.SessionGCMaxLifetime + m["BConfig.WebConfig.Session.SessionProviderConfig"] = BConfig.WebConfig.Session.SessionProviderConfig + m["BConfig.WebConfig.Session.SessionCookieLifeTime"] = BConfig.WebConfig.Session.SessionCookieLifeTime + m["BConfig.WebConfig.Session.SessionAutoSetCookie"] = BConfig.WebConfig.Session.SessionAutoSetCookie + m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain + m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs + m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum + m["BConfig.Log.Outputs"] = BConfig.Log.Outputs + tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) + tmpl = template.Must(tmpl.Parse(configTpl)) + tmpl = template.Must(tmpl.Parse(defaultScriptsTpl)) - tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) - tmpl = template.Must(tmpl.Parse(configTpl)) - tmpl = template.Must(tmpl.Parse(defaultScriptsTpl)) + data["Content"] = m - data["Content"] = m + tmpl.Execute(rw, data) - tmpl.Execute(rw, data) - - case "router": - content := make(map[string]interface{}) - - var fields = []string{ - fmt.Sprintf("Router Pattern"), - fmt.Sprintf("Methods"), - fmt.Sprintf("Controller"), + case "router": + var ( + content = map[string]interface{}{ + "Fields": []string{ + "Router Pattern", + "Methods", + "Controller", + }, } - content["Fields"] = fields + methods = []string{} + methodsData = make(map[string]interface{}) + ) + for method, t := range BeeApp.Handlers.routers { - methods := []string{} - methodsData := make(map[string]interface{}) - for method, t := range BeeApp.Handlers.routers { + resultList := new([][]string) - resultList := new([][]string) + printTree(resultList, t) - printTree(resultList, t) - - methods = append(methods, method) - methodsData[method] = resultList - } - - content["Data"] = methodsData - content["Methods"] = methods - data["Content"] = content - data["Title"] = "Routers" - - tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) - tmpl = template.Must(tmpl.Parse(routerAndFilterTpl)) - tmpl = template.Must(tmpl.Parse(defaultScriptsTpl)) - tmpl.Execute(rw, data) - case "filter": - content := make(map[string]interface{}) - - var fields = []string{ - fmt.Sprintf("Router Pattern"), - fmt.Sprintf("Filter Function"), - } - content["Fields"] = fields - - filterTypes := []string{} - filterTypeData := make(map[string]interface{}) - - if BeeApp.Handlers.enableFilter { - var filterType string - - if bf, ok := BeeApp.Handlers.filters[BeforeRouter]; ok { - filterType = "Before Router" - filterTypes = append(filterTypes, filterType) - resultList := new([][]string) - for _, f := range bf { - - var result = []string{ - fmt.Sprintf("%s", f.pattern), - fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)), - } - *resultList = append(*resultList, result) - } - filterTypeData[filterType] = resultList - } - - if bf, ok := BeeApp.Handlers.filters[BeforeExec]; ok { - filterType = "Before Exec" - filterTypes = append(filterTypes, filterType) - resultList := new([][]string) - for _, f := range bf { - - var result = []string{ - fmt.Sprintf("%s", f.pattern), - fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)), - } - *resultList = append(*resultList, result) - } - filterTypeData[filterType] = resultList - } - - if bf, ok := BeeApp.Handlers.filters[AfterExec]; ok { - filterType = "After Exec" - filterTypes = append(filterTypes, filterType) - resultList := new([][]string) - for _, f := range bf { - - var result = []string{ - fmt.Sprintf("%s", f.pattern), - fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)), - } - *resultList = append(*resultList, result) - } - filterTypeData[filterType] = resultList - } - - if bf, ok := BeeApp.Handlers.filters[FinishRouter]; ok { - filterType = "Finish Router" - filterTypes = append(filterTypes, filterType) - resultList := new([][]string) - for _, f := range bf { - - var result = []string{ - fmt.Sprintf("%s", f.pattern), - fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)), - } - *resultList = append(*resultList, result) - } - filterTypeData[filterType] = resultList - } - } - - content["Data"] = filterTypeData - content["Methods"] = filterTypes - - data["Content"] = content - data["Title"] = "Filters" - tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) - tmpl = template.Must(tmpl.Parse(routerAndFilterTpl)) - tmpl = template.Must(tmpl.Parse(defaultScriptsTpl)) - tmpl.Execute(rw, data) - - default: - rw.Write([]byte("command not support")) + methods = append(methods, method) + methodsData[method] = resultList } - } else { + + content["Data"] = methodsData + content["Methods"] = methods + data["Content"] = content + data["Title"] = "Routers" + execTpl(rw, data, routerAndFilterTpl, defaultScriptsTpl) + case "filter": + var ( + content = map[string]interface{}{ + "Fields": []string{ + "Router Pattern", + "Filter Function", + }, + } + filterTypes = []string{} + filterTypeData = make(map[string]interface{}) + ) + + if BeeApp.Handlers.enableFilter { + var filterType string + for k, fr := range map[int]string{ + BeforeStatic: "Before Static", + BeforeRouter: "Before Router", + BeforeExec: "Before Exec", + AfterExec: "After Exec", + FinishRouter: "Finish Router"} { + if bf, ok := BeeApp.Handlers.filters[k]; ok { + filterType = fr + filterTypes = append(filterTypes, filterType) + resultList := new([][]string) + for _, f := range bf { + var result = []string{ + fmt.Sprintf("%s", f.pattern), + fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)), + } + *resultList = append(*resultList, result) + } + filterTypeData[filterType] = resultList + } + } + } + + content["Data"] = filterTypeData + content["Methods"] = filterTypes + + data["Content"] = content + data["Title"] = "Filters" + execTpl(rw, data, routerAndFilterTpl, defaultScriptsTpl) + default: + rw.Write([]byte("command not support")) } } @@ -276,23 +234,23 @@ func printTree(resultList *[][]string, t *Tree) { if v, ok := l.runObject.(*controllerInfo); ok { if v.routerType == routerTypeBeego { var result = []string{ - fmt.Sprintf("%s", v.pattern), + v.pattern, fmt.Sprintf("%s", v.methods), fmt.Sprintf("%s", v.controllerType), } *resultList = append(*resultList, result) } else if v.routerType == routerTypeRESTFul { var result = []string{ - fmt.Sprintf("%s", v.pattern), + v.pattern, fmt.Sprintf("%s", v.methods), - fmt.Sprintf(""), + "", } *resultList = append(*resultList, result) } else if v.routerType == routerTypeHandler { var result = []string{ - fmt.Sprintf("%s", v.pattern), - fmt.Sprintf(""), - fmt.Sprintf(""), + v.pattern, + "", + "", } *resultList = append(*resultList, result) } @@ -305,53 +263,49 @@ func printTree(resultList *[][]string, t *Tree) { func profIndex(rw http.ResponseWriter, r *http.Request) { r.ParseForm() command := r.Form.Get("command") - format := r.Form.Get("format") - data := make(map[string]interface{}) + if command == "" { + return + } - var result bytes.Buffer - if command != "" { - toolbox.ProcessInput(command, &result) - data["Content"] = result.String() + var ( + format = r.Form.Get("format") + data = make(map[interface{}]interface{}) + result bytes.Buffer + ) + toolbox.ProcessInput(command, &result) + data["Content"] = result.String() - if format == "json" && command == "gc summary" { - dataJson, err := json.Marshal(data) - if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - return - } - - rw.Header().Set("Content-Type", "application/json") - rw.Write(dataJson) + if format == "json" && command == "gc summary" { + dataJSON, err := json.Marshal(data) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) return } - data["Title"] = command - tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) - tmpl = template.Must(tmpl.Parse(profillingTpl)) - if command == "gc summary" { - tmpl = template.Must(tmpl.Parse(gcAjaxTpl)) - } else { - - tmpl = template.Must(tmpl.Parse(defaultScriptsTpl)) - } - tmpl.Execute(rw, data) + rw.Header().Set("Content-Type", "application/json") + rw.Write(dataJSON) + return } + + data["Title"] = command + defaultTpl := defaultScriptsTpl + if command == "gc summary" { + defaultTpl = gcAjaxTpl + } + execTpl(rw, data, profillingTpl, defaultTpl) } // Healthcheck is a http.Handler calling health checking and showing the result. // it's in "/healthcheck" pattern in admin module. func healthcheck(rw http.ResponseWriter, req *http.Request) { - data := make(map[interface{}]interface{}) - - var result = []string{} - fields := []string{ - fmt.Sprintf("Name"), - fmt.Sprintf("Message"), - fmt.Sprintf("Status"), - } - resultList := new([][]string) - - content := make(map[string]interface{}) + var ( + data = make(map[interface{}]interface{}) + result = []string{} + resultList = new([][]string) + content = map[string]interface{}{ + "Fields": []string{"Name", "Message", "Status"}, + } + ) for name, h := range toolbox.AdminCheckList { if err := h.Check(); err != nil { @@ -371,16 +325,10 @@ func healthcheck(rw http.ResponseWriter, req *http.Request) { } *resultList = append(*resultList, result) } - - content["Fields"] = fields content["Data"] = resultList data["Content"] = content data["Title"] = "Health Check" - tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) - tmpl = template.Must(tmpl.Parse(healthCheckTpl)) - tmpl = template.Must(tmpl.Parse(defaultScriptsTpl)) - tmpl.Execute(rw, data) - + execTpl(rw, data, healthCheckTpl, defaultScriptsTpl) } // TaskStatus is a http.Handler with running task status (task name, status and the last execution). @@ -392,10 +340,8 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) { req.ParseForm() taskname := req.Form.Get("taskname") if taskname != "" { - if t, ok := toolbox.AdminTaskList[taskname]; ok { - err := t.Run() - if err != nil { + if err := t.Run(); err != nil { data["Message"] = []string{"error", fmt.Sprintf("%s", err)} } data["Message"] = []string{"success", fmt.Sprintf("%s run success,Now the Status is
%s", taskname, t.GetStatus())} @@ -409,18 +355,18 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) { resultList := new([][]string) var result = []string{} var fields = []string{ - fmt.Sprintf("Task Name"), - fmt.Sprintf("Task Spec"), - fmt.Sprintf("Task Status"), - fmt.Sprintf("Last Time"), - fmt.Sprintf(""), + "Task Name", + "Task Spec", + "Task Status", + "Last Time", + "", } for tname, tk := range toolbox.AdminTaskList { result = []string{ - fmt.Sprintf("%s", tname), + tname, fmt.Sprintf("%s", tk.GetSpec()), fmt.Sprintf("%s", tk.GetStatus()), - fmt.Sprintf("%s", tk.GetPrev().String()), + tk.GetPrev().String(), } *resultList = append(*resultList, result) } @@ -429,9 +375,14 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) { content["Data"] = resultList data["Content"] = content data["Title"] = "Tasks" + execTpl(rw, data, tasksTpl, defaultScriptsTpl) +} + +func execTpl(rw http.ResponseWriter, data map[interface{}]interface{}, tpls ...string) { tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) - tmpl = template.Must(tmpl.Parse(tasksTpl)) - tmpl = template.Must(tmpl.Parse(defaultScriptsTpl)) + for _, tpl := range tpls { + tmpl = template.Must(tmpl.Parse(tpl)) + } tmpl.Execute(rw, data) } @@ -451,10 +402,10 @@ func (admin *adminApp) Run() { if len(toolbox.AdminTaskList) > 0 { toolbox.StartTask() } - addr := AdminHttpAddr + addr := BConfig.Listen.AdminAddr - if AdminHttpPort != 0 { - addr = fmt.Sprintf("%s:%d", AdminHttpAddr, AdminHttpPort) + if BConfig.Listen.AdminPort != 0 { + addr = fmt.Sprintf("%s:%d", BConfig.Listen.AdminAddr, BConfig.Listen.AdminPort) } for p, f := range admin.routers { http.Handle(p, f) @@ -462,7 +413,7 @@ func (admin *adminApp) Run() { BeeLogger.Info("Admin server Running on %s", addr) var err error - if Graceful { + if BConfig.Listen.Graceful { err = grace.ListenAndServe(addr, nil) } else { err = http.ListenAndServe(addr, nil) diff --git a/app.go b/app.go index 8fc320ad..af54ea4b 100644 --- a/app.go +++ b/app.go @@ -20,15 +20,26 @@ import ( "net/http" "net/http/fcgi" "os" + "path" "time" "github.com/astaxie/beego/grace" "github.com/astaxie/beego/utils" ) +var ( + // BeeApp is an application instance + BeeApp *App +) + +func init() { + // create beego application + BeeApp = NewApp() +} + // App defines beego application with a new PatternServeMux. type App struct { - Handlers *ControllerRegistor + Handlers *ControllerRegister Server *http.Server } @@ -41,132 +52,311 @@ func NewApp() *App { // Run beego application. func (app *App) Run() { - addr := HttpAddr + addr := BConfig.Listen.HTTPAddr - if HttpPort != 0 { - addr = fmt.Sprintf("%s:%d", HttpAddr, HttpPort) + if BConfig.Listen.HTTPPort != 0 { + addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPAddr, BConfig.Listen.HTTPPort) } var ( - err error - l net.Listener + err error + l net.Listener + endRunning = make(chan bool, 1) ) - endRunning := make(chan bool, 1) - if UseFcgi { - if UseStdIo { - err = fcgi.Serve(nil, app.Handlers) // standard I/O - if err == nil { + // run cgi server + if BConfig.Listen.EnableFcgi { + if BConfig.Listen.EnableStdIo { + if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O BeeLogger.Info("Use FCGI via standard I/O") } else { - BeeLogger.Info("Cannot use FCGI via standard I/O", err) + BeeLogger.Critical("Cannot use FCGI via standard I/O", err) } + return + } + if BConfig.Listen.HTTPPort == 0 { + // remove the Socket file before start + if utils.FileExists(addr) { + os.Remove(addr) + } + l, err = net.Listen("unix", addr) } else { - if HttpPort == 0 { - // remove the Socket file before start - if utils.FileExists(addr) { - os.Remove(addr) + l, err = net.Listen("tcp", addr) + } + if err != nil { + BeeLogger.Critical("Listen: ", err) + } + if err = fcgi.Serve(l, app.Handlers); err != nil { + BeeLogger.Critical("fcgi.Serve: ", err) + } + return + } + + app.Server.Handler = app.Handlers + app.Server.ReadTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second + app.Server.WriteTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second + + // run graceful mode + if BConfig.Listen.Graceful { + httpsAddr := BConfig.Listen.HTTPSAddr + app.Server.Addr = httpsAddr + if BConfig.Listen.EnableHTTPS { + go func() { + time.Sleep(20 * time.Microsecond) + if BConfig.Listen.HTTPSPort != 0 { + httpsAddr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) + app.Server.Addr = httpsAddr + } + server := grace.NewServer(httpsAddr, app.Handlers) + server.Server.ReadTimeout = app.Server.ReadTimeout + server.Server.WriteTimeout = app.Server.WriteTimeout + if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { + BeeLogger.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) + time.Sleep(100 * time.Microsecond) + endRunning <- true + } + }() + } + if BConfig.Listen.EnableHTTP { + go func() { + server := grace.NewServer(addr, app.Handlers) + server.Server.ReadTimeout = app.Server.ReadTimeout + server.Server.WriteTimeout = app.Server.WriteTimeout + if BConfig.Listen.ListenTCP4 { + server.Network = "tcp4" + } + if err := server.ListenAndServe(); err != nil { + BeeLogger.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) + time.Sleep(100 * time.Microsecond) + endRunning <- true + } + }() + } + <-endRunning + return + } + + // run normal mode + app.Server.Addr = addr + if BConfig.Listen.EnableHTTPS { + go func() { + time.Sleep(20 * time.Microsecond) + if BConfig.Listen.HTTPSPort != 0 { + app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) + } + BeeLogger.Info("https server Running on %s", app.Server.Addr) + if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { + BeeLogger.Critical("ListenAndServeTLS: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + } + }() + } + if BConfig.Listen.EnableHTTP { + go func() { + app.Server.Addr = addr + BeeLogger.Info("http server Running on %s", app.Server.Addr) + if BConfig.Listen.ListenTCP4 { + ln, err := net.Listen("tcp4", app.Server.Addr) + if err != nil { + BeeLogger.Critical("ListenAndServe: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + return + } + if err = app.Server.Serve(ln); err != nil { + BeeLogger.Critical("ListenAndServe: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + return } - l, err = net.Listen("unix", addr) } else { - l, err = net.Listen("tcp", addr) + if err := app.Server.ListenAndServe(); err != nil { + BeeLogger.Critical("ListenAndServe: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + } } - if err != nil { - BeeLogger.Critical("Listen: ", err) - } - err = fcgi.Serve(l, app.Handlers) - } - } else { - if Graceful { - app.Server.Addr = addr - app.Server.Handler = app.Handlers - app.Server.ReadTimeout = time.Duration(HttpServerTimeOut) * time.Second - app.Server.WriteTimeout = time.Duration(HttpServerTimeOut) * time.Second - if EnableHttpTLS { - go func() { - time.Sleep(20 * time.Microsecond) - if HttpsPort != 0 { - addr = fmt.Sprintf("%s:%d", HttpAddr, HttpsPort) - app.Server.Addr = addr - } - server := grace.NewServer(addr, app.Handlers) - server.Server = app.Server - err := server.ListenAndServeTLS(HttpCertFile, HttpKeyFile) - if err != nil { - BeeLogger.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) - time.Sleep(100 * time.Microsecond) - endRunning <- true - } - }() - } - if EnableHttpListen { - go func() { - server := grace.NewServer(addr, app.Handlers) - server.Server = app.Server - if ListenTCP4 && HttpAddr == "" { - server.Network = "tcp4" - } - err := server.ListenAndServe() - if err != nil { - BeeLogger.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) - time.Sleep(100 * time.Microsecond) - endRunning <- true - } - }() - } - } else { - app.Server.Addr = addr - app.Server.Handler = app.Handlers - app.Server.ReadTimeout = time.Duration(HttpServerTimeOut) * time.Second - app.Server.WriteTimeout = time.Duration(HttpServerTimeOut) * time.Second - - if EnableHttpTLS { - go func() { - time.Sleep(20 * time.Microsecond) - if HttpsPort != 0 { - app.Server.Addr = fmt.Sprintf("%s:%d", HttpAddr, HttpsPort) - } - BeeLogger.Info("https server Running on %s", app.Server.Addr) - err := app.Server.ListenAndServeTLS(HttpCertFile, HttpKeyFile) - if err != nil { - BeeLogger.Critical("ListenAndServeTLS: ", err) - time.Sleep(100 * time.Microsecond) - endRunning <- true - } - }() - } - - if EnableHttpListen { - go func() { - app.Server.Addr = addr - BeeLogger.Info("http server Running on %s", app.Server.Addr) - if ListenTCP4 && HttpAddr == "" { - ln, err := net.Listen("tcp4", app.Server.Addr) - if err != nil { - BeeLogger.Critical("ListenAndServe: ", err) - time.Sleep(100 * time.Microsecond) - endRunning <- true - return - } - err = app.Server.Serve(ln) - if err != nil { - BeeLogger.Critical("ListenAndServe: ", err) - time.Sleep(100 * time.Microsecond) - endRunning <- true - return - } - } else { - err := app.Server.ListenAndServe() - if err != nil { - BeeLogger.Critical("ListenAndServe: ", err) - time.Sleep(100 * time.Microsecond) - endRunning <- true - } - } - }() - } - } - + }() } <-endRunning } + +// Router adds a patterned controller handler to BeeApp. +// it's an alias method of App.Router. +// usage: +// simple router +// beego.Router("/admin", &admin.UserController{}) +// beego.Router("/admin/index", &admin.ArticleController{}) +// +// regex router +// +// beego.Router("/api/:id([0-9]+)", &controllers.RController{}) +// +// custom rules +// beego.Router("/api/list",&RestController{},"*:ListFood") +// beego.Router("/api/create",&RestController{},"post:CreateFood") +// beego.Router("/api/update",&RestController{},"put:UpdateFood") +// beego.Router("/api/delete",&RestController{},"delete:DeleteFood") +func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App { + BeeApp.Handlers.Add(rootpath, c, mappingMethods...) + return BeeApp +} + +// Include will generate router file in the router/xxx.go from the controller's comments +// usage: +// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) +// type BankAccount struct{ +// beego.Controller +// } +// +// register the function +// func (b *BankAccount)Mapping(){ +// b.Mapping("ShowAccount" , b.ShowAccount) +// b.Mapping("ModifyAccount", b.ModifyAccount) +//} +// +// //@router /account/:id [get] +// func (b *BankAccount) ShowAccount(){ +// //logic +// } +// +// +// //@router /account/:id [post] +// func (b *BankAccount) ModifyAccount(){ +// //logic +// } +// +// the comments @router url methodlist +// url support all the function Router's pattern +// methodlist [get post head put delete options *] +func Include(cList ...ControllerInterface) *App { + BeeApp.Handlers.Include(cList...) + return BeeApp +} + +// RESTRouter adds a restful controller handler to BeeApp. +// its' controller implements beego.ControllerInterface and +// defines a param "pattern/:objectId" to visit each resource. +func RESTRouter(rootpath string, c ControllerInterface) *App { + Router(rootpath, c) + Router(path.Join(rootpath, ":objectId"), c) + return BeeApp +} + +// AutoRouter adds defined controller handler to BeeApp. +// it's same to App.AutoRouter. +// if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page, +// visit the url /main/list to exec List function or /main/page to exec Page function. +func AutoRouter(c ControllerInterface) *App { + BeeApp.Handlers.AddAuto(c) + return BeeApp +} + +// AutoPrefix adds controller handler to BeeApp with prefix. +// it's same to App.AutoRouterWithPrefix. +// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page, +// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function. +func AutoPrefix(prefix string, c ControllerInterface) *App { + BeeApp.Handlers.AddAutoPrefix(prefix, c) + return BeeApp +} + +// Get used to register router for Get method +// usage: +// beego.Get("/", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Get(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Get(rootpath, f) + return BeeApp +} + +// Post used to register router for Post method +// usage: +// beego.Post("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Post(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Post(rootpath, f) + return BeeApp +} + +// Delete used to register router for Delete method +// usage: +// beego.Delete("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Delete(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Delete(rootpath, f) + return BeeApp +} + +// Put used to register router for Put method +// usage: +// beego.Put("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Put(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Put(rootpath, f) + return BeeApp +} + +// Head used to register router for Head method +// usage: +// beego.Head("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Head(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Head(rootpath, f) + return BeeApp +} + +// Options used to register router for Options method +// usage: +// beego.Options("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Options(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Options(rootpath, f) + return BeeApp +} + +// Patch used to register router for Patch method +// usage: +// beego.Patch("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Patch(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Patch(rootpath, f) + return BeeApp +} + +// Any used to register router for all methods +// usage: +// beego.Any("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Any(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Any(rootpath, f) + return BeeApp +} + +// Handler used to register a Handler router +// usage: +// beego.Handler("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Handler(rootpath string, h http.Handler, options ...interface{}) *App { + BeeApp.Handlers.Handler(rootpath, h, options...) + return BeeApp +} + +// InsertFilter adds a FilterFunc with pattern condition and action constant. +// The pos means action constant including +// beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. +// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) +func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App { + BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...) + return BeeApp +} diff --git a/beego.go b/beego.go index cfebfbea..04f02071 100644 --- a/beego.go +++ b/beego.go @@ -12,243 +12,35 @@ // See the License for the specific language governing permissions and // limitations under the License. -// beego is an open-source, high-performance, modularity, full-stack web framework -// -// package main -// -// import "github.com/astaxie/beego" -// -// func main() { -// beego.Run() -// } -// -// more infomation: http://beego.me package beego import ( - "net/http" + "fmt" "os" - "path" "path/filepath" "strconv" "strings" - - "github.com/astaxie/beego/session" ) -// beego web framework version. -const VERSION = "1.5.0" +const ( + // VERSION represent beego web framework version. + VERSION = "1.6.0" -type hookfunc func() error //hook function to run -var hooks []hookfunc //hook function slice to store the hookfunc + // DEV is for develop + DEV = "dev" + // PROD is for production + PROD = "prod" +) -// Router adds a patterned controller handler to BeeApp. -// it's an alias method of App.Router. -// usage: -// simple router -// beego.Router("/admin", &admin.UserController{}) -// beego.Router("/admin/index", &admin.ArticleController{}) -// -// regex router -// -// beego.Router("/api/:id([0-9]+)", &controllers.RController{}) -// -// custom rules -// beego.Router("/api/list",&RestController{},"*:ListFood") -// beego.Router("/api/create",&RestController{},"post:CreateFood") -// beego.Router("/api/update",&RestController{},"put:UpdateFood") -// beego.Router("/api/delete",&RestController{},"delete:DeleteFood") -func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App { - BeeApp.Handlers.Add(rootpath, c, mappingMethods...) - return BeeApp -} +//hook function to run +type hookfunc func() error -// Router add list from -// usage: -// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) -// type BankAccount struct{ -// beego.Controller -// } -// -// register the function -// func (b *BankAccount)Mapping(){ -// b.Mapping("ShowAccount" , b.ShowAccount) -// b.Mapping("ModifyAccount", b.ModifyAccount) -//} -// -// //@router /account/:id [get] -// func (b *BankAccount) ShowAccount(){ -// //logic -// } -// -// -// //@router /account/:id [post] -// func (b *BankAccount) ModifyAccount(){ -// //logic -// } -// -// the comments @router url methodlist -// url support all the function Router's pattern -// methodlist [get post head put delete options *] -func Include(cList ...ControllerInterface) *App { - BeeApp.Handlers.Include(cList...) - return BeeApp -} +var ( + hooks = make([]hookfunc, 0) //hook function slice to store the hookfunc +) -// RESTRouter adds a restful controller handler to BeeApp. -// its' controller implements beego.ControllerInterface and -// defines a param "pattern/:objectId" to visit each resource. -func RESTRouter(rootpath string, c ControllerInterface) *App { - Router(rootpath, c) - Router(path.Join(rootpath, ":objectId"), c) - return BeeApp -} - -// AutoRouter adds defined controller handler to BeeApp. -// it's same to App.AutoRouter. -// if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page, -// visit the url /main/list to exec List function or /main/page to exec Page function. -func AutoRouter(c ControllerInterface) *App { - BeeApp.Handlers.AddAuto(c) - return BeeApp -} - -// AutoPrefix adds controller handler to BeeApp with prefix. -// it's same to App.AutoRouterWithPrefix. -// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page, -// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function. -func AutoPrefix(prefix string, c ControllerInterface) *App { - BeeApp.Handlers.AddAutoPrefix(prefix, c) - return BeeApp -} - -// register router for Get method -// usage: -// beego.Get("/", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -func Get(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Get(rootpath, f) - return BeeApp -} - -// register router for Post method -// usage: -// beego.Post("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -func Post(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Post(rootpath, f) - return BeeApp -} - -// register router for Delete method -// usage: -// beego.Delete("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -func Delete(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Delete(rootpath, f) - return BeeApp -} - -// register router for Put method -// usage: -// beego.Put("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -func Put(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Put(rootpath, f) - return BeeApp -} - -// register router for Head method -// usage: -// beego.Head("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -func Head(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Head(rootpath, f) - return BeeApp -} - -// register router for Options method -// usage: -// beego.Options("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -func Options(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Options(rootpath, f) - return BeeApp -} - -// register router for Patch method -// usage: -// beego.Patch("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -func Patch(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Patch(rootpath, f) - return BeeApp -} - -// register router for all method -// usage: -// beego.Any("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -func Any(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Any(rootpath, f) - return BeeApp -} - -// register router for own Handler -// usage: -// beego.Handler("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -func Handler(rootpath string, h http.Handler, options ...interface{}) *App { - BeeApp.Handlers.Handler(rootpath, h, options...) - return BeeApp -} - -// SetViewsPath sets view directory path in beego application. -func SetViewsPath(path string) *App { - ViewsPath = path - return BeeApp -} - -// SetStaticPath sets static directory path and proper url pattern in beego application. -// if beego.SetStaticPath("static","public"), visit /static/* to load static file in folder "public". -func SetStaticPath(url string, path string) *App { - if !strings.HasPrefix(url, "/") { - url = "/" + url - } - url = strings.TrimRight(url, "/") - StaticDir[url] = path - return BeeApp -} - -// DelStaticPath removes the static folder setting in this url pattern in beego application. -func DelStaticPath(url string) *App { - if !strings.HasPrefix(url, "/") { - url = "/" + url - } - url = strings.TrimRight(url, "/") - delete(StaticDir, url) - return BeeApp -} - -// InsertFilter adds a FilterFunc with pattern condition and action constant. -// The pos means action constant including -// beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. -// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) -func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App { - BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...) - return BeeApp -} - -// The hookfunc will run in beego.Run() +// AddAPPStartHook is used to register the hookfunc +// The hookfuncs will run in beego.Run() // such as sessionInit, middlerware start, buildtemplate, admin start func AddAPPStartHook(hf hookfunc) { hooks = append(hooks, hf) @@ -256,97 +48,60 @@ func AddAPPStartHook(hf hookfunc) { // Run beego application. // beego.Run() default run on HttpPort +// beego.Run("localhost") // beego.Run(":8089") // beego.Run("127.0.0.1:8089") func Run(params ...string) { + initBeforeHTTPRun() + if len(params) > 0 && params[0] != "" { strs := strings.Split(params[0], ":") if len(strs) > 0 && strs[0] != "" { - HttpAddr = strs[0] + BConfig.Listen.HTTPAddr = strs[0] } if len(strs) > 1 && strs[1] != "" { - HttpPort, _ = strconv.Atoi(strs[1]) + BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1]) } } - initBeforeHttpRun() - - if EnableAdmin { - go beeAdminApp.Run() - } BeeApp.Run() } -func initBeforeHttpRun() { - // if AppConfigPath not In the conf/app.conf reParse config - if AppConfigPath != filepath.Join(AppPath, "conf", "app.conf") { - err := ParseConfig() - if err != nil && AppConfigPath != filepath.Join(workPath, "conf", "app.conf") { - // configuration is critical to app, panic here if parse failed - panic(err) - } - } - - //init mime - AddAPPStartHook(initMime) - - // do hooks function - for _, hk := range hooks { - err := hk() - if err != nil { - panic(err) - } - } - - if SessionOn { - var err error - sessionConfig := AppConfig.String("sessionConfig") - if sessionConfig == "" { - sessionConfig = `{"cookieName":"` + SessionName + `",` + - `"gclifetime":` + strconv.FormatInt(SessionGCMaxLifetime, 10) + `,` + - `"providerConfig":"` + filepath.ToSlash(SessionSavePath) + `",` + - `"secure":` + strconv.FormatBool(EnableHttpTLS) + `,` + - `"enableSetCookie":` + strconv.FormatBool(SessionAutoSetCookie) + `,` + - `"domain":"` + SessionDomain + `",` + - `"cookieLifeTime":` + strconv.Itoa(SessionCookieLifeTime) + `}` - } - GlobalSessions, err = session.NewManager(SessionProvider, - sessionConfig) - if err != nil { - panic(err) - } - go GlobalSessions.GC() - } - - err := BuildTemplate(ViewsPath) - if err != nil { - if RunMode == "dev" { - Warn(err) - } - } - - registerDefaultErrorHandler() - - if EnableDocs { - Get("/docs", serverDocs) - Get("/docs/*", serverDocs) - } -} - -// this function is for test package init -func TestBeegoInit(apppath string) { - AppPath = apppath - os.Setenv("BEEGO_RUNMODE", "test") - AppConfigPath = filepath.Join(AppPath, "conf", "app.conf") +func initBeforeHTTPRun() { + // if AppConfigPath is setted or conf/app.conf exist err := ParseConfig() - if err != nil && !os.IsNotExist(err) { - // for init if doesn't have app.conf will not panic - Info(err) + if err != nil { + panic(err) + } + //init log + for adaptor, config := range BConfig.Log.Outputs { + err = BeeLogger.SetLogger(adaptor, config) + if err != nil { + fmt.Printf("%s with the config `%s` got err:%s\n", adaptor, config, err) + } + } + + SetLogFuncCall(BConfig.Log.FileLineNum) + + //init hooks + AddAPPStartHook(registerMime) + AddAPPStartHook(registerDefaultErrorHandler) + AddAPPStartHook(registerSession) + AddAPPStartHook(registerDocs) + AddAPPStartHook(registerTemplate) + AddAPPStartHook(registerAdmin) + + for _, hk := range hooks { + if err := hk(); err != nil { + panic(err) + } } - os.Chdir(AppPath) - initBeforeHttpRun() } -func init() { - hooks = make([]hookfunc, 0) +// TestBeegoInit is for test package init +func TestBeegoInit(ap string) { + os.Setenv("BEEGO_RUNMODE", "test") + AppConfigPath = filepath.Join(ap, "conf", "app.conf") + os.Chdir(ap) + initBeforeHTTPRun() } diff --git a/cache/README.md b/cache/README.md index 72d0d1c5..957790e7 100644 --- a/cache/README.md +++ b/cache/README.md @@ -26,7 +26,7 @@ Then init a Cache (example with memory adapter) Use it like this: - bm.Put("astaxie", 1, 10) + bm.Put("astaxie", 1, 10 * time.Second) bm.Get("astaxie") bm.IsExist("astaxie") bm.Delete("astaxie") @@ -43,7 +43,7 @@ interval means the gc time. The cache will check at each time interval, whether ## Memcache adapter -Memcache adapter use the vitess's [Memcache](http://code.google.com/p/vitess/go/memcache) client. +Memcache adapter use the [gomemcache](http://github.com/bradfitz/gomemcache) client. Configure like this: @@ -52,7 +52,7 @@ Configure like this: ## Redis adapter -Redis adapter use the [redigo](http://github.com/garyburd/redigo/redis) client. +Redis adapter use the [redigo](http://github.com/garyburd/redigo) client. Configure like this: diff --git a/cache/cache.go b/cache/cache.go index 7ca87802..2008402e 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package cache provide a Cache interface and some implemetn engine // Usage: // // import( @@ -22,7 +23,7 @@ // // Use it like this: // -// bm.Put("astaxie", 1, 10) +// bm.Put("astaxie", 1, 10 * time.Second) // bm.Get("astaxie") // bm.IsExist("astaxie") // bm.Delete("astaxie") @@ -32,13 +33,14 @@ package cache import ( "fmt" + "time" ) // Cache interface contains all behaviors for cache adapter. // usage: -// cache.Register("file",cache.NewFileCache()) // this operation is run in init method of file.go. +// cache.Register("file",cache.NewFileCache) // this operation is run in init method of file.go. // c,err := cache.NewCache("file","{....}") -// c.Put("key",value,3600) +// c.Put("key",value, 3600 * time.Second) // v := c.Get("key") // // c.Incr("counter") // now is 1 @@ -50,7 +52,7 @@ type Cache interface { // GetMulti is a batch version of Get. GetMulti(keys []string) []interface{} // set cached value with key and expire time. - Put(key string, val interface{}, timeout int64) error + Put(key string, val interface{}, timeout time.Duration) error // delete cached value by key. Delete(key string) error // increase cached int value by key, as a counter. @@ -65,12 +67,14 @@ type Cache interface { StartAndGC(config string) error } -var adapters = make(map[string]Cache) +type CacheInstance func() Cache + +var adapters = make(map[string]CacheInstance) // Register makes a cache adapter available by the adapter name. // If Register is called twice with the same name or if driver is nil, // it panics. -func Register(name string, adapter Cache) { +func Register(name string, adapter CacheInstance) { if adapter == nil { panic("cache: Register adapter is nil") } @@ -80,15 +84,16 @@ func Register(name string, adapter Cache) { adapters[name] = adapter } -// Create a new cache driver by adapter name and config string. +// NewCache Create a new cache driver by adapter name and config string. // config need to be correct JSON as string: {"interval":360}. // it will start gc automatically. func NewCache(adapterName, config string) (adapter Cache, err error) { - adapter, ok := adapters[adapterName] + instanceFunc, ok := adapters[adapterName] if !ok { err = fmt.Errorf("cache: unknown adapter name %q (forgot to import?)", adapterName) return } + adapter = instanceFunc() err = adapter.StartAndGC(config) if err != nil { adapter = nil diff --git a/cache/cache_test.go b/cache/cache_test.go index 481309fd..9ceb606a 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -25,7 +25,8 @@ func TestCache(t *testing.T) { if err != nil { t.Error("init err") } - if err = bm.Put("astaxie", 1, 10); err != nil { + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { t.Error("set Error", err) } if !bm.IsExist("astaxie") { @@ -42,7 +43,7 @@ func TestCache(t *testing.T) { t.Error("check err") } - if err = bm.Put("astaxie", 1, 10); err != nil { + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { t.Error("set Error", err) } @@ -67,7 +68,7 @@ func TestCache(t *testing.T) { } //test GetMulti - if err = bm.Put("astaxie", "author", 10); err != nil { + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { t.Error("set Error", err) } if !bm.IsExist("astaxie") { @@ -77,7 +78,7 @@ func TestCache(t *testing.T) { t.Error("get err") } - if err = bm.Put("astaxie1", "author1", 10); err != nil { + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { t.Error("set Error", err) } if !bm.IsExist("astaxie1") { @@ -101,7 +102,8 @@ func TestFileCache(t *testing.T) { if err != nil { t.Error("init err") } - if err = bm.Put("astaxie", 1, 10); err != nil { + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { t.Error("set Error", err) } if !bm.IsExist("astaxie") { @@ -133,7 +135,7 @@ func TestFileCache(t *testing.T) { } //test string - if err = bm.Put("astaxie", "author", 10); err != nil { + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { t.Error("set Error", err) } if !bm.IsExist("astaxie") { @@ -144,7 +146,7 @@ func TestFileCache(t *testing.T) { } //test GetMulti - if err = bm.Put("astaxie1", "author1", 10); err != nil { + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { t.Error("set Error", err) } if !bm.IsExist("astaxie1") { diff --git a/cache/conv.go b/cache/conv.go index 724abfd2..dbdff1c7 100644 --- a/cache/conv.go +++ b/cache/conv.go @@ -19,7 +19,7 @@ import ( "strconv" ) -// convert interface to string. +// GetString convert interface to string. func GetString(v interface{}) string { switch result := v.(type) { case string: @@ -34,7 +34,7 @@ func GetString(v interface{}) string { return "" } -// convert interface to int. +// GetInt convert interface to int. func GetInt(v interface{}) int { switch result := v.(type) { case int: @@ -52,7 +52,7 @@ func GetInt(v interface{}) int { return 0 } -// convert interface to int64. +// GetInt64 convert interface to int64. func GetInt64(v interface{}) int64 { switch result := v.(type) { case int: @@ -71,7 +71,7 @@ func GetInt64(v interface{}) int64 { return 0 } -// convert interface to float64. +// GetFloat64 convert interface to float64. func GetFloat64(v interface{}) float64 { switch result := v.(type) { case float64: @@ -85,7 +85,7 @@ func GetFloat64(v interface{}) float64 { return 0 } -// convert interface to bool. +// GetBool convert interface to bool. func GetBool(v interface{}) bool { switch result := v.(type) { case bool: @@ -98,15 +98,3 @@ func GetBool(v interface{}) bool { } return false } - -// convert interface to byte slice. -func getByteArray(v interface{}) []byte { - switch result := v.(type) { - case []byte: - return result - case string: - return []byte(result) - default: - return nil - } -} diff --git a/cache/conv_test.go b/cache/conv_test.go index 267bb0c9..cf792fa6 100644 --- a/cache/conv_test.go +++ b/cache/conv_test.go @@ -27,7 +27,7 @@ func TestGetString(t *testing.T) { if "test2" != GetString(t2) { t.Error("get string from byte array error") } - var t3 int = 1 + var t3 = 1 if "1" != GetString(t3) { t.Error("get string from int error") } @@ -35,7 +35,7 @@ func TestGetString(t *testing.T) { if "1" != GetString(t4) { t.Error("get string from int64 error") } - var t5 float64 = 1.1 + var t5 = 1.1 if "1.1" != GetString(t5) { t.Error("get string from float64 error") } @@ -46,7 +46,7 @@ func TestGetString(t *testing.T) { } func TestGetInt(t *testing.T) { - var t1 int = 1 + var t1 = 1 if 1 != GetInt(t1) { t.Error("get int from int error") } @@ -69,7 +69,7 @@ func TestGetInt(t *testing.T) { func TestGetInt64(t *testing.T) { var i int64 = 1 - var t1 int = 1 + var t1 = 1 if i != GetInt64(t1) { t.Error("get int64 from int error") } @@ -91,12 +91,12 @@ func TestGetInt64(t *testing.T) { } func TestGetFloat64(t *testing.T) { - var f float64 = 1.11 + var f = 1.11 var t1 float32 = 1.11 if f != GetFloat64(t1) { t.Error("get float64 from float32 error") } - var t2 float64 = 1.11 + var t2 = 1.11 if f != GetFloat64(t2) { t.Error("get float64 from float64 error") } @@ -106,7 +106,7 @@ func TestGetFloat64(t *testing.T) { } var f2 float64 = 1 - var t4 int = 1 + var t4 = 1 if f2 != GetFloat64(t4) { t.Error("get float64 from int error") } @@ -130,21 +130,6 @@ func TestGetBool(t *testing.T) { } } -func TestGetByteArray(t *testing.T) { - var b = []byte("test") - var t1 = []byte("test") - if !byteArrayEquals(b, getByteArray(t1)) { - t.Error("get byte array from byte array error") - } - var t2 = "test" - if !byteArrayEquals(b, getByteArray(t2)) { - t.Error("get byte array from string error") - } - if nil != getByteArray(nil) { - t.Error("get byte array from nil error") - } -} - func byteArrayEquals(a []byte, b []byte) bool { if len(a) != len(b) { return false diff --git a/cache/file.go b/cache/file.go index 65f114f3..3a7aa8b0 100644 --- a/cache/file.go +++ b/cache/file.go @@ -29,23 +29,20 @@ import ( "time" ) -func init() { - Register("file", NewFileCache()) -} - // FileCacheItem is basic unit of file cache adapter. // it contains data and expire time. type FileCacheItem struct { Data interface{} - Lastaccess int64 - Expired int64 + Lastaccess time.Time + Expired time.Time } +// FileCache Config var ( - FileCachePath string = "cache" // cache directory - FileCacheFileSuffix string = ".bin" // cache file suffix - FileCacheDirectoryLevel int = 2 // cache file deep level if auto generated cache files. - FileCacheEmbedExpiry int64 = 0 // cache expire time, default is no expire forever. + FileCachePath = "cache" // cache directory + FileCacheFileSuffix = ".bin" // cache file suffix + FileCacheDirectoryLevel = 2 // cache file deep level if auto generated cache files. + FileCacheEmbedExpiry time.Duration = 0 // cache expire time, default is no expire forever. ) // FileCache is cache adapter for file storage. @@ -56,14 +53,14 @@ type FileCache struct { EmbedExpiry int } -// Create new file cache with no config. +// NewFileCache Create new file cache with no config. // the level and expiry need set in method StartAndGC as config string. -func NewFileCache() *FileCache { +func NewFileCache() Cache { // return &FileCache{CachePath:FileCachePath, FileSuffix:FileCacheFileSuffix} return &FileCache{} } -// Start and begin gc for file cache. +// StartAndGC will start and begin gc for file cache. // the config need to be like {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":2,"EmbedExpiry":0} func (fc *FileCache) StartAndGC(config string) error { @@ -79,7 +76,7 @@ func (fc *FileCache) StartAndGC(config string) error { cfg["DirectoryLevel"] = strconv.Itoa(FileCacheDirectoryLevel) } if _, ok := cfg["EmbedExpiry"]; !ok { - cfg["EmbedExpiry"] = strconv.FormatInt(FileCacheEmbedExpiry, 10) + cfg["EmbedExpiry"] = strconv.FormatInt(int64(FileCacheEmbedExpiry.Seconds()), 10) } fc.CachePath = cfg["CachePath"] fc.FileSuffix = cfg["FileSuffix"] @@ -120,13 +117,13 @@ func (fc *FileCache) getCacheFileName(key string) string { // Get value from file cache. // if non-exist or expired, return empty string. func (fc *FileCache) Get(key string) interface{} { - fileData, err := File_get_contents(fc.getCacheFileName(key)) + fileData, err := FileGetContents(fc.getCacheFileName(key)) if err != nil { return "" } var to FileCacheItem - Gob_decode(fileData, &to) - if to.Expired < time.Now().Unix() { + GobDecode(fileData, &to) + if to.Expired.Before(time.Now()) { return "" } return to.Data @@ -145,21 +142,21 @@ func (fc *FileCache) GetMulti(keys []string) []interface{} { // Put value into file cache. // timeout means how long to keep this file, unit of ms. // if timeout equals FileCacheEmbedExpiry(default is 0), cache this item forever. -func (fc *FileCache) Put(key string, val interface{}, timeout int64) error { +func (fc *FileCache) Put(key string, val interface{}, timeout time.Duration) error { gob.Register(val) item := FileCacheItem{Data: val} if timeout == FileCacheEmbedExpiry { - item.Expired = time.Now().Unix() + (86400 * 365 * 10) // ten years + item.Expired = time.Now().Add((86400 * 365 * 10) * time.Second) // ten years } else { - item.Expired = time.Now().Unix() + timeout + item.Expired = time.Now().Add(timeout) } - item.Lastaccess = time.Now().Unix() - data, err := Gob_encode(item) + item.Lastaccess = time.Now() + data, err := GobEncode(item) if err != nil { return err } - return File_put_contents(fc.getCacheFileName(key), data) + return FilePutContents(fc.getCacheFileName(key), data) } // Delete file cache value. @@ -171,7 +168,7 @@ func (fc *FileCache) Delete(key string) error { return nil } -// Increase cached int value. +// Incr will increase cached int value. // fc value is saving forever unless Delete. func (fc *FileCache) Incr(key string) error { data := fc.Get(key) @@ -185,7 +182,7 @@ func (fc *FileCache) Incr(key string) error { return nil } -// Decrease cached int value. +// Decr will decrease cached int value. func (fc *FileCache) Decr(key string) error { data := fc.Get(key) var decr int @@ -198,13 +195,13 @@ func (fc *FileCache) Decr(key string) error { return nil } -// Check value is exist. +// IsExist check value is exist. func (fc *FileCache) IsExist(key string) bool { ret, _ := exists(fc.getCacheFileName(key)) return ret } -// Clean cached files. +// ClearAll will clean cached files. // not implemented. func (fc *FileCache) ClearAll() error { return nil @@ -222,9 +219,9 @@ func exists(path string) (bool, error) { return false, err } -// Get bytes to file. +// FileGetContents Get bytes to file. // if non-exist, create this file. -func File_get_contents(filename string) (data []byte, e error) { +func FileGetContents(filename string) (data []byte, e error) { f, e := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, os.ModePerm) if e != nil { return @@ -242,9 +239,9 @@ func File_get_contents(filename string) (data []byte, e error) { return } -// Put bytes to file. +// FilePutContents Put bytes to file. // if non-exist, create this file. -func File_put_contents(filename string, content []byte) error { +func FilePutContents(filename string, content []byte) error { fp, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, os.ModePerm) if err != nil { return err @@ -254,8 +251,8 @@ func File_put_contents(filename string, content []byte) error { return err } -// Gob encodes file cache item. -func Gob_encode(data interface{}) ([]byte, error) { +// GobEncode Gob encodes file cache item. +func GobEncode(data interface{}) ([]byte, error) { buf := bytes.NewBuffer(nil) enc := gob.NewEncoder(buf) err := enc.Encode(data) @@ -265,9 +262,13 @@ func Gob_encode(data interface{}) ([]byte, error) { return buf.Bytes(), err } -// Gob decodes file cache item. -func Gob_decode(data []byte, to *FileCacheItem) error { +// GobDecode Gob decodes file cache item. +func GobDecode(data []byte, to *FileCacheItem) error { buf := bytes.NewBuffer(data) dec := gob.NewDecoder(buf) return dec.Decode(&to) } + +func init() { + Register("file", NewFileCache) +} diff --git a/cache/memcache/memcache.go b/cache/memcache/memcache.go index c6829054..15ea5d3e 100644 --- a/cache/memcache/memcache.go +++ b/cache/memcache/memcache.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// package memcahe for cache provider +// Package memcache for cache provider // // depend on github.com/bradfitz/gomemcache/memcache // @@ -37,21 +37,22 @@ import ( "github.com/bradfitz/gomemcache/memcache" "github.com/astaxie/beego/cache" + "time" ) -// Memcache adapter. -type MemcacheCache struct { +// Cache Memcache adapter. +type Cache struct { conn *memcache.Client conninfo []string } -// create new memcache adapter. -func NewMemCache() *MemcacheCache { - return &MemcacheCache{} +// NewMemCache create new memcache adapter. +func NewMemCache() cache.Cache { + return &Cache{} } -// get value from memcache. -func (rc *MemcacheCache) Get(key string) interface{} { +// Get get value from memcache. +func (rc *Cache) Get(key string) interface{} { if rc.conn == nil { if err := rc.connectInit(); err != nil { return err @@ -63,8 +64,8 @@ func (rc *MemcacheCache) Get(key string) interface{} { return nil } -// get value from memcache. -func (rc *MemcacheCache) GetMulti(keys []string) []interface{} { +// GetMulti get value from memcache. +func (rc *Cache) GetMulti(keys []string) []interface{} { size := len(keys) var rv []interface{} if rc.conn == nil { @@ -81,16 +82,15 @@ func (rc *MemcacheCache) GetMulti(keys []string) []interface{} { rv = append(rv, string(v.Value)) } return rv - } else { - for i := 0; i < size; i++ { - rv = append(rv, err) - } - return rv } + for i := 0; i < size; i++ { + rv = append(rv, err) + } + return rv } -// put value to memcache. only support string. -func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error { +// Put put value to memcache. only support string. +func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error { if rc.conn == nil { if err := rc.connectInit(); err != nil { return err @@ -100,12 +100,12 @@ func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error { if !ok { return errors.New("val must string") } - item := memcache.Item{Key: key, Value: []byte(v), Expiration: int32(timeout)} + item := memcache.Item{Key: key, Value: []byte(v), Expiration: int32(timeout/time.Second)} return rc.conn.Set(&item) } -// delete value in memcache. -func (rc *MemcacheCache) Delete(key string) error { +// Delete delete value in memcache. +func (rc *Cache) Delete(key string) error { if rc.conn == nil { if err := rc.connectInit(); err != nil { return err @@ -114,8 +114,8 @@ func (rc *MemcacheCache) Delete(key string) error { return rc.conn.Delete(key) } -// increase counter. -func (rc *MemcacheCache) Incr(key string) error { +// Incr increase counter. +func (rc *Cache) Incr(key string) error { if rc.conn == nil { if err := rc.connectInit(); err != nil { return err @@ -125,8 +125,8 @@ func (rc *MemcacheCache) Incr(key string) error { return err } -// decrease counter. -func (rc *MemcacheCache) Decr(key string) error { +// Decr decrease counter. +func (rc *Cache) Decr(key string) error { if rc.conn == nil { if err := rc.connectInit(); err != nil { return err @@ -136,8 +136,8 @@ func (rc *MemcacheCache) Decr(key string) error { return err } -// check value exists in memcache. -func (rc *MemcacheCache) IsExist(key string) bool { +// IsExist check value exists in memcache. +func (rc *Cache) IsExist(key string) bool { if rc.conn == nil { if err := rc.connectInit(); err != nil { return false @@ -150,8 +150,8 @@ func (rc *MemcacheCache) IsExist(key string) bool { return true } -// clear all cached in memcache. -func (rc *MemcacheCache) ClearAll() error { +// ClearAll clear all cached in memcache. +func (rc *Cache) ClearAll() error { if rc.conn == nil { if err := rc.connectInit(); err != nil { return err @@ -160,10 +160,10 @@ func (rc *MemcacheCache) ClearAll() error { return rc.conn.FlushAll() } -// start memcache adapter. +// StartAndGC start memcache adapter. // config string is like {"conn":"connection info"}. // if connecting error, return. -func (rc *MemcacheCache) StartAndGC(config string) error { +func (rc *Cache) StartAndGC(config string) error { var cf map[string]string json.Unmarshal([]byte(config), &cf) if _, ok := cf["conn"]; !ok { @@ -179,11 +179,11 @@ func (rc *MemcacheCache) StartAndGC(config string) error { } // connect to memcache and keep the connection. -func (rc *MemcacheCache) connectInit() error { +func (rc *Cache) connectInit() error { rc.conn = memcache.New(rc.conninfo...) return nil } func init() { - cache.Register("memcache", NewMemCache()) + cache.Register("memcache", NewMemCache) } diff --git a/cache/memcache/memcache_test.go b/cache/memcache/memcache_test.go index 0523ae85..19629059 100644 --- a/cache/memcache/memcache_test.go +++ b/cache/memcache/memcache_test.go @@ -23,12 +23,13 @@ import ( "time" ) -func TestRedisCache(t *testing.T) { +func TestMemcacheCache(t *testing.T) { bm, err := cache.NewCache("memcache", `{"conn": "127.0.0.1:11211"}`) if err != nil { t.Error("init err") } - if err = bm.Put("astaxie", "1", 10); err != nil { + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", "1", timeoutDuration); err != nil { t.Error("set Error", err) } if !bm.IsExist("astaxie") { @@ -40,7 +41,7 @@ func TestRedisCache(t *testing.T) { if bm.IsExist("astaxie") { t.Error("check err") } - if err = bm.Put("astaxie", "1", 10); err != nil { + if err = bm.Put("astaxie", "1", timeoutDuration); err != nil { t.Error("set Error", err) } @@ -69,7 +70,7 @@ func TestRedisCache(t *testing.T) { } //test string - if err = bm.Put("astaxie", "author", 10); err != nil { + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { t.Error("set Error", err) } if !bm.IsExist("astaxie") { @@ -81,7 +82,7 @@ func TestRedisCache(t *testing.T) { } //test GetMulti - if err = bm.Put("astaxie1", "author1", 10); err != nil { + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { t.Error("set Error", err) } if !bm.IsExist("astaxie1") { diff --git a/cache/memory.go b/cache/memory.go index b6657048..d928afdb 100644 --- a/cache/memory.go +++ b/cache/memory.go @@ -17,34 +17,41 @@ package cache import ( "encoding/json" "errors" - "fmt" "sync" "time" ) var ( - // clock time of recycling the expired cache items in memory. - DefaultEvery int = 60 // 1 minute + // DefaultEvery means the clock time of recycling the expired cache items in memory. + DefaultEvery = 60 // 1 minute ) -// Memory cache item. +// MemoryItem store memory cache item. type MemoryItem struct { - val interface{} - Lastaccess time.Time - expired int64 + val interface{} + createdTime time.Time + lifespan time.Duration } -// Memory cache adapter. +func (mi *MemoryItem) isExpire() bool { + // 0 means forever + if mi.lifespan == 0 { + return false + } + return time.Now().Sub(mi.createdTime) > mi.lifespan +} + +// MemoryCache is Memory cache adapter. // it contains a RW locker for safe map storage. type MemoryCache struct { - lock sync.RWMutex + sync.RWMutex dur time.Duration items map[string]*MemoryItem Every int // run an expiration check Every clock time } // NewMemoryCache returns a new MemoryCache. -func NewMemoryCache() *MemoryCache { +func NewMemoryCache() Cache { cache := MemoryCache{items: make(map[string]*MemoryItem)} return &cache } @@ -52,11 +59,10 @@ func NewMemoryCache() *MemoryCache { // Get cache from memory. // if non-existed or expired, return nil. func (bc *MemoryCache) Get(name string) interface{} { - bc.lock.RLock() - defer bc.lock.RUnlock() + bc.RLock() + defer bc.RUnlock() if itm, ok := bc.items[name]; ok { - if (time.Now().Unix() - itm.Lastaccess.Unix()) > itm.expired { - go bc.Delete(name) + if itm.isExpire() { return nil } return itm.val @@ -75,22 +81,22 @@ func (bc *MemoryCache) GetMulti(names []string) []interface{} { } // Put cache to memory. -// if expired is 0, it will be cleaned by next gc operation ( default gc clock is 1 minute). -func (bc *MemoryCache) Put(name string, value interface{}, expired int64) error { - bc.lock.Lock() - defer bc.lock.Unlock() +// if lifespan is 0, it will be forever till restart. +func (bc *MemoryCache) Put(name string, value interface{}, lifespan time.Duration) error { + bc.Lock() + defer bc.Unlock() bc.items[name] = &MemoryItem{ val: value, - Lastaccess: time.Now(), - expired: expired, + createdTime: time.Now(), + lifespan: lifespan, } return nil } -/// Delete cache in memory. +// Delete cache in memory. func (bc *MemoryCache) Delete(name string) error { - bc.lock.Lock() - defer bc.lock.Unlock() + bc.Lock() + defer bc.Unlock() if _, ok := bc.items[name]; !ok { return errors.New("key not exist") } @@ -101,11 +107,11 @@ func (bc *MemoryCache) Delete(name string) error { return nil } -// Increase cache counter in memory. -// it supports int,int64,int32,uint,uint64,uint32. +// Incr increase cache counter in memory. +// it supports int,int32,int64,uint,uint32,uint64. func (bc *MemoryCache) Incr(key string) error { - bc.lock.RLock() - defer bc.lock.RUnlock() + bc.RLock() + defer bc.RUnlock() itm, ok := bc.items[key] if !ok { return errors.New("key not exist") @@ -113,10 +119,10 @@ func (bc *MemoryCache) Incr(key string) error { switch itm.val.(type) { case int: itm.val = itm.val.(int) + 1 - case int64: - itm.val = itm.val.(int64) + 1 case int32: itm.val = itm.val.(int32) + 1 + case int64: + itm.val = itm.val.(int64) + 1 case uint: itm.val = itm.val.(uint) + 1 case uint32: @@ -124,15 +130,15 @@ func (bc *MemoryCache) Incr(key string) error { case uint64: itm.val = itm.val.(uint64) + 1 default: - return errors.New("item val is not int int64 int32") + return errors.New("item val is not (u)int (u)int32 (u)int64") } return nil } -// Decrease counter in memory. +// Decr decrease counter in memory. func (bc *MemoryCache) Decr(key string) error { - bc.lock.RLock() - defer bc.lock.RUnlock() + bc.RLock() + defer bc.RUnlock() itm, ok := bc.items[key] if !ok { return errors.New("key not exist") @@ -168,23 +174,25 @@ func (bc *MemoryCache) Decr(key string) error { return nil } -// check cache exist in memory. +// IsExist check cache exist in memory. func (bc *MemoryCache) IsExist(name string) bool { - bc.lock.RLock() - defer bc.lock.RUnlock() - _, ok := bc.items[name] - return ok + bc.RLock() + defer bc.RUnlock() + if v, ok := bc.items[name]; ok { + return !v.isExpire() + } + return false } -// delete all cache in memory. +// ClearAll will delete all cache in memory. func (bc *MemoryCache) ClearAll() error { - bc.lock.Lock() - defer bc.lock.Unlock() + bc.Lock() + defer bc.Unlock() bc.items = make(map[string]*MemoryItem) return nil } -// start memory cache. it will check expiration in every clock time. +// StartAndGC start memory cache. it will check expiration in every clock time. func (bc *MemoryCache) StartAndGC(config string) error { var cf map[string]int json.Unmarshal([]byte(config), &cf) @@ -192,10 +200,7 @@ func (bc *MemoryCache) StartAndGC(config string) error { cf = make(map[string]int) cf["interval"] = DefaultEvery } - dur, err := time.ParseDuration(fmt.Sprintf("%ds", cf["interval"])) - if err != nil { - return err - } + dur := time.Duration(cf["interval"]) * time.Second bc.Every = cf["interval"] bc.dur = dur go bc.vaccuum() @@ -213,20 +218,21 @@ func (bc *MemoryCache) vaccuum() { return } for name := range bc.items { - bc.item_expired(name) + bc.itemExpired(name) } } } -// item_expired returns true if an item is expired. -func (bc *MemoryCache) item_expired(name string) bool { - bc.lock.Lock() - defer bc.lock.Unlock() +// itemExpired returns true if an item is expired. +func (bc *MemoryCache) itemExpired(name string) bool { + bc.Lock() + defer bc.Unlock() + itm, ok := bc.items[name] if !ok { return true } - if time.Now().Unix()-itm.Lastaccess.Unix() >= itm.expired { + if itm.isExpire() { delete(bc.items, name) return true } @@ -234,5 +240,5 @@ func (bc *MemoryCache) item_expired(name string) bool { } func init() { - Register("memory", NewMemoryCache()) + Register("memory", NewMemoryCache) } diff --git a/cache/redis/redis.go b/cache/redis/redis.go index d14b1ada..781e3836 100644 --- a/cache/redis/redis.go +++ b/cache/redis/redis.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// package redis for cache provider +// Package redis for cache provider // // depend on github.com/garyburd/redigo/redis // @@ -41,12 +41,12 @@ import ( ) var ( - // the collection name of redis for cache adapter. - DefaultKey string = "beecacheRedis" + // DefaultKey the collection name of redis for cache adapter. + DefaultKey = "beecacheRedis" ) -// Redis cache adapter. -type RedisCache struct { +// Cache is Redis cache adapter. +type Cache struct { p *redis.Pool // redis connection pool conninfo string dbNum int @@ -54,13 +54,13 @@ type RedisCache struct { password string } -// create new redis cache with default collection name. -func NewRedisCache() *RedisCache { - return &RedisCache{key: DefaultKey} +// NewRedisCache create new redis cache with default collection name. +func NewRedisCache() cache.Cache { + return &Cache{key: DefaultKey} } // actually do the redis cmds -func (rc *RedisCache) do(commandName string, args ...interface{}) (reply interface{}, err error) { +func (rc *Cache) do(commandName string, args ...interface{}) (reply interface{}, err error) { c := rc.p.Get() defer c.Close() @@ -68,7 +68,7 @@ func (rc *RedisCache) do(commandName string, args ...interface{}) (reply interfa } // Get cache from redis. -func (rc *RedisCache) Get(key string) interface{} { +func (rc *Cache) Get(key string) interface{} { if v, err := rc.do("GET", key); err == nil { return v } @@ -76,7 +76,7 @@ func (rc *RedisCache) Get(key string) interface{} { } // GetMulti get cache from redis. -func (rc *RedisCache) GetMulti(keys []string) []interface{} { +func (rc *Cache) GetMulti(keys []string) []interface{} { size := len(keys) var rv []interface{} c := rc.p.Get() @@ -108,10 +108,10 @@ ERROR: return rv } -// put cache to redis. -func (rc *RedisCache) Put(key string, val interface{}, timeout int64) error { +// Put put cache to redis. +func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error { var err error - if _, err = rc.do("SETEX", key, timeout, val); err != nil { + if _, err = rc.do("SETEX", key, int64(timeout/time.Second), val); err != nil { return err } @@ -121,8 +121,8 @@ func (rc *RedisCache) Put(key string, val interface{}, timeout int64) error { return err } -// delete cache in redis. -func (rc *RedisCache) Delete(key string) error { +// Delete delete cache in redis. +func (rc *Cache) Delete(key string) error { var err error if _, err = rc.do("DEL", key); err != nil { return err @@ -131,8 +131,8 @@ func (rc *RedisCache) Delete(key string) error { return err } -// check cache's existence in redis. -func (rc *RedisCache) IsExist(key string) bool { +// IsExist check cache's existence in redis. +func (rc *Cache) IsExist(key string) bool { v, err := redis.Bool(rc.do("EXISTS", key)) if err != nil { return false @@ -145,20 +145,20 @@ func (rc *RedisCache) IsExist(key string) bool { return v } -// increase counter in redis. -func (rc *RedisCache) Incr(key string) error { +// Incr increase counter in redis. +func (rc *Cache) Incr(key string) error { _, err := redis.Bool(rc.do("INCRBY", key, 1)) return err } -// decrease counter in redis. -func (rc *RedisCache) Decr(key string) error { +// Decr decrease counter in redis. +func (rc *Cache) Decr(key string) error { _, err := redis.Bool(rc.do("INCRBY", key, -1)) return err } -// clean all cache in redis. delete this redis collection. -func (rc *RedisCache) ClearAll() error { +// ClearAll clean all cache in redis. delete this redis collection. +func (rc *Cache) ClearAll() error { cachedKeys, err := redis.Strings(rc.do("HKEYS", rc.key)) if err != nil { return err @@ -172,11 +172,11 @@ func (rc *RedisCache) ClearAll() error { return err } -// start redis cache adapter. +// StartAndGC start redis cache adapter. // config is like {"key":"collection key","conn":"connection info","dbNum":"0"} // the cache item in redis are stored forever, // so no gc operation. -func (rc *RedisCache) StartAndGC(config string) error { +func (rc *Cache) StartAndGC(config string) error { var cf map[string]string json.Unmarshal([]byte(config), &cf) @@ -206,7 +206,7 @@ func (rc *RedisCache) StartAndGC(config string) error { } // connect to redis. -func (rc *RedisCache) connectInit() { +func (rc *Cache) connectInit() { dialFunc := func() (c redis.Conn, err error) { c, err = redis.Dial("tcp", rc.conninfo) if err != nil { @@ -236,5 +236,5 @@ func (rc *RedisCache) connectInit() { } func init() { - cache.Register("redis", NewRedisCache()) + cache.Register("redis", NewRedisCache) } diff --git a/cache/redis/redis_test.go b/cache/redis/redis_test.go index 1f74fd27..47c5acc6 100644 --- a/cache/redis/redis_test.go +++ b/cache/redis/redis_test.go @@ -28,19 +28,20 @@ func TestRedisCache(t *testing.T) { if err != nil { t.Error("init err") } - if err = bm.Put("astaxie", 1, 10); err != nil { + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { t.Error("set Error", err) } if !bm.IsExist("astaxie") { t.Error("check err") } - time.Sleep(10 * time.Second) + time.Sleep(11 * time.Second) if bm.IsExist("astaxie") { t.Error("check err") } - if err = bm.Put("astaxie", 1, 10); err != nil { + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { t.Error("set Error", err) } @@ -69,7 +70,7 @@ func TestRedisCache(t *testing.T) { } //test string - if err = bm.Put("astaxie", "author", 10); err != nil { + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { t.Error("set Error", err) } if !bm.IsExist("astaxie") { @@ -81,7 +82,7 @@ func TestRedisCache(t *testing.T) { } //test GetMulti - if err = bm.Put("astaxie1", "author1", 10); err != nil { + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { t.Error("set Error", err) } if !bm.IsExist("astaxie1") { diff --git a/config.go b/config.go index 09d0df24..cffb7c2b 100644 --- a/config.go +++ b/config.go @@ -15,78 +15,264 @@ package beego import ( - "fmt" "html/template" "os" "path/filepath" - "runtime" "strings" "github.com/astaxie/beego/config" - "github.com/astaxie/beego/logs" "github.com/astaxie/beego/session" "github.com/astaxie/beego/utils" ) -var ( - BeeApp *App // beego application - AppName string - AppPath string - workPath string - AppConfigPath string +type BeegoConfig struct { + AppName string //Application name + RunMode string //Running Mode: dev | prod + RouterCaseSensitive bool + ServerName string + RecoverPanic bool + CopyRequestBody bool + EnableGzip bool + MaxMemory int64 + EnableErrorsShow bool + Listen Listen + WebConfig WebConfig + Log LogConfig +} + +type Listen struct { + Graceful bool // Graceful means use graceful module to start the server + ServerTimeOut int64 + ListenTCP4 bool + EnableHTTP bool + HTTPAddr string + HTTPPort int + EnableHTTPS bool + HTTPSAddr string + HTTPSPort int + HTTPSCertFile string + HTTPSKeyFile string + EnableAdmin bool + AdminAddr string + AdminPort int + EnableFcgi bool + EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O +} + +type WebConfig struct { + AutoRender bool + EnableDocs bool + FlashName string + FlashSeparator string + DirectoryIndex bool StaticDir map[string]string - TemplateCache map[string]*template.Template // template caching map - StaticExtensionsToGzip []string // files with should be compressed with gzip (.js,.css,etc) - EnableHttpListen bool - HttpAddr string - HttpPort int - ListenTCP4 bool - EnableHttpTLS bool - HttpsPort int - HttpCertFile string - HttpKeyFile string - RecoverPanic bool // flag of auto recover panic - AutoRender bool // flag of render template automatically - ViewsPath string - AppConfig *beegoAppConfig - RunMode string // run mode, "dev" or "prod" - GlobalSessions *session.Manager // global session mananger - SessionOn bool // flag of starting session auto. default is false. - SessionProvider string // default session provider, memory, mysql , redis ,etc. - SessionName string // the cookie name when saving session id into cookie. - SessionGCMaxLifetime int64 // session gc time for auto cleaning expired session. - SessionSavePath string // if use mysql/redis/file provider, define save path to connection info. - SessionCookieLifeTime int // the life time of session id in cookie. - SessionAutoSetCookie bool // auto setcookie - SessionDomain string // the cookie domain default is empty - UseFcgi bool - UseStdIo bool - MaxMemory int64 - EnableGzip bool // flag of enable gzip - DirectoryIndex bool // flag of display directory index. default is false. - HttpServerTimeOut int64 - ErrorsShow bool // flag of show errors in page. if true, show error and trace info in page rendered with error template. - XSRFKEY string // xsrf hash salt string. - EnableXSRF bool // flag of enable xsrf. - XSRFExpire int // the expiry of xsrf value. - CopyRequestBody bool // flag of copy raw request body in context. + StaticExtensionsToGzip []string TemplateLeft string TemplateRight string - BeegoServerName string // beego server name exported in response header. - EnableAdmin bool // flag of enable admin module to log every request info. - AdminHttpAddr string // http server configurations for admin module. - AdminHttpPort int - FlashName string // name of the flash variable found in response header and cookie - FlashSeperator string // used to seperate flash key:value - AppConfigProvider string // config provider - EnableDocs bool // enable generate docs & server docs API Swagger - RouterCaseSensitive bool // router case sensitive default is true - AccessLogs bool // print access logs, default is false - Graceful bool // use graceful start the server + ViewsPath string + EnableXSRF bool + XSRFKey string + XSRFExpire int + Session SessionConfig +} + +type SessionConfig struct { + SessionOn bool + SessionProvider string + SessionName string + SessionGCMaxLifetime int64 + SessionProviderConfig string + SessionCookieLifeTime int + SessionAutoSetCookie bool + SessionDomain string +} + +type LogConfig struct { + AccessLogs bool + FileLineNum bool + Outputs map[string]string // Store Adaptor : config +} + +var ( + // BConfig is the default config for Application + BConfig *BeegoConfig + // AppConfig is the instance of Config, store the config information from file + AppConfig *beegoAppConfig + // AppConfigPath is the path to the config files + AppConfigPath string + // AppConfigProvider is the provider for the config, default is ini + AppConfigProvider = "ini" + // TemplateCache stores template caching + TemplateCache map[string]*template.Template + // GlobalSessions is the instance for the session manager + GlobalSessions *session.Manager ) +func init() { + BConfig = &BeegoConfig{ + AppName: "beego", + RunMode: DEV, + RouterCaseSensitive: true, + ServerName: "beegoServer:" + VERSION, + RecoverPanic: true, + CopyRequestBody: false, + EnableGzip: false, + MaxMemory: 1 << 26, //64MB + EnableErrorsShow: true, + Listen: Listen{ + Graceful: false, + ServerTimeOut: 0, + ListenTCP4: false, + EnableHTTP: true, + HTTPAddr: "", + HTTPPort: 8080, + EnableHTTPS: false, + HTTPSAddr: "", + HTTPSPort: 10443, + HTTPSCertFile: "", + HTTPSKeyFile: "", + EnableAdmin: false, + AdminAddr: "", + AdminPort: 8088, + EnableFcgi: false, + EnableStdIo: false, + }, + WebConfig: WebConfig{ + AutoRender: true, + EnableDocs: false, + FlashName: "BEEGO_FLASH", + FlashSeparator: "BEEGOFLASH", + DirectoryIndex: false, + StaticDir: map[string]string{"/static": "static"}, + StaticExtensionsToGzip: []string{".css", ".js"}, + TemplateLeft: "{{", + TemplateRight: "}}", + ViewsPath: "views", + EnableXSRF: false, + XSRFKey: "beegoxsrf", + XSRFExpire: 0, + Session: SessionConfig{ + SessionOn: false, + SessionProvider: "memory", + SessionName: "beegosessionID", + SessionGCMaxLifetime: 3600, + SessionProviderConfig: "", + SessionCookieLifeTime: 0, //set cookie default is the brower life + SessionAutoSetCookie: true, + SessionDomain: "", + }, + }, + Log: LogConfig{ + AccessLogs: false, + FileLineNum: true, + Outputs: map[string]string{"console": ""}, + }, + } + ParseConfig() +} + +// ParseConfig parsed default config file. +// now only support ini, next will support json. +func ParseConfig() (err error) { + if AppConfigPath == "" { + if utils.FileExists(filepath.Join("conf", "app.conf")) { + AppConfigPath = filepath.Join("conf", "app.conf") + } else { + AppConfig = &beegoAppConfig{config.NewFakeConfig()} + return + } + } + AppConfig, err = newAppConfig(AppConfigProvider, AppConfigPath) + if err != nil { + return err + } + // set the runmode first + if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" { + BConfig.RunMode = envRunMode + } else if runmode := AppConfig.String("RunMode"); runmode != "" { + BConfig.RunMode = runmode + } + + BConfig.AppName = AppConfig.DefaultString("AppName", BConfig.AppName) + BConfig.RecoverPanic = AppConfig.DefaultBool("RecoverPanic", BConfig.RecoverPanic) + BConfig.RouterCaseSensitive = AppConfig.DefaultBool("RouterCaseSensitive", BConfig.RouterCaseSensitive) + BConfig.ServerName = AppConfig.DefaultString("ServerName", BConfig.ServerName) + BConfig.EnableGzip = AppConfig.DefaultBool("EnableGzip", BConfig.EnableGzip) + BConfig.EnableErrorsShow = AppConfig.DefaultBool("EnableErrorsShow", BConfig.EnableErrorsShow) + BConfig.CopyRequestBody = AppConfig.DefaultBool("CopyRequestBody", BConfig.CopyRequestBody) + BConfig.MaxMemory = AppConfig.DefaultInt64("MaxMemory", BConfig.MaxMemory) + BConfig.Listen.Graceful = AppConfig.DefaultBool("Graceful", BConfig.Listen.Graceful) + BConfig.Listen.HTTPAddr = AppConfig.String("HTTPAddr") + BConfig.Listen.HTTPPort = AppConfig.DefaultInt("HTTPPort", BConfig.Listen.HTTPPort) + BConfig.Listen.ListenTCP4 = AppConfig.DefaultBool("ListenTCP4", BConfig.Listen.ListenTCP4) + BConfig.Listen.EnableHTTP = AppConfig.DefaultBool("EnableHTTP", BConfig.Listen.EnableHTTP) + BConfig.Listen.EnableHTTPS = AppConfig.DefaultBool("EnableHTTPS", BConfig.Listen.EnableHTTPS) + BConfig.Listen.HTTPSAddr = AppConfig.DefaultString("HTTPSAddr", BConfig.Listen.HTTPSAddr) + BConfig.Listen.HTTPSPort = AppConfig.DefaultInt("HTTPSPort", BConfig.Listen.HTTPSPort) + BConfig.Listen.HTTPSCertFile = AppConfig.DefaultString("HTTPSCertFile", BConfig.Listen.HTTPSCertFile) + BConfig.Listen.HTTPSKeyFile = AppConfig.DefaultString("HTTPSKeyFile", BConfig.Listen.HTTPSKeyFile) + BConfig.Listen.EnableAdmin = AppConfig.DefaultBool("EnableAdmin", BConfig.Listen.EnableAdmin) + BConfig.Listen.AdminAddr = AppConfig.DefaultString("AdminAddr", BConfig.Listen.AdminAddr) + BConfig.Listen.AdminPort = AppConfig.DefaultInt("AdminPort", BConfig.Listen.AdminPort) + BConfig.Listen.EnableFcgi = AppConfig.DefaultBool("EnableFcgi", BConfig.Listen.EnableFcgi) + BConfig.Listen.EnableStdIo = AppConfig.DefaultBool("EnableStdIo", BConfig.Listen.EnableStdIo) + BConfig.Listen.ServerTimeOut = AppConfig.DefaultInt64("ServerTimeOut", BConfig.Listen.ServerTimeOut) + BConfig.WebConfig.AutoRender = AppConfig.DefaultBool("AutoRender", BConfig.WebConfig.AutoRender) + BConfig.WebConfig.ViewsPath = AppConfig.DefaultString("ViewsPath", BConfig.WebConfig.ViewsPath) + BConfig.WebConfig.DirectoryIndex = AppConfig.DefaultBool("DirectoryIndex", BConfig.WebConfig.DirectoryIndex) + BConfig.WebConfig.FlashName = AppConfig.DefaultString("FlashName", BConfig.WebConfig.FlashName) + BConfig.WebConfig.FlashSeparator = AppConfig.DefaultString("FlashSeparator", BConfig.WebConfig.FlashSeparator) + BConfig.WebConfig.EnableDocs = AppConfig.DefaultBool("EnableDocs", BConfig.WebConfig.EnableDocs) + BConfig.WebConfig.XSRFKey = AppConfig.DefaultString("XSRFKEY", BConfig.WebConfig.XSRFKey) + BConfig.WebConfig.EnableXSRF = AppConfig.DefaultBool("EnableXSRF", BConfig.WebConfig.EnableXSRF) + BConfig.WebConfig.XSRFExpire = AppConfig.DefaultInt("XSRFExpire", BConfig.WebConfig.XSRFExpire) + BConfig.WebConfig.TemplateLeft = AppConfig.DefaultString("TemplateLeft", BConfig.WebConfig.TemplateLeft) + BConfig.WebConfig.TemplateRight = AppConfig.DefaultString("TemplateRight", BConfig.WebConfig.TemplateRight) + BConfig.WebConfig.Session.SessionOn = AppConfig.DefaultBool("SessionOn", BConfig.WebConfig.Session.SessionOn) + BConfig.WebConfig.Session.SessionProvider = AppConfig.DefaultString("SessionProvider", BConfig.WebConfig.Session.SessionProvider) + BConfig.WebConfig.Session.SessionName = AppConfig.DefaultString("SessionName", BConfig.WebConfig.Session.SessionName) + BConfig.WebConfig.Session.SessionProviderConfig = AppConfig.DefaultString("SessionProviderConfig", BConfig.WebConfig.Session.SessionProviderConfig) + BConfig.WebConfig.Session.SessionGCMaxLifetime = AppConfig.DefaultInt64("SessionGCMaxLifetime", BConfig.WebConfig.Session.SessionGCMaxLifetime) + BConfig.WebConfig.Session.SessionCookieLifeTime = AppConfig.DefaultInt("SessionCookieLifeTime", BConfig.WebConfig.Session.SessionCookieLifeTime) + BConfig.WebConfig.Session.SessionAutoSetCookie = AppConfig.DefaultBool("SessionAutoSetCookie", BConfig.WebConfig.Session.SessionAutoSetCookie) + BConfig.WebConfig.Session.SessionDomain = AppConfig.DefaultString("SessionDomain", BConfig.WebConfig.Session.SessionDomain) + + if sd := AppConfig.String("StaticDir"); sd != "" { + for k := range BConfig.WebConfig.StaticDir { + delete(BConfig.WebConfig.StaticDir, k) + } + sds := strings.Fields(sd) + for _, v := range sds { + if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 { + BConfig.WebConfig.StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[1] + } else { + BConfig.WebConfig.StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[0] + } + } + } + + if sgz := AppConfig.String("StaticExtensionsToGzip"); sgz != "" { + extensions := strings.Split(sgz, ",") + fileExts := []string{} + for _, ext := range extensions { + ext = strings.TrimSpace(ext) + if ext == "" { + continue + } + if !strings.HasPrefix(ext, ".") { + ext = "." + ext + } + fileExts = append(fileExts, ext) + } + if len(fileExts) > 0 { + BConfig.WebConfig.StaticExtensionsToGzip = fileExts + } + } + return nil +} + type beegoAppConfig struct { - innerConfig config.ConfigContainer + innerConfig config.Configer } func newAppConfig(AppConfigProvider, AppConfigPath string) (*beegoAppConfig, error) { @@ -94,109 +280,95 @@ func newAppConfig(AppConfigProvider, AppConfigPath string) (*beegoAppConfig, err if err != nil { return nil, err } - rac := &beegoAppConfig{ac} - return rac, nil + return &beegoAppConfig{ac}, nil } func (b *beegoAppConfig) Set(key, val string) error { - err := b.innerConfig.Set(RunMode+"::"+key, val) - if err == nil { + if err := b.innerConfig.Set(BConfig.RunMode+"::"+key, val); err != nil { return err } return b.innerConfig.Set(key, val) } func (b *beegoAppConfig) String(key string) string { - v := b.innerConfig.String(RunMode + "::" + key) - if v == "" { - return b.innerConfig.String(key) + if v := b.innerConfig.String(BConfig.RunMode + "::" + key); v != "" { + return v } - return v + return b.innerConfig.String(key) } func (b *beegoAppConfig) Strings(key string) []string { - v := b.innerConfig.Strings(RunMode + "::" + key) - if v[0] == "" { - return b.innerConfig.Strings(key) + if v := b.innerConfig.Strings(BConfig.RunMode + "::" + key); v[0] != "" { + return v } - return v + return b.innerConfig.Strings(key) } func (b *beegoAppConfig) Int(key string) (int, error) { - v, err := b.innerConfig.Int(RunMode + "::" + key) - if err != nil { - return b.innerConfig.Int(key) + if v, err := b.innerConfig.Int(BConfig.RunMode + "::" + key); err == nil { + return v, nil } - return v, nil + return b.innerConfig.Int(key) } func (b *beegoAppConfig) Int64(key string) (int64, error) { - v, err := b.innerConfig.Int64(RunMode + "::" + key) - if err != nil { - return b.innerConfig.Int64(key) + if v, err := b.innerConfig.Int64(BConfig.RunMode + "::" + key); err == nil { + return v, nil } - return v, nil + return b.innerConfig.Int64(key) } func (b *beegoAppConfig) Bool(key string) (bool, error) { - v, err := b.innerConfig.Bool(RunMode + "::" + key) - if err != nil { - return b.innerConfig.Bool(key) + if v, err := b.innerConfig.Bool(BConfig.RunMode + "::" + key); err == nil { + return v, nil } - return v, nil + return b.innerConfig.Bool(key) } func (b *beegoAppConfig) Float(key string) (float64, error) { - v, err := b.innerConfig.Float(RunMode + "::" + key) - if err != nil { - return b.innerConfig.Float(key) + if v, err := b.innerConfig.Float(BConfig.RunMode + "::" + key); err == nil { + return v, nil } - return v, nil + return b.innerConfig.Float(key) } func (b *beegoAppConfig) DefaultString(key string, defaultval string) string { - v := b.String(key) - if v != "" { + if v := b.String(key); v != "" { return v } return defaultval } func (b *beegoAppConfig) DefaultStrings(key string, defaultval []string) []string { - v := b.Strings(key) - if len(v) != 0 { + if v := b.Strings(key); len(v) != 0 { return v } return defaultval } func (b *beegoAppConfig) DefaultInt(key string, defaultval int) int { - v, err := b.Int(key) - if err == nil { + if v, err := b.Int(key); err == nil { return v } return defaultval } func (b *beegoAppConfig) DefaultInt64(key string, defaultval int64) int64 { - v, err := b.Int64(key) - if err == nil { + if v, err := b.Int64(key); err == nil { return v } return defaultval } func (b *beegoAppConfig) DefaultBool(key string, defaultval bool) bool { - v, err := b.Bool(key) - if err == nil { + if v, err := b.Bool(key); err == nil { return v } return defaultval } func (b *beegoAppConfig) DefaultFloat(key string, defaultval float64) float64 { - v, err := b.Float(key) - if err == nil { + if v, err := b.Float(key); err == nil { return v } return defaultval @@ -213,305 +385,3 @@ func (b *beegoAppConfig) GetSection(section string) (map[string]string, error) { func (b *beegoAppConfig) SaveConfigFile(filename string) error { return b.innerConfig.SaveConfigFile(filename) } - -func init() { - // create beego application - BeeApp = NewApp() - - workPath, _ = os.Getwd() - workPath, _ = filepath.Abs(workPath) - // initialize default configurations - AppPath, _ = filepath.Abs(filepath.Dir(os.Args[0])) - - AppConfigPath = filepath.Join(AppPath, "conf", "app.conf") - - if workPath != AppPath { - if utils.FileExists(AppConfigPath) { - os.Chdir(AppPath) - } else { - AppConfigPath = filepath.Join(workPath, "conf", "app.conf") - } - } - - AppConfigProvider = "ini" - - StaticDir = make(map[string]string) - StaticDir["/static"] = "static" - - StaticExtensionsToGzip = []string{".css", ".js"} - - TemplateCache = make(map[string]*template.Template) - - // set this to 0.0.0.0 to make this app available to externally - EnableHttpListen = true //default enable http Listen - - HttpAddr = "" - HttpPort = 8080 - - HttpsPort = 10443 - - AppName = "beego" - - RunMode = "dev" //default runmod - - AutoRender = true - - RecoverPanic = true - - ViewsPath = "views" - - SessionOn = false - SessionProvider = "memory" - SessionName = "beegosessionID" - SessionGCMaxLifetime = 3600 - SessionSavePath = "" - SessionCookieLifeTime = 0 //set cookie default is the brower life - SessionAutoSetCookie = true - - UseFcgi = false - UseStdIo = false - - MaxMemory = 1 << 26 //64MB - - EnableGzip = false - - HttpServerTimeOut = 0 - - ErrorsShow = true - - XSRFKEY = "beegoxsrf" - XSRFExpire = 0 - - TemplateLeft = "{{" - TemplateRight = "}}" - - BeegoServerName = "beegoServer:" + VERSION - - EnableAdmin = false - AdminHttpAddr = "127.0.0.1" - AdminHttpPort = 8088 - - FlashName = "BEEGO_FLASH" - FlashSeperator = "BEEGOFLASH" - - RouterCaseSensitive = true - - runtime.GOMAXPROCS(runtime.NumCPU()) - - // init BeeLogger - BeeLogger = logs.NewLogger(10000) - err := BeeLogger.SetLogger("console", "") - if err != nil { - fmt.Println("init console log error:", err) - } - SetLogFuncCall(true) - - err = ParseConfig() - if err != nil && os.IsNotExist(err) { - // for init if doesn't have app.conf will not panic - ac := config.NewFakeConfig() - AppConfig = &beegoAppConfig{ac} - Warning(err) - } -} - -// ParseConfig parsed default config file. -// now only support ini, next will support json. -func ParseConfig() (err error) { - AppConfig, err = newAppConfig(AppConfigProvider, AppConfigPath) - if err != nil { - return err - } - envRunMode := os.Getenv("BEEGO_RUNMODE") - // set the runmode first - if envRunMode != "" { - RunMode = envRunMode - } else if runmode := AppConfig.String("RunMode"); runmode != "" { - RunMode = runmode - } - - HttpAddr = AppConfig.String("HttpAddr") - - if v, err := AppConfig.Int("HttpPort"); err == nil { - HttpPort = v - } - - if v, err := AppConfig.Bool("ListenTCP4"); err == nil { - ListenTCP4 = v - } - - if v, err := AppConfig.Bool("EnableHttpListen"); err == nil { - EnableHttpListen = v - } - - if maxmemory, err := AppConfig.Int64("MaxMemory"); err == nil { - MaxMemory = maxmemory - } - - if appname := AppConfig.String("AppName"); appname != "" { - AppName = appname - } - - if autorender, err := AppConfig.Bool("AutoRender"); err == nil { - AutoRender = autorender - } - - if autorecover, err := AppConfig.Bool("RecoverPanic"); err == nil { - RecoverPanic = autorecover - } - - if views := AppConfig.String("ViewsPath"); views != "" { - ViewsPath = views - } - - if sessionon, err := AppConfig.Bool("SessionOn"); err == nil { - SessionOn = sessionon - } - - if sessProvider := AppConfig.String("SessionProvider"); sessProvider != "" { - SessionProvider = sessProvider - } - - if sessName := AppConfig.String("SessionName"); sessName != "" { - SessionName = sessName - } - - if sesssavepath := AppConfig.String("SessionSavePath"); sesssavepath != "" { - SessionSavePath = sesssavepath - } - - if sessMaxLifeTime, err := AppConfig.Int64("SessionGCMaxLifetime"); err == nil && sessMaxLifeTime != 0 { - SessionGCMaxLifetime = sessMaxLifeTime - } - - if sesscookielifetime, err := AppConfig.Int("SessionCookieLifeTime"); err == nil && sesscookielifetime != 0 { - SessionCookieLifeTime = sesscookielifetime - } - - if usefcgi, err := AppConfig.Bool("UseFcgi"); err == nil { - UseFcgi = usefcgi - } - - if enablegzip, err := AppConfig.Bool("EnableGzip"); err == nil { - EnableGzip = enablegzip - } - - if directoryindex, err := AppConfig.Bool("DirectoryIndex"); err == nil { - DirectoryIndex = directoryindex - } - - if timeout, err := AppConfig.Int64("HttpServerTimeOut"); err == nil { - HttpServerTimeOut = timeout - } - - if errorsshow, err := AppConfig.Bool("ErrorsShow"); err == nil { - ErrorsShow = errorsshow - } - - if copyrequestbody, err := AppConfig.Bool("CopyRequestBody"); err == nil { - CopyRequestBody = copyrequestbody - } - - if xsrfkey := AppConfig.String("XSRFKEY"); xsrfkey != "" { - XSRFKEY = xsrfkey - } - - if enablexsrf, err := AppConfig.Bool("EnableXSRF"); err == nil { - EnableXSRF = enablexsrf - } - - if expire, err := AppConfig.Int("XSRFExpire"); err == nil { - XSRFExpire = expire - } - - if tplleft := AppConfig.String("TemplateLeft"); tplleft != "" { - TemplateLeft = tplleft - } - - if tplright := AppConfig.String("TemplateRight"); tplright != "" { - TemplateRight = tplright - } - - if httptls, err := AppConfig.Bool("EnableHttpTLS"); err == nil { - EnableHttpTLS = httptls - } - - if httpsport, err := AppConfig.Int("HttpsPort"); err == nil { - HttpsPort = httpsport - } - - if certfile := AppConfig.String("HttpCertFile"); certfile != "" { - HttpCertFile = certfile - } - - if keyfile := AppConfig.String("HttpKeyFile"); keyfile != "" { - HttpKeyFile = keyfile - } - - if serverName := AppConfig.String("BeegoServerName"); serverName != "" { - BeegoServerName = serverName - } - - if flashname := AppConfig.String("FlashName"); flashname != "" { - FlashName = flashname - } - - if flashseperator := AppConfig.String("FlashSeperator"); flashseperator != "" { - FlashSeperator = flashseperator - } - - if sd := AppConfig.String("StaticDir"); sd != "" { - for k := range StaticDir { - delete(StaticDir, k) - } - sds := strings.Fields(sd) - for _, v := range sds { - if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 { - StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[1] - } else { - StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[0] - } - } - } - - if sgz := AppConfig.String("StaticExtensionsToGzip"); sgz != "" { - extensions := strings.Split(sgz, ",") - if len(extensions) > 0 { - StaticExtensionsToGzip = []string{} - for _, ext := range extensions { - if len(ext) == 0 { - continue - } - extWithDot := ext - if extWithDot[:1] != "." { - extWithDot = "." + extWithDot - } - StaticExtensionsToGzip = append(StaticExtensionsToGzip, extWithDot) - } - } - } - - if enableadmin, err := AppConfig.Bool("EnableAdmin"); err == nil { - EnableAdmin = enableadmin - } - - if adminhttpaddr := AppConfig.String("AdminHttpAddr"); adminhttpaddr != "" { - AdminHttpAddr = adminhttpaddr - } - - if adminhttpport, err := AppConfig.Int("AdminHttpPort"); err == nil { - AdminHttpPort = adminhttpport - } - - if enabledocs, err := AppConfig.Bool("EnableDocs"); err == nil { - EnableDocs = enabledocs - } - - if casesensitive, err := AppConfig.Bool("RouterCaseSensitive"); err == nil { - RouterCaseSensitive = casesensitive - } - if graceful, err := AppConfig.Bool("Graceful"); err == nil { - Graceful = graceful - } - return nil -} diff --git a/config/config.go b/config/config.go index 8d9261b8..da5d358b 100644 --- a/config/config.go +++ b/config/config.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package config is used to parse config // Usage: // import( // "github.com/astaxie/beego/config" @@ -28,12 +29,12 @@ // cnf.Int64(key string) (int64, error) // cnf.Bool(key string) (bool, error) // cnf.Float(key string) (float64, error) -// cnf.DefaultString(key string, defaultval string) string -// cnf.DefaultStrings(key string, defaultval []string) []string -// cnf.DefaultInt(key string, defaultval int) int -// cnf.DefaultInt64(key string, defaultval int64) int64 -// cnf.DefaultBool(key string, defaultval bool) bool -// cnf.DefaultFloat(key string, defaultval float64) float64 +// cnf.DefaultString(key string, defaultVal string) string +// cnf.DefaultStrings(key string, defaultVal []string) []string +// cnf.DefaultInt(key string, defaultVal int) int +// cnf.DefaultInt64(key string, defaultVal int64) int64 +// cnf.DefaultBool(key string, defaultVal bool) bool +// cnf.DefaultFloat(key string, defaultVal float64) float64 // cnf.DIY(key string) (interface{}, error) // cnf.GetSection(section string) (map[string]string, error) // cnf.SaveConfigFile(filename string) error @@ -45,30 +46,30 @@ import ( "fmt" ) -// ConfigContainer defines how to get and set value from configuration raw data. -type ConfigContainer interface { - Set(key, val string) error // support section::key type in given key when using ini type. - String(key string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. +// Configer defines how to get and set value from configuration raw data. +type Configer interface { + Set(key, val string) error //support section::key type in given key when using ini type. + String(key string) string //support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. Strings(key string) []string //get string slice Int(key string) (int, error) Int64(key string) (int64, error) Bool(key string) (bool, error) Float(key string) (float64, error) - DefaultString(key string, defaultval string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. - DefaultStrings(key string, defaultval []string) []string //get string slice - DefaultInt(key string, defaultval int) int - DefaultInt64(key string, defaultval int64) int64 - DefaultBool(key string, defaultval bool) bool - DefaultFloat(key string, defaultval float64) float64 + DefaultString(key string, defaultVal string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + DefaultStrings(key string, defaultVal []string) []string //get string slice + DefaultInt(key string, defaultVal int) int + DefaultInt64(key string, defaultVal int64) int64 + DefaultBool(key string, defaultVal bool) bool + DefaultFloat(key string, defaultVal float64) float64 DIY(key string) (interface{}, error) GetSection(section string) (map[string]string, error) SaveConfigFile(filename string) error } -// Config is the adapter interface for parsing config file to get raw data to ConfigContainer. +// Config is the adapter interface for parsing config file to get raw data to Configer. type Config interface { - Parse(key string) (ConfigContainer, error) - ParseData(data []byte) (ConfigContainer, error) + Parse(key string) (Configer, error) + ParseData(data []byte) (Configer, error) } var adapters = make(map[string]Config) @@ -86,19 +87,19 @@ func Register(name string, adapter Config) { adapters[name] = adapter } -// adapterName is ini/json/xml/yaml. +// NewConfig adapterName is ini/json/xml/yaml. // filename is the config file path. -func NewConfig(adapterName, fileaname string) (ConfigContainer, error) { +func NewConfig(adapterName, filename string) (Configer, error) { adapter, ok := adapters[adapterName] if !ok { return nil, fmt.Errorf("config: unknown adaptername %q (forgotten import?)", adapterName) } - return adapter.Parse(fileaname) + return adapter.Parse(filename) } -// adapterName is ini/json/xml/yaml. +// NewConfigData adapterName is ini/json/xml/yaml. // data is the config data. -func NewConfigData(adapterName string, data []byte) (ConfigContainer, error) { +func NewConfigData(adapterName string, data []byte) (Configer, error) { adapter, ok := adapters[adapterName] if !ok { return nil, fmt.Errorf("config: unknown adaptername %q (forgotten import?)", adapterName) diff --git a/config/fake.go b/config/fake.go index 54588e5e..50aa5d4a 100644 --- a/config/fake.go +++ b/config/fake.go @@ -38,11 +38,11 @@ func (c *fakeConfigContainer) String(key string) string { } func (c *fakeConfigContainer) DefaultString(key string, defaultval string) string { - if v := c.getData(key); v == "" { + v := c.getData(key) + if v == "" { return defaultval - } else { - return v } + return v } func (c *fakeConfigContainer) Strings(key string) []string { @@ -50,11 +50,11 @@ func (c *fakeConfigContainer) Strings(key string) []string { } func (c *fakeConfigContainer) DefaultStrings(key string, defaultval []string) []string { - if v := c.Strings(key); len(v) == 0 { + v := c.Strings(key) + if len(v) == 0 { return defaultval - } else { - return v } + return v } func (c *fakeConfigContainer) Int(key string) (int, error) { @@ -62,11 +62,11 @@ func (c *fakeConfigContainer) Int(key string) (int, error) { } func (c *fakeConfigContainer) DefaultInt(key string, defaultval int) int { - if v, err := c.Int(key); err != nil { + v, err := c.Int(key) + if err != nil { return defaultval - } else { - return v } + return v } func (c *fakeConfigContainer) Int64(key string) (int64, error) { @@ -74,11 +74,11 @@ func (c *fakeConfigContainer) Int64(key string) (int64, error) { } func (c *fakeConfigContainer) DefaultInt64(key string, defaultval int64) int64 { - if v, err := c.Int64(key); err != nil { + v, err := c.Int64(key) + if err != nil { return defaultval - } else { - return v } + return v } func (c *fakeConfigContainer) Bool(key string) (bool, error) { @@ -86,11 +86,11 @@ func (c *fakeConfigContainer) Bool(key string) (bool, error) { } func (c *fakeConfigContainer) DefaultBool(key string, defaultval bool) bool { - if v, err := c.Bool(key); err != nil { + v, err := c.Bool(key) + if err != nil { return defaultval - } else { - return v } + return v } func (c *fakeConfigContainer) Float(key string) (float64, error) { @@ -98,11 +98,11 @@ func (c *fakeConfigContainer) Float(key string) (float64, error) { } func (c *fakeConfigContainer) DefaultFloat(key string, defaultval float64) float64 { - if v, err := c.Float(key); err != nil { + v, err := c.Float(key) + if err != nil { return defaultval - } else { - return v } + return v } func (c *fakeConfigContainer) DIY(key string) (interface{}, error) { @@ -120,9 +120,10 @@ func (c *fakeConfigContainer) SaveConfigFile(filename string) error { return errors.New("not implement in the fakeConfigContainer") } -var _ ConfigContainer = new(fakeConfigContainer) +var _ Configer = new(fakeConfigContainer) -func NewFakeConfig() ConfigContainer { +// NewFakeConfig return a fake Congiger +func NewFakeConfig() Configer { return &fakeConfigContainer{ data: make(map[string]string), } diff --git a/config/ini.go b/config/ini.go index 31fe9b5f..59e84e1e 100644 --- a/config/ini.go +++ b/config/ini.go @@ -31,23 +31,23 @@ import ( ) var ( - DEFAULT_SECTION = "default" // default section means if some ini items not in a section, make them in default section, - bNumComment = []byte{'#'} // number signal - bSemComment = []byte{';'} // semicolon signal - bEmpty = []byte{} - bEqual = []byte{'='} // equal signal - bDQuote = []byte{'"'} // quote signal - sectionStart = []byte{'['} // section start signal - sectionEnd = []byte{']'} // section end signal - lineBreak = "\n" + defaultSection = "default" // default section means if some ini items not in a section, make them in default section, + bNumComment = []byte{'#'} // number signal + bSemComment = []byte{';'} // semicolon signal + bEmpty = []byte{} + bEqual = []byte{'='} // equal signal + bDQuote = []byte{'"'} // quote signal + sectionStart = []byte{'['} // section start signal + sectionEnd = []byte{']'} // section end signal + lineBreak = "\n" ) // IniConfig implements Config to parse ini file. type IniConfig struct { } -// ParseFile creates a new Config and parses the file configuration from the named file. -func (ini *IniConfig) Parse(name string) (ConfigContainer, error) { +// Parse creates a new Config and parses the file configuration from the named file. +func (ini *IniConfig) Parse(name string) (Configer, error) { return ini.parseFile(name) } @@ -77,7 +77,7 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) { buf.ReadByte() } } - section := DEFAULT_SECTION + section := defaultSection for { line, _, err := buf.ReadLine() if err == io.EOF { @@ -171,7 +171,8 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) { return cfg, nil } -func (ini *IniConfig) ParseData(data []byte) (ConfigContainer, error) { +// ParseData parse ini the data +func (ini *IniConfig) ParseData(data []byte) (Configer, error) { // Save memory data to temporary file tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond())) os.MkdirAll(path.Dir(tmpName), os.ModePerm) @@ -181,7 +182,7 @@ func (ini *IniConfig) ParseData(data []byte) (ConfigContainer, error) { return ini.Parse(tmpName) } -// A Config represents the ini configuration. +// IniConfigContainer A Config represents the ini configuration. // When set and get value, support key as section:name type. type IniConfigContainer struct { filename string @@ -199,11 +200,11 @@ func (c *IniConfigContainer) Bool(key string) (bool, error) { // DefaultBool returns the boolean value for a given key. // if err != nil return defaltval func (c *IniConfigContainer) DefaultBool(key string, defaultval bool) bool { - if v, err := c.Bool(key); err != nil { + v, err := c.Bool(key) + if err != nil { return defaultval - } else { - return v } + return v } // Int returns the integer value for a given key. @@ -214,11 +215,11 @@ func (c *IniConfigContainer) Int(key string) (int, error) { // DefaultInt returns the integer value for a given key. // if err != nil return defaltval func (c *IniConfigContainer) DefaultInt(key string, defaultval int) int { - if v, err := c.Int(key); err != nil { + v, err := c.Int(key) + if err != nil { return defaultval - } else { - return v } + return v } // Int64 returns the int64 value for a given key. @@ -229,11 +230,11 @@ func (c *IniConfigContainer) Int64(key string) (int64, error) { // DefaultInt64 returns the int64 value for a given key. // if err != nil return defaltval func (c *IniConfigContainer) DefaultInt64(key string, defaultval int64) int64 { - if v, err := c.Int64(key); err != nil { + v, err := c.Int64(key) + if err != nil { return defaultval - } else { - return v } + return v } // Float returns the float value for a given key. @@ -244,11 +245,11 @@ func (c *IniConfigContainer) Float(key string) (float64, error) { // DefaultFloat returns the float64 value for a given key. // if err != nil return defaltval func (c *IniConfigContainer) DefaultFloat(key string, defaultval float64) float64 { - if v, err := c.Float(key); err != nil { + v, err := c.Float(key) + if err != nil { return defaultval - } else { - return v } + return v } // String returns the string value for a given key. @@ -259,11 +260,11 @@ func (c *IniConfigContainer) String(key string) string { // DefaultString returns the string value for a given key. // if err != nil return defaltval func (c *IniConfigContainer) DefaultString(key string, defaultval string) string { - if v := c.String(key); v == "" { + v := c.String(key) + if v == "" { return defaultval - } else { - return v } + return v } // Strings returns the []string value for a given key. @@ -274,20 +275,19 @@ func (c *IniConfigContainer) Strings(key string) []string { // DefaultStrings returns the []string value for a given key. // if err != nil return defaltval func (c *IniConfigContainer) DefaultStrings(key string, defaultval []string) []string { - if v := c.Strings(key); len(v) == 0 { + v := c.Strings(key) + if len(v) == 0 { return defaultval - } else { - return v } + return v } // GetSection returns map for the given section func (c *IniConfigContainer) GetSection(section string) (map[string]string, error) { if v, ok := c.data[section]; ok { return v, nil - } else { - return nil, errors.New("not exist setction") } + return nil, errors.New("not exist setction") } // SaveConfigFile save the config into file @@ -301,7 +301,7 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { buf := bytes.NewBuffer(nil) // Save default section at first place - if dt, ok := c.data[DEFAULT_SECTION]; ok { + if dt, ok := c.data[defaultSection]; ok { for key, val := range dt { if key != " " { // Write key comments. @@ -325,7 +325,7 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { } // Save named sections for section, dt := range c.data { - if section != DEFAULT_SECTION { + if section != defaultSection { // Write section comments. if v, ok := c.sectionComment[section]; ok { if _, err = buf.WriteString(string(bNumComment) + v + lineBreak); err != nil { @@ -367,7 +367,7 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { return nil } -// WriteValue writes a new value for key. +// Set writes a new value for key. // if write to one section, the key need be "section::key". // if the section is not existed, it panics. func (c *IniConfigContainer) Set(key, value string) error { @@ -379,14 +379,14 @@ func (c *IniConfigContainer) Set(key, value string) error { var ( section, k string - sectionKey []string = strings.Split(key, "::") + sectionKey = strings.Split(key, "::") ) if len(sectionKey) >= 2 { section = sectionKey[0] k = sectionKey[1] } else { - section = DEFAULT_SECTION + section = defaultSection k = sectionKey[0] } @@ -415,13 +415,13 @@ func (c *IniConfigContainer) getdata(key string) string { var ( section, k string - sectionKey []string = strings.Split(strings.ToLower(key), "::") + sectionKey = strings.Split(strings.ToLower(key), "::") ) if len(sectionKey) >= 2 { section = sectionKey[0] k = sectionKey[1] } else { - section = DEFAULT_SECTION + section = defaultSection k = sectionKey[0] } if v, ok := c.data[section]; ok { diff --git a/config/json.go b/config/json.go index e2b53793..6929baad 100644 --- a/config/json.go +++ b/config/json.go @@ -23,12 +23,12 @@ import ( "sync" ) -// JsonConfig is a json config parser and implements Config interface. -type JsonConfig struct { +// JSONConfig is a json config parser and implements Config interface. +type JSONConfig struct { } // Parse returns a ConfigContainer with parsed json config map. -func (js *JsonConfig) Parse(filename string) (ConfigContainer, error) { +func (js *JSONConfig) Parse(filename string) (Configer, error) { file, err := os.Open(filename) if err != nil { return nil, err @@ -43,8 +43,8 @@ func (js *JsonConfig) Parse(filename string) (ConfigContainer, error) { } // ParseData returns a ConfigContainer with json string -func (js *JsonConfig) ParseData(data []byte) (ConfigContainer, error) { - x := &JsonConfigContainer{ +func (js *JSONConfig) ParseData(data []byte) (Configer, error) { + x := &JSONConfigContainer{ data: make(map[string]interface{}), } err := json.Unmarshal(data, &x.data) @@ -59,15 +59,15 @@ func (js *JsonConfig) ParseData(data []byte) (ConfigContainer, error) { return x, nil } -// A Config represents the json configuration. +// JSONConfigContainer A Config represents the json configuration. // Only when get value, support key as section:name type. -type JsonConfigContainer struct { +type JSONConfigContainer struct { data map[string]interface{} sync.RWMutex } // Bool returns the boolean value for a given key. -func (c *JsonConfigContainer) Bool(key string) (bool, error) { +func (c *JSONConfigContainer) Bool(key string) (bool, error) { val := c.getData(key) if val != nil { if v, ok := val.(bool); ok { @@ -80,7 +80,7 @@ func (c *JsonConfigContainer) Bool(key string) (bool, error) { // DefaultBool return the bool value if has no error // otherwise return the defaultval -func (c *JsonConfigContainer) DefaultBool(key string, defaultval bool) bool { +func (c *JSONConfigContainer) DefaultBool(key string, defaultval bool) bool { if v, err := c.Bool(key); err == nil { return v } @@ -88,7 +88,7 @@ func (c *JsonConfigContainer) DefaultBool(key string, defaultval bool) bool { } // Int returns the integer value for a given key. -func (c *JsonConfigContainer) Int(key string) (int, error) { +func (c *JSONConfigContainer) Int(key string) (int, error) { val := c.getData(key) if val != nil { if v, ok := val.(float64); ok { @@ -101,7 +101,7 @@ func (c *JsonConfigContainer) Int(key string) (int, error) { // DefaultInt returns the integer value for a given key. // if err != nil return defaltval -func (c *JsonConfigContainer) DefaultInt(key string, defaultval int) int { +func (c *JSONConfigContainer) DefaultInt(key string, defaultval int) int { if v, err := c.Int(key); err == nil { return v } @@ -109,7 +109,7 @@ func (c *JsonConfigContainer) DefaultInt(key string, defaultval int) int { } // Int64 returns the int64 value for a given key. -func (c *JsonConfigContainer) Int64(key string) (int64, error) { +func (c *JSONConfigContainer) Int64(key string) (int64, error) { val := c.getData(key) if val != nil { if v, ok := val.(float64); ok { @@ -122,7 +122,7 @@ func (c *JsonConfigContainer) Int64(key string) (int64, error) { // DefaultInt64 returns the int64 value for a given key. // if err != nil return defaltval -func (c *JsonConfigContainer) DefaultInt64(key string, defaultval int64) int64 { +func (c *JSONConfigContainer) DefaultInt64(key string, defaultval int64) int64 { if v, err := c.Int64(key); err == nil { return v } @@ -130,7 +130,7 @@ func (c *JsonConfigContainer) DefaultInt64(key string, defaultval int64) int64 { } // Float returns the float value for a given key. -func (c *JsonConfigContainer) Float(key string) (float64, error) { +func (c *JSONConfigContainer) Float(key string) (float64, error) { val := c.getData(key) if val != nil { if v, ok := val.(float64); ok { @@ -143,7 +143,7 @@ func (c *JsonConfigContainer) Float(key string) (float64, error) { // DefaultFloat returns the float64 value for a given key. // if err != nil return defaltval -func (c *JsonConfigContainer) DefaultFloat(key string, defaultval float64) float64 { +func (c *JSONConfigContainer) DefaultFloat(key string, defaultval float64) float64 { if v, err := c.Float(key); err == nil { return v } @@ -151,7 +151,7 @@ func (c *JsonConfigContainer) DefaultFloat(key string, defaultval float64) float } // String returns the string value for a given key. -func (c *JsonConfigContainer) String(key string) string { +func (c *JSONConfigContainer) String(key string) string { val := c.getData(key) if val != nil { if v, ok := val.(string); ok { @@ -163,7 +163,7 @@ func (c *JsonConfigContainer) String(key string) string { // DefaultString returns the string value for a given key. // if err != nil return defaltval -func (c *JsonConfigContainer) DefaultString(key string, defaultval string) string { +func (c *JSONConfigContainer) DefaultString(key string, defaultval string) string { // TODO FIXME should not use "" to replace non existance if v := c.String(key); v != "" { return v @@ -172,7 +172,7 @@ func (c *JsonConfigContainer) DefaultString(key string, defaultval string) strin } // Strings returns the []string value for a given key. -func (c *JsonConfigContainer) Strings(key string) []string { +func (c *JSONConfigContainer) Strings(key string) []string { stringVal := c.String(key) if stringVal == "" { return []string{} @@ -182,7 +182,7 @@ func (c *JsonConfigContainer) Strings(key string) []string { // DefaultStrings returns the []string value for a given key. // if err != nil return defaltval -func (c *JsonConfigContainer) DefaultStrings(key string, defaultval []string) []string { +func (c *JSONConfigContainer) DefaultStrings(key string, defaultval []string) []string { if v := c.Strings(key); len(v) > 0 { return v } @@ -190,7 +190,7 @@ func (c *JsonConfigContainer) DefaultStrings(key string, defaultval []string) [] } // GetSection returns map for the given section -func (c *JsonConfigContainer) GetSection(section string) (map[string]string, error) { +func (c *JSONConfigContainer) GetSection(section string) (map[string]string, error) { if v, ok := c.data[section]; ok { return v.(map[string]string), nil } @@ -198,7 +198,7 @@ func (c *JsonConfigContainer) GetSection(section string) (map[string]string, err } // SaveConfigFile save the config into file -func (c *JsonConfigContainer) SaveConfigFile(filename string) (err error) { +func (c *JSONConfigContainer) SaveConfigFile(filename string) (err error) { // Write configuration file by filename. f, err := os.Create(filename) if err != nil { @@ -214,7 +214,7 @@ func (c *JsonConfigContainer) SaveConfigFile(filename string) (err error) { } // Set writes a new value for key. -func (c *JsonConfigContainer) Set(key, val string) error { +func (c *JSONConfigContainer) Set(key, val string) error { c.Lock() defer c.Unlock() c.data[key] = val @@ -222,7 +222,7 @@ func (c *JsonConfigContainer) Set(key, val string) error { } // DIY returns the raw value by a given key. -func (c *JsonConfigContainer) DIY(key string) (v interface{}, err error) { +func (c *JSONConfigContainer) DIY(key string) (v interface{}, err error) { val := c.getData(key) if val != nil { return val, nil @@ -231,7 +231,7 @@ func (c *JsonConfigContainer) DIY(key string) (v interface{}, err error) { } // section.key or key -func (c *JsonConfigContainer) getData(key string) interface{} { +func (c *JSONConfigContainer) getData(key string) interface{} { if len(key) == 0 { return nil } @@ -261,5 +261,5 @@ func (c *JsonConfigContainer) getData(key string) interface{} { } func init() { - Register("json", &JsonConfig{}) + Register("json", &JSONConfig{}) } diff --git a/config/xml/xml.go b/config/xml/xml.go index a1d9fcdb..4d48f7d2 100644 --- a/config/xml/xml.go +++ b/config/xml/xml.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// package xml for config provider +// Package xml for config provider // // depend on github.com/beego/x2j // @@ -45,20 +45,20 @@ import ( "github.com/beego/x2j" ) -// XmlConfig is a xml config parser and implements Config interface. +// Config is a xml config parser and implements Config interface. // xml configurations should be included in tag. // only support key/value pair as value as each item. -type XMLConfig struct{} +type Config struct{} // Parse returns a ConfigContainer with parsed xml config map. -func (xc *XMLConfig) Parse(filename string) (config.ConfigContainer, error) { +func (xc *Config) Parse(filename string) (config.Configer, error) { file, err := os.Open(filename) if err != nil { return nil, err } defer file.Close() - x := &XMLConfigContainer{data: make(map[string]interface{})} + x := &ConfigContainer{data: make(map[string]interface{})} content, err := ioutil.ReadAll(file) if err != nil { return nil, err @@ -73,84 +73,86 @@ func (xc *XMLConfig) Parse(filename string) (config.ConfigContainer, error) { return x, nil } -func (x *XMLConfig) ParseData(data []byte) (config.ConfigContainer, error) { +// ParseData xml data +func (xc *Config) ParseData(data []byte) (config.Configer, error) { // Save memory data to temporary file tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond())) os.MkdirAll(path.Dir(tmpName), os.ModePerm) if err := ioutil.WriteFile(tmpName, data, 0655); err != nil { return nil, err } - return x.Parse(tmpName) + return xc.Parse(tmpName) } -// A Config represents the xml configuration. -type XMLConfigContainer struct { +// ConfigContainer A Config represents the xml configuration. +type ConfigContainer struct { data map[string]interface{} sync.Mutex } // Bool returns the boolean value for a given key. -func (c *XMLConfigContainer) Bool(key string) (bool, error) { +func (c *ConfigContainer) Bool(key string) (bool, error) { return strconv.ParseBool(c.data[key].(string)) } // DefaultBool return the bool value if has no error // otherwise return the defaultval -func (c *XMLConfigContainer) DefaultBool(key string, defaultval bool) bool { - if v, err := c.Bool(key); err != nil { +func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool { + v, err := c.Bool(key) + if err != nil { return defaultval - } else { - return v } + return v } // Int returns the integer value for a given key. -func (c *XMLConfigContainer) Int(key string) (int, error) { +func (c *ConfigContainer) Int(key string) (int, error) { return strconv.Atoi(c.data[key].(string)) } // DefaultInt returns the integer value for a given key. // if err != nil return defaltval -func (c *XMLConfigContainer) DefaultInt(key string, defaultval int) int { - if v, err := c.Int(key); err != nil { +func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { + v, err := c.Int(key) + if err != nil { return defaultval - } else { - return v } + return v } // Int64 returns the int64 value for a given key. -func (c *XMLConfigContainer) Int64(key string) (int64, error) { +func (c *ConfigContainer) Int64(key string) (int64, error) { return strconv.ParseInt(c.data[key].(string), 10, 64) } // DefaultInt64 returns the int64 value for a given key. // if err != nil return defaltval -func (c *XMLConfigContainer) DefaultInt64(key string, defaultval int64) int64 { - if v, err := c.Int64(key); err != nil { +func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { + v, err := c.Int64(key) + if err != nil { return defaultval - } else { - return v } + return v + } // Float returns the float value for a given key. -func (c *XMLConfigContainer) Float(key string) (float64, error) { +func (c *ConfigContainer) Float(key string) (float64, error) { return strconv.ParseFloat(c.data[key].(string), 64) } // DefaultFloat returns the float64 value for a given key. // if err != nil return defaltval -func (c *XMLConfigContainer) DefaultFloat(key string, defaultval float64) float64 { - if v, err := c.Float(key); err != nil { +func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { + v, err := c.Float(key) + if err != nil { return defaultval - } else { - return v } + return v } // String returns the string value for a given key. -func (c *XMLConfigContainer) String(key string) string { +func (c *ConfigContainer) String(key string) string { if v, ok := c.data[key].(string); ok { return v } @@ -159,40 +161,39 @@ func (c *XMLConfigContainer) String(key string) string { // DefaultString returns the string value for a given key. // if err != nil return defaltval -func (c *XMLConfigContainer) DefaultString(key string, defaultval string) string { - if v := c.String(key); v == "" { +func (c *ConfigContainer) DefaultString(key string, defaultval string) string { + v := c.String(key) + if v == "" { return defaultval - } else { - return v } + return v } // Strings returns the []string value for a given key. -func (c *XMLConfigContainer) Strings(key string) []string { +func (c *ConfigContainer) Strings(key string) []string { return strings.Split(c.String(key), ";") } // DefaultStrings returns the []string value for a given key. // if err != nil return defaltval -func (c *XMLConfigContainer) DefaultStrings(key string, defaultval []string) []string { - if v := c.Strings(key); len(v) == 0 { +func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { + v := c.Strings(key) + if len(v) == 0 { return defaultval - } else { - return v } + return v } // GetSection returns map for the given section -func (c *XMLConfigContainer) GetSection(section string) (map[string]string, error) { +func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { if v, ok := c.data[section]; ok { return v.(map[string]string), nil - } else { - return nil, errors.New("not exist setction") } + return nil, errors.New("not exist setction") } // SaveConfigFile save the config into file -func (c *XMLConfigContainer) SaveConfigFile(filename string) (err error) { +func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { // Write configuration file by filename. f, err := os.Create(filename) if err != nil { @@ -207,8 +208,8 @@ func (c *XMLConfigContainer) SaveConfigFile(filename string) (err error) { return err } -// WriteValue writes a new value for key. -func (c *XMLConfigContainer) Set(key, val string) error { +// Set writes a new value for key. +func (c *ConfigContainer) Set(key, val string) error { c.Lock() defer c.Unlock() c.data[key] = val @@ -216,7 +217,7 @@ func (c *XMLConfigContainer) Set(key, val string) error { } // DIY returns the raw value by a given key. -func (c *XMLConfigContainer) DIY(key string) (v interface{}, err error) { +func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { if v, ok := c.data[key]; ok { return v, nil } @@ -224,5 +225,5 @@ func (c *XMLConfigContainer) DIY(key string) (v interface{}, err error) { } func init() { - config.Register("xml", &XMLConfig{}) + config.Register("xml", &Config{}) } diff --git a/config/yaml/yaml.go b/config/yaml/yaml.go index c5be44a9..f034d3ba 100644 --- a/config/yaml/yaml.go +++ b/config/yaml/yaml.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// package yaml for config provider +// Package yaml for config provider // // depend on github.com/beego/goyaml2 // @@ -46,22 +46,23 @@ import ( "github.com/beego/goyaml2" ) -// YAMLConfig is a yaml config parser and implements Config interface. -type YAMLConfig struct{} +// Config is a yaml config parser and implements Config interface. +type Config struct{} // Parse returns a ConfigContainer with parsed yaml config map. -func (yaml *YAMLConfig) Parse(filename string) (y config.ConfigContainer, err error) { +func (yaml *Config) Parse(filename string) (y config.Configer, err error) { cnf, err := ReadYmlReader(filename) if err != nil { return } - y = &YAMLConfigContainer{ + y = &ConfigContainer{ data: cnf, } return } -func (yaml *YAMLConfig) ParseData(data []byte) (config.ConfigContainer, error) { +// ParseData parse yaml data +func (yaml *Config) ParseData(data []byte) (config.Configer, error) { // Save memory data to temporary file tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond())) os.MkdirAll(path.Dir(tmpName), os.ModePerm) @@ -71,7 +72,7 @@ func (yaml *YAMLConfig) ParseData(data []byte) (config.ConfigContainer, error) { return yaml.Parse(tmpName) } -// Read yaml file to map. +// ReadYmlReader Read yaml file to map. // if json like, use json package, unless goyaml2 package. func ReadYmlReader(path string) (cnf map[string]interface{}, err error) { f, err := os.Open(path) @@ -112,14 +113,14 @@ func ReadYmlReader(path string) (cnf map[string]interface{}, err error) { return } -// A Config represents the yaml configuration. -type YAMLConfigContainer struct { +// ConfigContainer A Config represents the yaml configuration. +type ConfigContainer struct { data map[string]interface{} sync.Mutex } // Bool returns the boolean value for a given key. -func (c *YAMLConfigContainer) Bool(key string) (bool, error) { +func (c *ConfigContainer) Bool(key string) (bool, error) { if v, ok := c.data[key].(bool); ok { return v, nil } @@ -128,16 +129,16 @@ func (c *YAMLConfigContainer) Bool(key string) (bool, error) { // DefaultBool return the bool value if has no error // otherwise return the defaultval -func (c *YAMLConfigContainer) DefaultBool(key string, defaultval bool) bool { - if v, err := c.Bool(key); err != nil { +func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool { + v, err := c.Bool(key) + if err != nil { return defaultval - } else { - return v } + return v } // Int returns the integer value for a given key. -func (c *YAMLConfigContainer) Int(key string) (int, error) { +func (c *ConfigContainer) Int(key string) (int, error) { if v, ok := c.data[key].(int64); ok { return int(v), nil } @@ -146,16 +147,16 @@ func (c *YAMLConfigContainer) Int(key string) (int, error) { // DefaultInt returns the integer value for a given key. // if err != nil return defaltval -func (c *YAMLConfigContainer) DefaultInt(key string, defaultval int) int { - if v, err := c.Int(key); err != nil { +func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { + v, err := c.Int(key) + if err != nil { return defaultval - } else { - return v } + return v } // Int64 returns the int64 value for a given key. -func (c *YAMLConfigContainer) Int64(key string) (int64, error) { +func (c *ConfigContainer) Int64(key string) (int64, error) { if v, ok := c.data[key].(int64); ok { return v, nil } @@ -164,16 +165,16 @@ func (c *YAMLConfigContainer) Int64(key string) (int64, error) { // DefaultInt64 returns the int64 value for a given key. // if err != nil return defaltval -func (c *YAMLConfigContainer) DefaultInt64(key string, defaultval int64) int64 { - if v, err := c.Int64(key); err != nil { +func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { + v, err := c.Int64(key) + if err != nil { return defaultval - } else { - return v } + return v } // Float returns the float value for a given key. -func (c *YAMLConfigContainer) Float(key string) (float64, error) { +func (c *ConfigContainer) Float(key string) (float64, error) { if v, ok := c.data[key].(float64); ok { return v, nil } @@ -182,16 +183,16 @@ func (c *YAMLConfigContainer) Float(key string) (float64, error) { // DefaultFloat returns the float64 value for a given key. // if err != nil return defaltval -func (c *YAMLConfigContainer) DefaultFloat(key string, defaultval float64) float64 { - if v, err := c.Float(key); err != nil { +func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { + v, err := c.Float(key) + if err != nil { return defaultval - } else { - return v } + return v } // String returns the string value for a given key. -func (c *YAMLConfigContainer) String(key string) string { +func (c *ConfigContainer) String(key string) string { if v, ok := c.data[key].(string); ok { return v } @@ -200,40 +201,40 @@ func (c *YAMLConfigContainer) String(key string) string { // DefaultString returns the string value for a given key. // if err != nil return defaltval -func (c *YAMLConfigContainer) DefaultString(key string, defaultval string) string { - if v := c.String(key); v == "" { +func (c *ConfigContainer) DefaultString(key string, defaultval string) string { + v := c.String(key) + if v == "" { return defaultval - } else { - return v } + return v } // Strings returns the []string value for a given key. -func (c *YAMLConfigContainer) Strings(key string) []string { +func (c *ConfigContainer) Strings(key string) []string { return strings.Split(c.String(key), ";") } // DefaultStrings returns the []string value for a given key. // if err != nil return defaltval -func (c *YAMLConfigContainer) DefaultStrings(key string, defaultval []string) []string { - if v := c.Strings(key); len(v) == 0 { +func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { + v := c.Strings(key) + if len(v) == 0 { return defaultval - } else { - return v } + return v } // GetSection returns map for the given section -func (c *YAMLConfigContainer) GetSection(section string) (map[string]string, error) { - if v, ok := c.data[section]; ok { +func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { + v, ok := c.data[section] + if ok { return v.(map[string]string), nil - } else { - return nil, errors.New("not exist setction") } + return nil, errors.New("not exist setction") } // SaveConfigFile save the config into file -func (c *YAMLConfigContainer) SaveConfigFile(filename string) (err error) { +func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { // Write configuration file by filename. f, err := os.Create(filename) if err != nil { @@ -244,8 +245,8 @@ func (c *YAMLConfigContainer) SaveConfigFile(filename string) (err error) { return err } -// WriteValue writes a new value for key. -func (c *YAMLConfigContainer) Set(key, val string) error { +// Set writes a new value for key. +func (c *ConfigContainer) Set(key, val string) error { c.Lock() defer c.Unlock() c.data[key] = val @@ -253,7 +254,7 @@ func (c *YAMLConfigContainer) Set(key, val string) error { } // DIY returns the raw value by a given key. -func (c *YAMLConfigContainer) DIY(key string) (v interface{}, err error) { +func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { if v, ok := c.data[key]; ok { return v, nil } @@ -261,5 +262,5 @@ func (c *YAMLConfigContainer) DIY(key string) (v interface{}, err error) { } func init() { - config.Register("yaml", &YAMLConfig{}) + config.Register("yaml", &Config{}) } diff --git a/config_test.go b/config_test.go index 17645f80..cf4a781d 100644 --- a/config_test.go +++ b/config_test.go @@ -19,11 +19,11 @@ import ( ) func TestDefaults(t *testing.T) { - if FlashName != "BEEGO_FLASH" { + if BConfig.WebConfig.FlashName != "BEEGO_FLASH" { t.Errorf("FlashName was not set to default.") } - if FlashSeperator != "BEEGOFLASH" { + if BConfig.WebConfig.FlashSeparator != "BEEGOFLASH" { t.Errorf("FlashName was not set to default.") } } diff --git a/context/acceptencoder.go b/context/acceptencoder.go new file mode 100644 index 00000000..07c5cb0b --- /dev/null +++ b/context/acceptencoder.go @@ -0,0 +1,198 @@ +// Copyright 2015 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "bytes" + "compress/flate" + "compress/gzip" + "compress/zlib" + "io" + "net/http" + "os" + "strconv" + "strings" + "sync" +) + +type resetWriter interface { + io.Writer + Reset(w io.Writer) +} + +type nopResetWriter struct { + io.Writer +} + +func (n nopResetWriter) Reset(w io.Writer) { + //do nothing +} + +type acceptEncoder struct { + name string + levelEncode func(int) resetWriter + bestSpeedPool *sync.Pool + bestCompressionPool *sync.Pool +} + +func (ac acceptEncoder) encode(wr io.Writer, level int) resetWriter { + if ac.bestSpeedPool == nil || ac.bestCompressionPool == nil { + return nopResetWriter{wr} + } + var rwr resetWriter + switch level { + case flate.BestSpeed: + rwr = ac.bestSpeedPool.Get().(resetWriter) + case flate.BestCompression: + rwr = ac.bestCompressionPool.Get().(resetWriter) + default: + rwr = ac.levelEncode(level) + } + rwr.Reset(wr) + return rwr +} + +func (ac acceptEncoder) put(wr resetWriter, level int) { + if ac.bestSpeedPool == nil || ac.bestCompressionPool == nil { + return + } + wr.Reset(nil) + switch level { + case flate.BestSpeed: + ac.bestSpeedPool.Put(wr) + case flate.BestCompression: + ac.bestCompressionPool.Put(wr) + } +} + +var ( + noneCompressEncoder = acceptEncoder{"", nil, nil, nil} + gzipCompressEncoder = acceptEncoder{"gzip", + func(level int) resetWriter { wr, _ := gzip.NewWriterLevel(nil, level); return wr }, + &sync.Pool{ + New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestSpeed); return wr }, + }, + &sync.Pool{ + New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestCompression); return wr }, + }, + } + + //according to the sec :http://tools.ietf.org/html/rfc2616#section-3.5 ,the deflate compress in http is zlib indeed + //deflate + //The "zlib" format defined in RFC 1950 [31] in combination with + //the "deflate" compression mechanism described in RFC 1951 [29]. + deflateCompressEncoder = acceptEncoder{"deflate", + func(level int) resetWriter { wr, _ := zlib.NewWriterLevel(nil, level); return wr }, + &sync.Pool{ + New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestSpeed); return wr }, + }, + &sync.Pool{ + New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestCompression); return wr }, + }, + } +) + +var ( + encoderMap = map[string]acceptEncoder{ // all the other compress methods will ignore + "gzip": gzipCompressEncoder, + "deflate": deflateCompressEncoder, + "*": gzipCompressEncoder, // * means any compress will accept,we prefer gzip + "identity": noneCompressEncoder, // identity means none-compress + } +) + +// WriteFile reads from file and writes to writer by the specific encoding(gzip/deflate) +func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string, error) { + return writeLevel(encoding, writer, file, flate.BestCompression) +} + +// WriteBody reads writes content to writer by the specific encoding(gzip/deflate) +func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) { + return writeLevel(encoding, writer, bytes.NewReader(content), flate.BestSpeed) +} + +// writeLevel reads from reader,writes to writer by specific encoding and compress level +// the compress level is defined by deflate package +func writeLevel(encoding string, writer io.Writer, reader io.Reader, level int) (bool, string, error) { + var outputWriter resetWriter + var err error + var ce = noneCompressEncoder + + if cf, ok := encoderMap[encoding]; ok { + ce = cf + } + encoding = ce.name + outputWriter = ce.encode(writer, level) + defer ce.put(outputWriter, level) + + _, err = io.Copy(outputWriter, reader) + if err != nil { + return false, "", err + } + + switch outputWriter.(type) { + case io.WriteCloser: + outputWriter.(io.WriteCloser).Close() + } + return encoding != "", encoding, nil +} + +// ParseEncoding will extract the right encoding for response +// the Accept-Encoding's sec is here: +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.3 +func ParseEncoding(r *http.Request) string { + if r == nil { + return "" + } + return parseEncoding(r) +} + +type q struct { + name string + value float64 +} + +func parseEncoding(r *http.Request) string { + acceptEncoding := r.Header.Get("Accept-Encoding") + if acceptEncoding == "" { + return "" + } + var lastQ q + for _, v := range strings.Split(acceptEncoding, ",") { + v = strings.TrimSpace(v) + if v == "" { + continue + } + vs := strings.Split(v, ";") + if len(vs) == 1 { + lastQ = q{vs[0], 1} + break + } + if len(vs) == 2 { + f, _ := strconv.ParseFloat(strings.Replace(vs[1], "q=", "", -1), 64) + if f == 0 { + continue + } + if f > lastQ.value { + lastQ = q{vs[0], f} + } + } + } + if cf, ok := encoderMap[lastQ.name]; ok { + return cf.name + } else { + return "" + } +} diff --git a/context/acceptencoder_test.go b/context/acceptencoder_test.go new file mode 100644 index 00000000..147313c5 --- /dev/null +++ b/context/acceptencoder_test.go @@ -0,0 +1,45 @@ +// Copyright 2015 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "net/http" + "testing" +) + +func Test_ExtractEncoding(t *testing.T) { + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"gzip,deflate"}}}) != "gzip" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"deflate,gzip"}}}) != "deflate" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"gzip;q=.5,deflate"}}}) != "deflate" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"gzip;q=.5,deflate;q=0.3"}}}) != "gzip" { + t.Fail() + } + + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"gzip;q=0,deflate"}}}) != "deflate" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"deflate;q=0.5,gzip;q=0.5,identity"}}}) != "" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"*"}}}) != "gzip" { + t.Fail() + } +} diff --git a/context/context.go b/context/context.go index f6aa85d6..db790ff2 100644 --- a/context/context.go +++ b/context/context.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package context provide the context utils // Usage: // // import "github.com/astaxie/beego/context" @@ -22,10 +23,13 @@ package context import ( + "bufio" "crypto/hmac" "crypto/sha1" "encoding/base64" + "errors" "fmt" + "net" "net/http" "strconv" "strings" @@ -34,14 +38,30 @@ import ( "github.com/astaxie/beego/utils" ) -// Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter. +// NewContext return the Context with Input and Output +func NewContext() *Context { + return &Context{ + Input: NewInput(), + Output: NewOutput(), + } +} + +// Context Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter. // BeegoInput and BeegoOutput provides some api to operate request and response more easily. type Context struct { Input *BeegoInput Output *BeegoOutput Request *http.Request - ResponseWriter http.ResponseWriter - _xsrf_token string + ResponseWriter *Response + _xsrfToken string +} + +// Reset init Context, BeegoInput and BeegoOutput +func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) { + ctx.Request = r + ctx.ResponseWriter = &Response{rw, false, 0} + ctx.Input.Reset(ctx) + ctx.Output.Reset(ctx) } // Redirect does redirection to localurl with http header status code. @@ -54,29 +74,28 @@ func (ctx *Context) Redirect(status int, localurl string) { // Abort stops this request. // if beego.ErrorMaps exists, panic body. func (ctx *Context) Abort(status int, body string) { - ctx.ResponseWriter.WriteHeader(status) panic(body) } -// Write string to response body. +// WriteString Write string to response body. // it sends response body. func (ctx *Context) WriteString(content string) { ctx.ResponseWriter.Write([]byte(content)) } -// Get cookie from request by a given key. +// GetCookie Get cookie from request by a given key. // It's alias of BeegoInput.Cookie. func (ctx *Context) GetCookie(key string) string { return ctx.Input.Cookie(key) } -// Set cookie for response. +// SetCookie Set cookie for response. // It's alias of BeegoOutput.Cookie. func (ctx *Context) SetCookie(name string, value string, others ...interface{}) { ctx.Output.Cookie(name, value, others...) } -// Get secure cookie from request by a given key. +// GetSecureCookie Get secure cookie from request by a given key. func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) { val := ctx.Input.Cookie(key) if val == "" { @@ -103,7 +122,7 @@ func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) { return string(res), true } -// Set Secure cookie for response. +// SetSecureCookie Set Secure cookie for response. func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) { vs := base64.URLEncoding.EncodeToString([]byte(value)) timestamp := strconv.FormatInt(time.Now().UnixNano(), 10) @@ -114,23 +133,23 @@ func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interf ctx.Output.Cookie(name, cookie, others...) } -// XsrfToken creates a xsrf token string and returns. -func (ctx *Context) XsrfToken(key string, expire int64) string { - if ctx._xsrf_token == "" { +// XSRFToken creates a xsrf token string and returns. +func (ctx *Context) XSRFToken(key string, expire int64) string { + if ctx._xsrfToken == "" { token, ok := ctx.GetSecureCookie(key, "_xsrf") if !ok { token = string(utils.RandomCreateBytes(32)) ctx.SetSecureCookie(key, "_xsrf", token, expire) } - ctx._xsrf_token = token + ctx._xsrfToken = token } - return ctx._xsrf_token + return ctx._xsrfToken } -// CheckXsrfCookie checks xsrf token in this request is valid or not. +// CheckXSRFCookie checks xsrf token in this request is valid or not. // the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" // or in form field value named as "_xsrf". -func (ctx *Context) CheckXsrfCookie() bool { +func (ctx *Context) CheckXSRFCookie() bool { token := ctx.Input.Query("_xsrf") if token == "" { token = ctx.Request.Header.Get("X-Xsrftoken") @@ -142,9 +161,57 @@ func (ctx *Context) CheckXsrfCookie() bool { ctx.Abort(403, "'_xsrf' argument missing from POST") return false } - if ctx._xsrf_token != token { + if ctx._xsrfToken != token { ctx.Abort(403, "XSRF cookie does not match POST argument") return false } return true } + +//Response is a wrapper for the http.ResponseWriter +//started set to true if response was written to then don't execute other handler +type Response struct { + http.ResponseWriter + Started bool + Status int +} + +// Write writes the data to the connection as part of an HTTP reply, +// and sets `started` to true. +// started means the response has sent out. +func (w *Response) Write(p []byte) (int, error) { + w.Started = true + return w.ResponseWriter.Write(p) +} + +// WriteHeader sends an HTTP response header with status code, +// and sets `started` to true. +func (w *Response) WriteHeader(code int) { + w.Status = code + w.Started = true + w.ResponseWriter.WriteHeader(code) +} + +// Hijack hijacker for http +func (w *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hj, ok := w.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, errors.New("webserver doesn't support hijacking") + } + return hj.Hijack() +} + +// Flush http.Flusher +func (w *Response) Flush() { + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +// CloseNotify http.CloseNotifier +func (w *Response) CloseNotify() <-chan bool { + if cn, ok := w.ResponseWriter.(http.CloseNotifier); ok { + return cn.CloseNotify() + } + return nil +} diff --git a/context/input.go b/context/input.go index 1985df21..c37204bd 100644 --- a/context/input.go +++ b/context/input.go @@ -17,8 +17,8 @@ package context import ( "bytes" "errors" + "io" "io/ioutil" - "net/http" "net/url" "reflect" "regexp" @@ -31,45 +31,55 @@ import ( // Regexes for checking the accept headers // TODO make sure these are correct var ( - acceptsHtmlRegex = regexp.MustCompile(`(text/html|application/xhtml\+xml)(?:,|$)`) - acceptsXmlRegex = regexp.MustCompile(`(application/xml|text/xml)(?:,|$)`) - acceptsJsonRegex = regexp.MustCompile(`(application/json)(?:,|$)`) + acceptsHTMLRegex = regexp.MustCompile(`(text/html|application/xhtml\+xml)(?:,|$)`) + acceptsXMLRegex = regexp.MustCompile(`(application/xml|text/xml)(?:,|$)`) + acceptsJSONRegex = regexp.MustCompile(`(application/json)(?:,|$)`) + maxParam = 50 ) // BeegoInput operates the http request header, data, cookie and body. // it also contains router params and current session. type BeegoInput struct { - CruSession session.SessionStore - Params map[string]string - Data map[interface{}]interface{} // store some values in this context when calling context in filter or controller. - Request *http.Request - RequestBody []byte - RunController reflect.Type - RunMethod string + Context *Context + CruSession session.Store + pnames []string + pvalues []string + data map[interface{}]interface{} // store some values in this context when calling context in filter or controller. + RequestBody []byte } -// NewInput return BeegoInput generated by http.Request. -func NewInput(req *http.Request) *BeegoInput { +// NewInput return BeegoInput generated by Context. +func NewInput() *BeegoInput { return &BeegoInput{ - Params: make(map[string]string), - Data: make(map[interface{}]interface{}), - Request: req, + pnames: make([]string, 0, maxParam), + pvalues: make([]string, 0, maxParam), + data: make(map[interface{}]interface{}), } } +// Reset init the BeegoInput +func (input *BeegoInput) Reset(ctx *Context) { + input.Context = ctx + input.CruSession = nil + input.pnames = input.pnames[:0] + input.pvalues = input.pvalues[:0] + input.data = nil + input.RequestBody = []byte{} +} + // Protocol returns request protocol name, such as HTTP/1.1 . func (input *BeegoInput) Protocol() string { - return input.Request.Proto + return input.Context.Request.Proto } -// Uri returns full request url with query string, fragment. -func (input *BeegoInput) Uri() string { - return input.Request.RequestURI +// URI returns full request url with query string, fragment. +func (input *BeegoInput) URI() string { + return input.Context.Request.RequestURI } -// Url returns request url path (without query string, fragment). -func (input *BeegoInput) Url() string { - return input.Request.URL.Path +// URL returns request url path (without query string, fragment). +func (input *BeegoInput) URL() string { + return input.Context.Request.URL.Path } // Site returns base site url as scheme://domain type. @@ -79,10 +89,10 @@ func (input *BeegoInput) Site() string { // Scheme returns request scheme as "http" or "https". func (input *BeegoInput) Scheme() string { - if input.Request.URL.Scheme != "" { - return input.Request.URL.Scheme + if input.Context.Request.URL.Scheme != "" { + return input.Context.Request.URL.Scheme } - if input.Request.TLS == nil { + if input.Context.Request.TLS == nil { return "http" } return "https" @@ -97,19 +107,19 @@ func (input *BeegoInput) Domain() string { // Host returns host name. // if no host info in request, return localhost. func (input *BeegoInput) Host() string { - if input.Request.Host != "" { - hostParts := strings.Split(input.Request.Host, ":") + if input.Context.Request.Host != "" { + hostParts := strings.Split(input.Context.Request.Host, ":") if len(hostParts) > 0 { return hostParts[0] } - return input.Request.Host + return input.Context.Request.Host } return "localhost" } // Method returns http request method. func (input *BeegoInput) Method() string { - return input.Request.Method + return input.Context.Request.Method } // Is returns boolean of this request is on given method, such as Is("POST"). @@ -117,37 +127,37 @@ func (input *BeegoInput) Is(method string) bool { return input.Method() == method } -// Is this a GET method request? +// IsGet Is this a GET method request? func (input *BeegoInput) IsGet() bool { return input.Is("GET") } -// Is this a POST method request? +// IsPost Is this a POST method request? func (input *BeegoInput) IsPost() bool { return input.Is("POST") } -// Is this a Head method request? +// IsHead Is this a Head method request? func (input *BeegoInput) IsHead() bool { return input.Is("HEAD") } -// Is this a OPTIONS method request? +// IsOptions Is this a OPTIONS method request? func (input *BeegoInput) IsOptions() bool { return input.Is("OPTIONS") } -// Is this a PUT method request? +// IsPut Is this a PUT method request? func (input *BeegoInput) IsPut() bool { return input.Is("PUT") } -// Is this a DELETE method request? +// IsDelete Is this a DELETE method request? func (input *BeegoInput) IsDelete() bool { return input.Is("DELETE") } -// Is this a PATCH method request? +// IsPatch Is this a PATCH method request? func (input *BeegoInput) IsPatch() bool { return input.Is("PATCH") } @@ -172,19 +182,19 @@ func (input *BeegoInput) IsUpload() bool { return strings.Contains(input.Header("Content-Type"), "multipart/form-data") } -// Checks if request accepts html response -func (input *BeegoInput) AcceptsHtml() bool { - return acceptsHtmlRegex.MatchString(input.Header("Accept")) +// AcceptsHTML Checks if request accepts html response +func (input *BeegoInput) AcceptsHTML() bool { + return acceptsHTMLRegex.MatchString(input.Header("Accept")) } -// Checks if request accepts xml response -func (input *BeegoInput) AcceptsXml() bool { - return acceptsXmlRegex.MatchString(input.Header("Accept")) +// AcceptsXML Checks if request accepts xml response +func (input *BeegoInput) AcceptsXML() bool { + return acceptsXMLRegex.MatchString(input.Header("Accept")) } -// Checks if request accepts json response -func (input *BeegoInput) AcceptsJson() bool { - return acceptsJsonRegex.MatchString(input.Header("Accept")) +// AcceptsJSON Checks if request accepts json response +func (input *BeegoInput) AcceptsJSON() bool { + return acceptsJSONRegex.MatchString(input.Header("Accept")) } // IP returns request client ip. @@ -196,7 +206,7 @@ func (input *BeegoInput) IP() string { rip := strings.Split(ips[0], ":") return rip[0] } - ip := strings.Split(input.Request.RemoteAddr, ":") + ip := strings.Split(input.Context.Request.RemoteAddr, ":") if len(ip) > 0 { if ip[0] != "[" { return ip[0] @@ -236,7 +246,7 @@ func (input *BeegoInput) SubDomains() string { // Port returns request client port. // when error or empty, return 80. func (input *BeegoInput) Port() int { - parts := strings.Split(input.Request.Host, ":") + parts := strings.Split(input.Context.Request.Host, ":") if len(parts) == 2 { port, _ := strconv.Atoi(parts[1]) return port @@ -249,35 +259,59 @@ func (input *BeegoInput) UserAgent() string { return input.Header("User-Agent") } +// ParamsLen return the length of the params +func (input *BeegoInput) ParamsLen() int { + return len(input.pnames) +} + // Param returns router param by a given key. func (input *BeegoInput) Param(key string) string { - if v, ok := input.Params[key]; ok { - return v + for i, v := range input.pnames { + if v == key && i <= len(input.pvalues) { + return input.pvalues[i] + } } return "" } +// Params returns the map[key]value. +func (input *BeegoInput) Params() map[string]string { + m := make(map[string]string) + for i, v := range input.pnames { + if i <= len(input.pvalues) { + m[v] = input.pvalues[i] + } + } + return m +} + +// SetParam will set the param with key and value +func (input *BeegoInput) SetParam(key, val string) { + input.pvalues = append(input.pvalues, val) + input.pnames = append(input.pnames, key) +} + // Query returns input data item string by a given string. func (input *BeegoInput) Query(key string) string { if val := input.Param(key); val != "" { return val } - if input.Request.Form == nil { - input.Request.ParseForm() + if input.Context.Request.Form == nil { + input.Context.Request.ParseForm() } - return input.Request.Form.Get(key) + return input.Context.Request.Form.Get(key) } // Header returns request header item string by a given string. // if non-existed, return empty string. func (input *BeegoInput) Header(key string) string { - return input.Request.Header.Get(key) + return input.Context.Request.Header.Get(key) } // Cookie returns request cookie item string by a given key. // if non-existed, return empty string. func (input *BeegoInput) Cookie(key string) string { - ck, err := input.Request.Cookie(key) + ck, err := input.Context.Request.Cookie(key) if err != nil { return "" } @@ -291,18 +325,27 @@ func (input *BeegoInput) Session(key interface{}) interface{} { } // CopyBody returns the raw request body data as bytes. -func (input *BeegoInput) CopyBody() []byte { - requestbody, _ := ioutil.ReadAll(input.Request.Body) - input.Request.Body.Close() +func (input *BeegoInput) CopyBody(MaxMemory int64) []byte { + safe := &io.LimitedReader{R: input.Context.Request.Body, N: MaxMemory} + requestbody, _ := ioutil.ReadAll(safe) + input.Context.Request.Body.Close() bf := bytes.NewBuffer(requestbody) - input.Request.Body = ioutil.NopCloser(bf) + input.Context.Request.Body = ioutil.NopCloser(bf) input.RequestBody = requestbody return requestbody } +// Data return the implicit data in the input +func (input *BeegoInput) Data() map[interface{}]interface{} { + if input.data == nil { + input.data = make(map[interface{}]interface{}) + } + return input.data +} + // GetData returns the stored data in this context. func (input *BeegoInput) GetData(key interface{}) interface{} { - if v, ok := input.Data[key]; ok { + if v, ok := input.data[key]; ok { return v } return nil @@ -311,17 +354,20 @@ func (input *BeegoInput) GetData(key interface{}) interface{} { // SetData stores data with given key in this context. // This data are only available in this context. func (input *BeegoInput) SetData(key, val interface{}) { - input.Data[key] = val + if input.data == nil { + input.data = make(map[interface{}]interface{}) + } + input.data[key] = val } -// parseForm or parseMultiForm based on Content-type +// ParseFormOrMulitForm parseForm or parseMultiForm based on Content-type func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error { // Parse the body depending on the content type. if strings.Contains(input.Header("Content-Type"), "multipart/form-data") { - if err := input.Request.ParseMultipartForm(maxMemory); err != nil { + if err := input.Context.Request.ParseMultipartForm(maxMemory); err != nil { return errors.New("Error parsing request body:" + err.Error()) } - } else if err := input.Request.ParseForm(); err != nil { + } else if err := input.Context.Request.ParseForm(); err != nil { return errors.New("Error parsing request body:" + err.Error()) } return nil @@ -353,7 +399,7 @@ func (input *BeegoInput) Bind(dest interface{}, key string) error { } func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value { - rv := reflect.Zero(reflect.TypeOf(0)) + rv := reflect.Zero(typ) switch typ.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: val := input.Query(key) @@ -386,19 +432,19 @@ func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value { } rv = input.bindBool(val, typ) case reflect.Slice: - rv = input.bindSlice(&input.Request.Form, key, typ) + rv = input.bindSlice(&input.Context.Request.Form, key, typ) case reflect.Struct: - rv = input.bindStruct(&input.Request.Form, key, typ) + rv = input.bindStruct(&input.Context.Request.Form, key, typ) case reflect.Ptr: rv = input.bindPoint(key, typ) case reflect.Map: - rv = input.bindMap(&input.Request.Form, key, typ) + rv = input.bindMap(&input.Context.Request.Form, key, typ) } return rv } func (input *BeegoInput) bindValue(val string, typ reflect.Type) reflect.Value { - rv := reflect.Zero(reflect.TypeOf(0)) + rv := reflect.Zero(typ) switch typ.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: rv = input.bindInt(val, typ) diff --git a/context/input_test.go b/context/input_test.go index 4566f6d6..618e1254 100644 --- a/context/input_test.go +++ b/context/input_test.go @@ -17,12 +17,15 @@ package context import ( "fmt" "net/http" + "net/http/httptest" "testing" ) func TestParse(t *testing.T) { r, _ := http.NewRequest("GET", "/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil) - beegoInput := NewInput(r) + beegoInput := NewInput() + beegoInput.Context = NewContext() + beegoInput.Context.Reset(httptest.NewRecorder(), r) beegoInput.ParseFormOrMulitForm(1 << 20) var id int @@ -73,7 +76,9 @@ func TestParse(t *testing.T) { func TestSubDomain(t *testing.T) { r, _ := http.NewRequest("GET", "http://www.example.com/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil) - beegoInput := NewInput(r) + beegoInput := NewInput() + beegoInput.Context = NewContext() + beegoInput.Context.Reset(httptest.NewRecorder(), r) subdomain := beegoInput.SubDomains() if subdomain != "www" { @@ -81,13 +86,13 @@ func TestSubDomain(t *testing.T) { } r, _ = http.NewRequest("GET", "http://localhost/", nil) - beegoInput.Request = r + beegoInput.Context.Request = r if beegoInput.SubDomains() != "" { t.Fatal("Subdomain parse error, should be empty, got " + beegoInput.SubDomains()) } r, _ = http.NewRequest("GET", "http://aa.bb.example.com/", nil) - beegoInput.Request = r + beegoInput.Context.Request = r if beegoInput.SubDomains() != "aa.bb" { t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) } @@ -101,13 +106,13 @@ func TestSubDomain(t *testing.T) { */ r, _ = http.NewRequest("GET", "http://example.com/", nil) - beegoInput.Request = r + beegoInput.Context.Request = r if beegoInput.SubDomains() != "" { t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) } r, _ = http.NewRequest("GET", "http://aa.bb.cc.dd.example.com/", nil) - beegoInput.Request = r + beegoInput.Context.Request = r if beegoInput.SubDomains() != "aa.bb.cc.dd" { t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) } diff --git a/context/output.go b/context/output.go index 7edde552..2d756e27 100644 --- a/context/output.go +++ b/context/output.go @@ -16,8 +16,6 @@ package context import ( "bytes" - "compress/flate" - "compress/gzip" "encoding/json" "encoding/xml" "errors" @@ -45,6 +43,12 @@ func NewOutput() *BeegoOutput { return &BeegoOutput{} } +// Reset init BeegoOutput +func (output *BeegoOutput) Reset(ctx *Context) { + output.Context = ctx + output.Status = 0 +} + // Header sets response header item string via given key. func (output *BeegoOutput) Header(key, val string) { output.Context.ResponseWriter.Header().Set(key, val) @@ -54,30 +58,16 @@ func (output *BeegoOutput) Header(key, val string) { // if EnableGzip, compress content string. // it sends out response body directly. func (output *BeegoOutput) Body(content []byte) { - output_writer := output.Context.ResponseWriter.(io.Writer) - if output.EnableGzip == true && output.Context.Input.Header("Accept-Encoding") != "" { - splitted := strings.SplitN(output.Context.Input.Header("Accept-Encoding"), ",", -1) - encodings := make([]string, len(splitted)) - - for i, val := range splitted { - encodings[i] = strings.TrimSpace(val) - } - for _, val := range encodings { - if val == "gzip" { - output.Header("Content-Encoding", "gzip") - output_writer, _ = gzip.NewWriterLevel(output.Context.ResponseWriter, gzip.BestSpeed) - - break - } else if val == "deflate" { - output.Header("Content-Encoding", "deflate") - output_writer, _ = flate.NewWriter(output.Context.ResponseWriter, flate.BestSpeed) - break - } - } + var encoding string + var buf = &bytes.Buffer{} + if output.EnableGzip { + encoding = ParseEncoding(output.Context.Request) + } + if b, n, _ := WriteBody(encoding, buf, content); b { + output.Header("Content-Encoding", n) } else { output.Header("Content-Length", strconv.Itoa(len(content))) } - // Write status code if it has been set manually // Set it to 0 afterwards to prevent "multiple response.WriteHeader calls" if output.Status != 0 { @@ -85,13 +75,7 @@ func (output *BeegoOutput) Body(content []byte) { output.Status = 0 } - output_writer.Write(content) - switch output_writer.(type) { - case *gzip.Writer: - output_writer.(*gzip.Writer).Close() - case *flate.Writer: - output_writer.(*flate.Writer).Close() - } + io.Copy(output.Context.ResponseWriter, buf) } // Cookie sets cookie value via given key. @@ -100,29 +84,25 @@ func (output *BeegoOutput) Cookie(name string, value string, others ...interface var b bytes.Buffer fmt.Fprintf(&b, "%s=%s", sanitizeName(name), sanitizeValue(value)) - //fix cookie not work in IE - if len(others) > 0 { - switch v := others[0].(type) { - case int: - if v > 0 { - fmt.Fprintf(&b, "; Expires=%s; Max-Age=%d", time.Now().Add(time.Duration(v) * time.Second).UTC().Format(time.RFC1123), v) - } else if v < 0 { - fmt.Fprintf(&b, "; Max-Age=0") - } - case int64: - if v > 0 { - fmt.Fprintf(&b, "; Expires=%s; Max-Age=%d", time.Now().Add(time.Duration(v) * time.Second).UTC().Format(time.RFC1123), v) - } else if v < 0 { - fmt.Fprintf(&b, "; Max-Age=0") - } - case int32: - if v > 0 { - fmt.Fprintf(&b, "; Expires=%s; Max-Age=%d", time.Now().Add(time.Duration(v) * time.Second).UTC().Format(time.RFC1123), v) - } else if v < 0 { - fmt.Fprintf(&b, "; Max-Age=0") - } - } - } + //fix cookie not work in IE + if len(others) > 0 { + var maxAge int64 + + switch v := others[0].(type) { + case int: + maxAge = int64(v) + case int32: + maxAge = int64(v) + case int64: + maxAge = v + } + + if maxAge > 0 { + fmt.Fprintf(&b, "; Expires=%s; Max-Age=%d", time.Now().Add(time.Duration(maxAge)*time.Second).UTC().Format(time.RFC1123), maxAge) + } else { + fmt.Fprintf(&b, "; Max-Age=0") + } + } // the settings below // Path, Domain, Secure, HttpOnly @@ -188,9 +168,9 @@ func sanitizeValue(v string) string { return cookieValueSanitizer.Replace(v) } -// Json writes json to response body. +// JSON writes json to response body. // if coding is true, it converts utf-8 to \u0000 type. -func (output *BeegoOutput) Json(data interface{}, hasIndent bool, coding bool) error { +func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, coding bool) error { output.Header("Content-Type", "application/json; charset=utf-8") var content []byte var err error @@ -204,14 +184,14 @@ func (output *BeegoOutput) Json(data interface{}, hasIndent bool, coding bool) e return err } if coding { - content = []byte(stringsToJson(string(content))) + content = []byte(stringsToJSON(string(content))) } output.Body(content) return nil } -// Jsonp writes jsonp to response body. -func (output *BeegoOutput) Jsonp(data interface{}, hasIndent bool) error { +// JSONP writes jsonp to response body. +func (output *BeegoOutput) JSONP(data interface{}, hasIndent bool) error { output.Header("Content-Type", "application/javascript; charset=utf-8") var content []byte var err error @@ -228,16 +208,16 @@ func (output *BeegoOutput) Jsonp(data interface{}, hasIndent bool) error { if callback == "" { return errors.New(`"callback" parameter required`) } - callback_content := bytes.NewBufferString(" " + template.JSEscapeString(callback)) - callback_content.WriteString("(") - callback_content.Write(content) - callback_content.WriteString(");\r\n") - output.Body(callback_content.Bytes()) + callbackContent := bytes.NewBufferString(" " + template.JSEscapeString(callback)) + callbackContent.WriteString("(") + callbackContent.Write(content) + callbackContent.WriteString(");\r\n") + output.Body(callbackContent.Bytes()) return nil } -// Xml writes xml string to response body. -func (output *BeegoOutput) Xml(data interface{}, hasIndent bool) error { +// XML writes xml string to response body. +func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error { output.Header("Content-Type", "application/xml; charset=utf-8") var content []byte var err error @@ -331,7 +311,7 @@ func (output *BeegoOutput) IsNotFound(status int) bool { return output.Status == 404 } -// IsClient returns boolean of this request client sends error data. +// IsClientError returns boolean of this request client sends error data. // HTTP 4xx means forbidden. func (output *BeegoOutput) IsClientError(status int) bool { return output.Status >= 400 && output.Status < 500 @@ -343,7 +323,7 @@ func (output *BeegoOutput) IsServerError(status int) bool { return output.Status >= 500 && output.Status < 600 } -func stringsToJson(str string) string { +func stringsToJSON(str string) string { rs := []rune(str) jsons := "" for _, r := range rs { @@ -357,7 +337,7 @@ func stringsToJson(str string) string { return jsons } -// Sessions sets session item value with given key. +// Session sets session item value with given key. func (output *BeegoOutput) Session(name interface{}, value interface{}) { output.Context.Input.CruSession.Set(name, value) } diff --git a/controller.go b/controller.go index a8b9e8d3..ab261a56 100644 --- a/controller.go +++ b/controller.go @@ -19,7 +19,6 @@ import ( "errors" "html/template" "io" - "io/ioutil" "mime/multipart" "net/http" "net/url" @@ -34,18 +33,19 @@ import ( //commonly used mime-types const ( - applicationJson = "application/json" - applicationXml = "application/xml" - textXml = "text/xml" + applicationJSON = "application/json" + applicationXML = "application/xml" + textXML = "text/xml" ) var ( - // custom error when user stop request handler manually. - USERSTOPRUN = errors.New("User stop run") - GlobalControllerRouter map[string][]ControllerComments = make(map[string][]ControllerComments) //pkgpath+controller:comments + // ErrAbort custom error when user stop request handler manually. + ErrAbort = errors.New("User stop run") + // GlobalControllerRouter store comments with controller. pkgpath+controller:comments + GlobalControllerRouter = make(map[string][]ControllerComments) ) -// store the comment for the controller method +// ControllerComments store the comment for the controller method type ControllerComments struct { Method string Router string @@ -56,22 +56,31 @@ type ControllerComments struct { // Controller defines some basic http request handler operations, such as // http context, template and view, session and xsrf. type Controller struct { - Ctx *context.Context - Data map[interface{}]interface{} + // context data + Ctx *context.Context + Data map[interface{}]interface{} + + // route controller info controllerName string actionName string - TplNames string + methodMapping map[string]func() //method:routertree + gotofunc string + AppController interface{} + + // template data + TplName string Layout string LayoutSections map[string]string // the key is the section name and the value is the template name TplExt string - _xsrf_token string - gotofunc string - CruSession session.SessionStore - XSRFExpire int - AppController interface{} EnableRender bool - EnableXSRF bool - methodMapping map[string]func() //method:routertree + + // xsrf data + _xsrfToken string + XSRFExpire int + EnableXSRF bool + + // session + CruSession session.Store } // ControllerInterface is an interface to uniform all controller handler. @@ -87,8 +96,8 @@ type ControllerInterface interface { Options() Finish() Render() error - XsrfToken() string - CheckXsrfCookie() bool + XSRFToken() string + CheckXSRFCookie() bool HandlerFunc(fn string) bool URLMapping() } @@ -96,7 +105,7 @@ type ControllerInterface interface { // Init generates default values of controller operations. func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) { c.Layout = "" - c.TplNames = "" + c.TplName = "" c.controllerName = controllerName c.actionName = actionName c.Ctx = ctx @@ -104,19 +113,15 @@ func (c *Controller) Init(ctx *context.Context, controllerName, actionName strin c.AppController = app c.EnableRender = true c.EnableXSRF = true - c.Data = ctx.Input.Data + c.Data = ctx.Input.Data() c.methodMapping = make(map[string]func()) } // Prepare runs after Init before request function execution. -func (c *Controller) Prepare() { - -} +func (c *Controller) Prepare() {} // Finish runs after request function execution. -func (c *Controller) Finish() { - -} +func (c *Controller) Finish() {} // Get adds a request function to handle GET request. func (c *Controller) Get() { @@ -153,20 +158,19 @@ func (c *Controller) Options() { http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405) } -// call function fn +// HandlerFunc call function with the name func (c *Controller) HandlerFunc(fnname string) bool { if v, ok := c.methodMapping[fnname]; ok { v() return true - } else { - return false } + return false } // URLMapping register the internal Controller router. -func (c *Controller) URLMapping() { -} +func (c *Controller) URLMapping() {} +// Mapping the method to function func (c *Controller) Mapping(method string, fn func()) { c.methodMapping[method] = fn } @@ -177,13 +181,11 @@ func (c *Controller) Render() error { return nil } rb, err := c.RenderBytes() - if err != nil { return err - } else { - c.Ctx.Output.Header("Content-Type", "text/html; charset=utf-8") - c.Ctx.Output.Body(rb) } + c.Ctx.Output.Header("Content-Type", "text/html; charset=utf-8") + c.Ctx.Output.Body(rb) return nil } @@ -196,24 +198,33 @@ func (c *Controller) RenderString() (string, error) { // RenderBytes returns the bytes of rendered template string. Do not send out response. func (c *Controller) RenderBytes() ([]byte, error) { //if the controller has set layout, then first get the tplname's content set the content to the layout + var buf bytes.Buffer if c.Layout != "" { - if c.TplNames == "" { - c.TplNames = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt + if c.TplName == "" { + c.TplName = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt } - if RunMode == "dev" { - BuildTemplate(ViewsPath) + + if BConfig.RunMode == DEV { + buildFiles := []string{c.TplName} + if c.LayoutSections != nil { + for _, sectionTpl := range c.LayoutSections { + if sectionTpl == "" { + continue + } + buildFiles = append(buildFiles, sectionTpl) + } + } + BuildTemplate(BConfig.WebConfig.ViewsPath, buildFiles...) } - newbytes := bytes.NewBufferString("") - if _, ok := BeeTemplates[c.TplNames]; !ok { - panic("can't find templatefile in the path:" + c.TplNames) + if _, ok := BeeTemplates[c.TplName]; !ok { + panic("can't find templatefile in the path:" + c.TplName) } - err := BeeTemplates[c.TplNames].ExecuteTemplate(newbytes, c.TplNames, c.Data) + err := BeeTemplates[c.TplName].ExecuteTemplate(&buf, c.TplName, c.Data) if err != nil { Trace("template Execute err:", err) return nil, err } - tplcontent, _ := ioutil.ReadAll(newbytes) - c.Data["LayoutContent"] = template.HTML(string(tplcontent)) + c.Data["LayoutContent"] = template.HTML(buf.String()) if c.LayoutSections != nil { for sectionName, sectionTpl := range c.LayoutSections { @@ -222,44 +233,41 @@ func (c *Controller) RenderBytes() ([]byte, error) { continue } - sectionBytes := bytes.NewBufferString("") - err = BeeTemplates[sectionTpl].ExecuteTemplate(sectionBytes, sectionTpl, c.Data) + buf.Reset() + err = BeeTemplates[sectionTpl].ExecuteTemplate(&buf, sectionTpl, c.Data) if err != nil { Trace("template Execute err:", err) return nil, err } - sectionContent, _ := ioutil.ReadAll(sectionBytes) - c.Data[sectionName] = template.HTML(string(sectionContent)) + c.Data[sectionName] = template.HTML(buf.String()) } } - ibytes := bytes.NewBufferString("") - err = BeeTemplates[c.Layout].ExecuteTemplate(ibytes, c.Layout, c.Data) + buf.Reset() + err = BeeTemplates[c.Layout].ExecuteTemplate(&buf, c.Layout, c.Data) if err != nil { Trace("template Execute err:", err) return nil, err } - icontent, _ := ioutil.ReadAll(ibytes) - return icontent, nil - } else { - if c.TplNames == "" { - c.TplNames = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt - } - if RunMode == "dev" { - BuildTemplate(ViewsPath) - } - ibytes := bytes.NewBufferString("") - if _, ok := BeeTemplates[c.TplNames]; !ok { - panic("can't find templatefile in the path:" + c.TplNames) - } - err := BeeTemplates[c.TplNames].ExecuteTemplate(ibytes, c.TplNames, c.Data) - if err != nil { - Trace("template Execute err:", err) - return nil, err - } - icontent, _ := ioutil.ReadAll(ibytes) - return icontent, nil + return buf.Bytes(), nil } + + if c.TplName == "" { + c.TplName = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt + } + if BConfig.RunMode == DEV { + BuildTemplate(BConfig.WebConfig.ViewsPath, c.TplName) + } + if _, ok := BeeTemplates[c.TplName]; !ok { + panic("can't find templatefile in the path:" + c.TplName) + } + buf.Reset() + err := BeeTemplates[c.TplName].ExecuteTemplate(&buf, c.TplName, c.Data) + if err != nil { + Trace("template Execute err:", err) + return nil, err + } + return buf.Bytes(), nil } // Redirect sends the redirection response to url with status code. @@ -267,7 +275,7 @@ func (c *Controller) Redirect(url string, code int) { c.Ctx.Redirect(code, url) } -// Aborts stops controller handler and show the error data if code is defined in ErrorMap or code string. +// Abort stops controller handler and show the error data if code is defined in ErrorMap or code string. func (c *Controller) Abort(code string) { status, err := strconv.Atoi(code) if err != nil { @@ -285,74 +293,69 @@ func (c *Controller) CustomAbort(status int, body string) { } // last panic user string c.Ctx.ResponseWriter.Write([]byte(body)) - panic(USERSTOPRUN) + panic(ErrAbort) } // StopRun makes panic of USERSTOPRUN error and go to recover function if defined. func (c *Controller) StopRun() { - panic(USERSTOPRUN) + panic(ErrAbort) } -// UrlFor does another controller handler in this request function. +// URLFor does another controller handler in this request function. // it goes to this controller method if endpoint is not clear. -func (c *Controller) UrlFor(endpoint string, values ...interface{}) string { - if len(endpoint) <= 0 { +func (c *Controller) URLFor(endpoint string, values ...interface{}) string { + if len(endpoint) == 0 { return "" } if endpoint[0] == '.' { - return UrlFor(reflect.Indirect(reflect.ValueOf(c.AppController)).Type().Name()+endpoint, values...) - } else { - return UrlFor(endpoint, values...) + return URLFor(reflect.Indirect(reflect.ValueOf(c.AppController)).Type().Name()+endpoint, values...) } + return URLFor(endpoint, values...) } -// ServeJson sends a json response with encoding charset. -func (c *Controller) ServeJson(encoding ...bool) { - var hasIndent bool - var hasencoding bool - if RunMode == "prod" { +// ServeJSON sends a json response with encoding charset. +func (c *Controller) ServeJSON(encoding ...bool) { + var ( + hasIndent = true + hasEncoding = false + ) + if BConfig.RunMode == PROD { hasIndent = false - } else { - hasIndent = true } if len(encoding) > 0 && encoding[0] == true { - hasencoding = true + hasEncoding = true } - c.Ctx.Output.Json(c.Data["json"], hasIndent, hasencoding) + c.Ctx.Output.JSON(c.Data["json"], hasIndent, hasEncoding) } -// ServeJsonp sends a jsonp response. -func (c *Controller) ServeJsonp() { - var hasIndent bool - if RunMode == "prod" { +// ServeJSONP sends a jsonp response. +func (c *Controller) ServeJSONP() { + hasIndent := true + if BConfig.RunMode == PROD { hasIndent = false - } else { - hasIndent = true } - c.Ctx.Output.Jsonp(c.Data["jsonp"], hasIndent) + c.Ctx.Output.JSONP(c.Data["jsonp"], hasIndent) } -// ServeXml sends xml response. -func (c *Controller) ServeXml() { - var hasIndent bool - if RunMode == "prod" { +// ServeXML sends xml response. +func (c *Controller) ServeXML() { + hasIndent := true + if BConfig.RunMode == PROD { hasIndent = false - } else { - hasIndent = true } - c.Ctx.Output.Xml(c.Data["xml"], hasIndent) + c.Ctx.Output.XML(c.Data["xml"], hasIndent) } // ServeFormatted serve Xml OR Json, depending on the value of the Accept header func (c *Controller) ServeFormatted() { accept := c.Ctx.Input.Header("Accept") switch accept { - case applicationJson: - c.ServeJson() - case applicationXml, textXml: - c.ServeXml() + case applicationJSON: + c.ServeJSON() + case applicationXML, textXML: + c.ServeXML() default: - c.ServeJson() + c.ServeJSON() } } @@ -371,16 +374,13 @@ func (c *Controller) ParseForm(obj interface{}) error { // GetString returns the input value by key string or the default value while it's present and input is blank func (c *Controller) GetString(key string, def ...string) string { - var defv string - if len(def) > 0 { - defv = def[0] - } - if v := c.Ctx.Input.Query(key); v != "" { return v - } else { - return defv } + if len(def) > 0 { + return def[0] + } + return "" } // GetStrings returns the input string slice by key string or the default value while it's present and input is blank @@ -391,106 +391,81 @@ func (c *Controller) GetStrings(key string, def ...[]string) []string { defv = def[0] } - f := c.Input() - if f == nil { + if f := c.Input(); f == nil { return defv + } else { + if vs := f[key]; len(vs) > 0 { + return vs + } } - vs := f[key] - if len(vs) > 0 { - return vs - } else { - return defv - } + return defv } // GetInt returns input as an int or the default value while it's present and input is blank func (c *Controller) GetInt(key string, def ...int) (int, error) { - if strv := c.Ctx.Input.Query(key); strv != "" { - return strconv.Atoi(strv) - } else if len(def) > 0 { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { return def[0], nil - } else { - return strconv.Atoi(strv) } + return strconv.Atoi(strv) } // GetInt8 return input as an int8 or the default value while it's present and input is blank func (c *Controller) GetInt8(key string, def ...int8) (int8, error) { - if strv := c.Ctx.Input.Query(key); strv != "" { - i64, err := strconv.ParseInt(strv, 10, 8) - i8 := int8(i64) - return i8, err - } else if len(def) > 0 { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { return def[0], nil - } else { - i64, err := strconv.ParseInt(strv, 10, 8) - i8 := int8(i64) - return i8, err } + i64, err := strconv.ParseInt(strv, 10, 8) + return int8(i64), err } // GetInt16 returns input as an int16 or the default value while it's present and input is blank func (c *Controller) GetInt16(key string, def ...int16) (int16, error) { - if strv := c.Ctx.Input.Query(key); strv != "" { - i64, err := strconv.ParseInt(strv, 10, 16) - i16 := int16(i64) - return i16, err - } else if len(def) > 0 { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { return def[0], nil - } else { - i64, err := strconv.ParseInt(strv, 10, 16) - i16 := int16(i64) - return i16, err } + i64, err := strconv.ParseInt(strv, 10, 16) + return int16(i64), err } // GetInt32 returns input as an int32 or the default value while it's present and input is blank func (c *Controller) GetInt32(key string, def ...int32) (int32, error) { - if strv := c.Ctx.Input.Query(key); strv != "" { - i64, err := strconv.ParseInt(c.Ctx.Input.Query(key), 10, 32) - i32 := int32(i64) - return i32, err - } else if len(def) > 0 { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { return def[0], nil - } else { - i64, err := strconv.ParseInt(c.Ctx.Input.Query(key), 10, 32) - i32 := int32(i64) - return i32, err } + i64, err := strconv.ParseInt(strv, 10, 32) + return int32(i64), err } // GetInt64 returns input value as int64 or the default value while it's present and input is blank. func (c *Controller) GetInt64(key string, def ...int64) (int64, error) { - if strv := c.Ctx.Input.Query(key); strv != "" { - return strconv.ParseInt(strv, 10, 64) - } else if len(def) > 0 { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { return def[0], nil - } else { - return strconv.ParseInt(strv, 10, 64) } + return strconv.ParseInt(strv, 10, 64) } // GetBool returns input value as bool or the default value while it's present and input is blank. func (c *Controller) GetBool(key string, def ...bool) (bool, error) { - if strv := c.Ctx.Input.Query(key); strv != "" { - return strconv.ParseBool(strv) - } else if len(def) > 0 { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { return def[0], nil - } else { - return strconv.ParseBool(strv) } + return strconv.ParseBool(strv) } // GetFloat returns input value as float64 or the default value while it's present and input is blank. func (c *Controller) GetFloat(key string, def ...float64) (float64, error) { - if strv := c.Ctx.Input.Query(key); strv != "" { - return strconv.ParseFloat(strv, 64) - } else if len(def) > 0 { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { return def[0], nil - } else { - return strconv.ParseFloat(strv, 64) } + return strconv.ParseFloat(strv, 64) } // GetFile returns the file data in file upload field named as key. @@ -527,8 +502,7 @@ func (c *Controller) GetFile(key string) (multipart.File, *multipart.FileHeader, // } // } func (c *Controller) GetFiles(key string) ([]*multipart.FileHeader, error) { - files, ok := c.Ctx.Request.MultipartForm.File[key] - if ok { + if files, ok := c.Ctx.Request.MultipartForm.File[key]; ok { return files, nil } return nil, http.ErrMissingFile @@ -552,7 +526,7 @@ func (c *Controller) SaveToFile(fromfile, tofile string) error { } // StartSession starts session and load old session data info this controller. -func (c *Controller) StartSession() session.SessionStore { +func (c *Controller) StartSession() session.Store { if c.CruSession == nil { c.CruSession = c.Ctx.Input.CruSession } @@ -575,7 +549,7 @@ func (c *Controller) GetSession(name interface{}) interface{} { return c.CruSession.Get(name) } -// SetSession removes value from session. +// DelSession removes value from session. func (c *Controller) DelSession(name interface{}) { if c.CruSession == nil { c.StartSession() @@ -589,7 +563,7 @@ func (c *Controller) SessionRegenerateID() { if c.CruSession != nil { c.CruSession.SessionRelease(c.Ctx.ResponseWriter) } - c.CruSession = GlobalSessions.SessionRegenerateId(c.Ctx.ResponseWriter, c.Ctx.Request) + c.CruSession = GlobalSessions.SessionRegenerateID(c.Ctx.ResponseWriter, c.Ctx.Request) c.Ctx.Input.CruSession = c.CruSession } @@ -614,37 +588,35 @@ func (c *Controller) SetSecureCookie(Secret, name, value string, others ...inter c.Ctx.SetSecureCookie(Secret, name, value, others...) } -// XsrfToken creates a xsrf token string and returns. -func (c *Controller) XsrfToken() string { - if c._xsrf_token == "" { - var expire int64 +// XSRFToken creates a CSRF token string and returns. +func (c *Controller) XSRFToken() string { + if c._xsrfToken == "" { + expire := int64(BConfig.WebConfig.XSRFExpire) if c.XSRFExpire > 0 { expire = int64(c.XSRFExpire) - } else { - expire = int64(XSRFExpire) } - c._xsrf_token = c.Ctx.XsrfToken(XSRFKEY, expire) + c._xsrfToken = c.Ctx.XSRFToken(BConfig.WebConfig.XSRFKey, expire) } - return c._xsrf_token + return c._xsrfToken } -// CheckXsrfCookie checks xsrf token in this request is valid or not. +// CheckXSRFCookie checks xsrf token in this request is valid or not. // the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" // or in form field value named as "_xsrf". -func (c *Controller) CheckXsrfCookie() bool { +func (c *Controller) CheckXSRFCookie() bool { if !c.EnableXSRF { return true } - return c.Ctx.CheckXsrfCookie() + return c.Ctx.CheckXSRFCookie() } -// XsrfFormHtml writes an input field contains xsrf token value. -func (c *Controller) XsrfFormHtml() string { - return "" +// XSRFFormHTML writes an input field contains xsrf token value. +func (c *Controller) XSRFFormHTML() string { + return `` } // GetControllerAndAction gets the executing controller name and action name. -func (c *Controller) GetControllerAndAction() (controllerName, actionName string) { +func (c *Controller) GetControllerAndAction() (string, string) { return c.controllerName, c.actionName } diff --git a/controller_test.go b/controller_test.go index 15938cdc..51d3a5b7 100644 --- a/controller_test.go +++ b/controller_test.go @@ -15,61 +15,63 @@ package beego import ( - "fmt" + "testing" + "github.com/astaxie/beego/context" ) -func ExampleGetInt() { - - i := &context.BeegoInput{Params: map[string]string{"age": "40"}} +func TestGetInt(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") ctx := &context.Context{Input: i} ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetInt("age") - fmt.Printf("%T", val) - //Output: int + if val != 40 { + t.Errorf("TestGetInt expect 40,get %T,%v", val, val) + } } -func ExampleGetInt8() { - - i := &context.BeegoInput{Params: map[string]string{"age": "40"}} +func TestGetInt8(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") ctx := &context.Context{Input: i} ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetInt8("age") - fmt.Printf("%T", val) + if val != 40 { + t.Errorf("TestGetInt8 expect 40,get %T,%v", val, val) + } //Output: int8 } -func ExampleGetInt16() { - - i := &context.BeegoInput{Params: map[string]string{"age": "40"}} +func TestGetInt16(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") ctx := &context.Context{Input: i} ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetInt16("age") - fmt.Printf("%T", val) - //Output: int16 + if val != 40 { + t.Errorf("TestGetInt16 expect 40,get %T,%v", val, val) + } } -func ExampleGetInt32() { - - i := &context.BeegoInput{Params: map[string]string{"age": "40"}} +func TestGetInt32(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") ctx := &context.Context{Input: i} ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetInt32("age") - fmt.Printf("%T", val) - //Output: int32 + if val != 40 { + t.Errorf("TestGetInt32 expect 40,get %T,%v", val, val) + } } -func ExampleGetInt64() { - - i := &context.BeegoInput{Params: map[string]string{"age": "40"}} +func TestGetInt64(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") ctx := &context.Context{Input: i} ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetInt64("age") - fmt.Printf("%T", val) - //Output: int64 + if val != 40 { + t.Errorf("TestGeetInt64 expect 40,get %T,%v", val, val) + } } diff --git a/doc.go b/doc.go new file mode 100644 index 00000000..4be305b3 --- /dev/null +++ b/doc.go @@ -0,0 +1,17 @@ +/* +Package beego provide a MVC framework +beego: an open-source, high-performance, modular, full-stack web framework + +It is used for rapid development of RESTful APIs, web apps and backend services in Go. +beego is inspired by Tornado, Sinatra and Flask with the added benefit of some Go-specific features such as interfaces and struct embedding. + + package main + import "github.com/astaxie/beego" + + func main() { + beego.Run() + } + +more infomation: http://beego.me +*/ +package beego diff --git a/docs.go b/docs.go index aaad205e..72532876 100644 --- a/docs.go +++ b/docs.go @@ -15,37 +15,24 @@ package beego import ( - "encoding/json" - "github.com/astaxie/beego/context" ) -var GlobalDocApi map[string]interface{} - -func init() { - if EnableDocs { - GlobalDocApi = make(map[string]interface{}) - } -} +// GlobalDocAPI store the swagger api documents +var GlobalDocAPI = make(map[string]interface{}) func serverDocs(ctx *context.Context) { var obj interface{} if splat := ctx.Input.Param(":splat"); splat == "" { - obj = GlobalDocApi["Root"] + obj = GlobalDocAPI["Root"] } else { - if v, ok := GlobalDocApi[splat]; ok { + if v, ok := GlobalDocAPI[splat]; ok { obj = v } } if obj != nil { - bt, err := json.Marshal(obj) - if err != nil { - ctx.Output.SetStatus(504) - return - } - ctx.Output.Header("Content-Type", "application/json;charset=UTF-8") ctx.Output.Header("Access-Control-Allow-Origin", "*") - ctx.Output.Body(bt) + ctx.Output.JSON(obj, false, false) return } ctx.Output.SetStatus(404) diff --git a/error.go b/error.go index 99a1fcf3..af57b7c7 100644 --- a/error.go +++ b/error.go @@ -82,16 +82,17 @@ var tpl = ` ` // render default application error page with error and stack string. -func showErr(err interface{}, ctx *context.Context, Stack string) { +func showErr(err interface{}, ctx *context.Context, stack string) { t, _ := template.New("beegoerrortemp").Parse(tpl) - data := make(map[string]string) - data["AppError"] = AppName + ":" + fmt.Sprint(err) - data["RequestMethod"] = ctx.Input.Method() - data["RequestURL"] = ctx.Input.Uri() - data["RemoteAddr"] = ctx.Input.IP() - data["Stack"] = Stack - data["BeegoVersion"] = VERSION - data["GoVersion"] = runtime.Version() + data := map[string]string{ + "AppError": fmt.Sprintf("%s:%v", BConfig.AppName, err), + "RequestMethod": ctx.Input.Method(), + "RequestURL": ctx.Input.URI(), + "RemoteAddr": ctx.Input.IP(), + "Stack": stack, + "BeegoVersion": VERSION, + "GoVersion": runtime.Version(), + } ctx.ResponseWriter.WriteHeader(500) t.Execute(ctx.ResponseWriter, data) } @@ -204,47 +205,48 @@ type errorInfo struct { } // map of http handlers for each error string. -var ErrorMaps map[string]*errorInfo - -func init() { - ErrorMaps = make(map[string]*errorInfo) -} +// there is 10 kinds default error(40x and 50x) +var ErrorMaps = make(map[string]*errorInfo, 10) // show 401 unauthorized error. func unauthorized(rw http.ResponseWriter, r *http.Request) { t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := make(map[string]interface{}) - data["Title"] = "Unauthorized" + data := map[string]interface{}{ + "Title": http.StatusText(401), + "BeegoVersion": VERSION, + } data["Content"] = template.HTML("
The page you have requested can't be authorized." + "
Perhaps you are here because:" + "

") - data["BeegoVersion"] = VERSION t.Execute(rw, data) } // show 402 Payment Required func paymentRequired(rw http.ResponseWriter, r *http.Request) { t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := make(map[string]interface{}) - data["Title"] = "Payment Required" + data := map[string]interface{}{ + "Title": http.StatusText(402), + "BeegoVersion": VERSION, + } data["Content"] = template.HTML("
The page you have requested Payment Required." + "
Perhaps you are here because:" + "

") - data["BeegoVersion"] = VERSION t.Execute(rw, data) } // show 403 forbidden error. func forbidden(rw http.ResponseWriter, r *http.Request) { t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := make(map[string]interface{}) - data["Title"] = "Forbidden" + data := map[string]interface{}{ + "Title": http.StatusText(403), + "BeegoVersion": VERSION, + } data["Content"] = template.HTML("
The page you have requested is forbidden." + "
Perhaps you are here because:" + "

") - data["BeegoVersion"] = VERSION t.Execute(rw, data) } // show 404 notfound error. func notFound(rw http.ResponseWriter, r *http.Request) { t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := make(map[string]interface{}) - data["Title"] = "Page Not Found" + data := map[string]interface{}{ + "Title": http.StatusText(404), + "BeegoVersion": VERSION, + } data["Content"] = template.HTML("
The page you have requested has flown the coop." + "
Perhaps you are here because:" + "

") - data["BeegoVersion"] = VERSION t.Execute(rw, data) } // show 405 Method Not Allowed func methodNotAllowed(rw http.ResponseWriter, r *http.Request) { t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := make(map[string]interface{}) - data["Title"] = "Method Not Allowed" + data := map[string]interface{}{ + "Title": http.StatusText(405), + "BeegoVersion": VERSION, + } data["Content"] = template.HTML("
The method you have requested Not Allowed." + "
Perhaps you are here because:" + "

") - data["BeegoVersion"] = VERSION t.Execute(rw, data) } // show 500 internal server error. func internalServerError(rw http.ResponseWriter, r *http.Request) { t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := make(map[string]interface{}) - data["Title"] = "Internal Server Error" + data := map[string]interface{}{ + "Title": http.StatusText(500), + "BeegoVersion": VERSION, + } data["Content"] = template.HTML("
The page you have requested is down right now." + "

") - data["BeegoVersion"] = VERSION t.Execute(rw, data) } // show 501 Not Implemented. func notImplemented(rw http.ResponseWriter, r *http.Request) { t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := make(map[string]interface{}) - data["Title"] = "Not Implemented" + data := map[string]interface{}{ + "Title": http.StatusText(504), + "BeegoVersion": VERSION, + } data["Content"] = template.HTML("
The page you have requested is Not Implemented." + "

") - data["BeegoVersion"] = VERSION t.Execute(rw, data) } // show 502 Bad Gateway. func badGateway(rw http.ResponseWriter, r *http.Request) { t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := make(map[string]interface{}) - data["Title"] = "Bad Gateway" + data := map[string]interface{}{ + "Title": http.StatusText(502), + "BeegoVersion": VERSION, + } data["Content"] = template.HTML("
The page you have requested is down right now." + "

") - data["BeegoVersion"] = VERSION t.Execute(rw, data) } // show 503 service unavailable error. func serviceUnavailable(rw http.ResponseWriter, r *http.Request) { t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := make(map[string]interface{}) - data["Title"] = "Service Unavailable" + data := map[string]interface{}{ + "Title": http.StatusText(503), + "BeegoVersion": VERSION, + } data["Content"] = template.HTML("
The page you have requested is unavailable." + "
Perhaps you are here because:" + "

") - data["BeegoVersion"] = VERSION t.Execute(rw, data) } // show 504 Gateway Timeout. func gatewayTimeout(rw http.ResponseWriter, r *http.Request) { t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := make(map[string]interface{}) - data["Title"] = "Gateway Timeout" + data := map[string]interface{}{ + "Title": http.StatusText(504), + "BeegoVersion": VERSION, + } data["Content"] = template.HTML("
The page you have requested is unavailable." + "
Perhaps you are here because:" + "

") - data["BeegoVersion"] = VERSION t.Execute(rw, data) } -// register default error http handlers, 404,401,403,500 and 503. -func registerDefaultErrorHandler() { - if _, ok := ErrorMaps["401"]; !ok { - Errorhandler("401", unauthorized) - } - - if _, ok := ErrorMaps["402"]; !ok { - Errorhandler("402", paymentRequired) - } - - if _, ok := ErrorMaps["403"]; !ok { - Errorhandler("403", forbidden) - } - - if _, ok := ErrorMaps["404"]; !ok { - Errorhandler("404", notFound) - } - - if _, ok := ErrorMaps["405"]; !ok { - Errorhandler("405", methodNotAllowed) - } - - if _, ok := ErrorMaps["500"]; !ok { - Errorhandler("500", internalServerError) - } - if _, ok := ErrorMaps["501"]; !ok { - Errorhandler("501", notImplemented) - } - if _, ok := ErrorMaps["502"]; !ok { - Errorhandler("502", badGateway) - } - - if _, ok := ErrorMaps["503"]; !ok { - Errorhandler("503", serviceUnavailable) - } - - if _, ok := ErrorMaps["504"]; !ok { - Errorhandler("504", gatewayTimeout) - } -} - // ErrorHandler registers http.HandlerFunc to each http err code string. // usage: // beego.ErrorHandler("404",NotFound) // beego.ErrorHandler("500",InternalServerError) -func Errorhandler(code string, h http.HandlerFunc) *App { - errinfo := &errorInfo{} - errinfo.errorType = errorTypeHandler - errinfo.handler = h - errinfo.method = code - ErrorMaps[code] = errinfo +func ErrorHandler(code string, h http.HandlerFunc) *App { + ErrorMaps[code] = &errorInfo{ + errorType: errorTypeHandler, + handler: h, + method: code, + } return BeeApp } // ErrorController registers ControllerInterface to each http err code string. // usage: -// beego.ErrorHandler(&controllers.ErrorController{}) +// beego.ErrorController(&controllers.ErrorController{}) func ErrorController(c ControllerInterface) *App { reflectVal := reflect.ValueOf(c) rt := reflectVal.Type() ct := reflect.Indirect(reflectVal).Type() for i := 0; i < rt.NumMethod(); i++ { - if !utils.InSlice(rt.Method(i).Name, exceptMethod) && strings.HasPrefix(rt.Method(i).Name, "Error") { - errinfo := &errorInfo{} - errinfo.errorType = errorTypeController - errinfo.controllerType = ct - errinfo.method = rt.Method(i).Name - errname := strings.TrimPrefix(rt.Method(i).Name, "Error") - ErrorMaps[errname] = errinfo + methodName := rt.Method(i).Name + if !utils.InSlice(methodName, exceptMethod) && strings.HasPrefix(methodName, "Error") { + errName := strings.TrimPrefix(methodName, "Error") + ErrorMaps[errName] = &errorInfo{ + errorType: errorTypeController, + controllerType: ct, + method: methodName, + } } } return BeeApp } // show error string as simple text message. -// if error string is empty, show 500 error as default. -func exception(errcode string, ctx *context.Context) { - code, err := strconv.Atoi(errcode) - if err != nil { - code = 503 +// if error string is empty, show 503 or 500 error as default. +func exception(errCode string, ctx *context.Context) { + atoi := func(code string) int { + v, err := strconv.Atoi(code) + if err == nil { + return v + } + return 503 } - if h, ok := ErrorMaps[errcode]; ok { - executeError(h, ctx, code) - return - } else if h, ok := ErrorMaps["503"]; ok { - executeError(h, ctx, code) - return - } else { - ctx.ResponseWriter.WriteHeader(code) - ctx.WriteString(errcode) + + for _, ec := range []string{errCode, "503", "500"} { + if h, ok := ErrorMaps[ec]; ok { + executeError(h, ctx, atoi(ec)) + return + } } + //if 50x error has been removed from errorMap + ctx.ResponseWriter.WriteHeader(atoi(errCode)) + ctx.WriteString(errCode) } func executeError(err *errorInfo, ctx *context.Context, code int) { if err.errorType == errorTypeHandler { - ctx.ResponseWriter.WriteHeader(code) err.handler(ctx.ResponseWriter, ctx.Request) return } @@ -473,12 +443,11 @@ func executeError(err *errorInfo, ctx *context.Context, code int) { execController.URLMapping() - in := make([]reflect.Value, 0) method := vc.MethodByName(err.method) - method.Call(in) + method.Call([]reflect.Value{}) //render template - if AutoRender { + if BConfig.WebConfig.AutoRender { if err := execController.Render(); err != nil { panic(err) } diff --git a/example/beeapi/conf/app.conf b/example/beeapi/conf/app.conf deleted file mode 100644 index fd1681a2..00000000 --- a/example/beeapi/conf/app.conf +++ /dev/null @@ -1,5 +0,0 @@ -appname = beeapi -httpport = 8080 -runmode = dev -autorender = false -copyrequestbody = true diff --git a/example/beeapi/controllers/default.go b/example/beeapi/controllers/default.go deleted file mode 100644 index a2184075..00000000 --- a/example/beeapi/controllers/default.go +++ /dev/null @@ -1,63 +0,0 @@ -// Beego (http://beego.me/) -// @description beego is an open-source, high-performance web framework for the Go programming language. -// @link http://github.com/astaxie/beego for the canonical source repository -// @license http://github.com/astaxie/beego/blob/master/LICENSE -// @authors astaxie - -package controllers - -import ( - "encoding/json" - - "github.com/astaxie/beego" - "github.com/astaxie/beego/example/beeapi/models" -) - -type ObjectController struct { - beego.Controller -} - -func (o *ObjectController) Post() { - var ob models.Object - json.Unmarshal(o.Ctx.Input.RequestBody, &ob) - objectid := models.AddOne(ob) - o.Data["json"] = map[string]string{"ObjectId": objectid} - o.ServeJson() -} - -func (o *ObjectController) Get() { - objectId := o.Ctx.Input.Params[":objectId"] - if objectId != "" { - ob, err := models.GetOne(objectId) - if err != nil { - o.Data["json"] = err - } else { - o.Data["json"] = ob - } - } else { - obs := models.GetAll() - o.Data["json"] = obs - } - o.ServeJson() -} - -func (o *ObjectController) Put() { - objectId := o.Ctx.Input.Params[":objectId"] - var ob models.Object - json.Unmarshal(o.Ctx.Input.RequestBody, &ob) - - err := models.Update(objectId, ob.Score) - if err != nil { - o.Data["json"] = err - } else { - o.Data["json"] = "update success!" - } - o.ServeJson() -} - -func (o *ObjectController) Delete() { - objectId := o.Ctx.Input.Params[":objectId"] - models.Delete(objectId) - o.Data["json"] = "delete success!" - o.ServeJson() -} diff --git a/example/beeapi/main.go b/example/beeapi/main.go deleted file mode 100644 index c1250e03..00000000 --- a/example/beeapi/main.go +++ /dev/null @@ -1,30 +0,0 @@ -// Beego (http://beego.me/) - -// @description beego is an open-source, high-performance web framework for the Go programming language. - -// @link http://github.com/astaxie/beego for the canonical source repository - -// @license http://github.com/astaxie/beego/blob/master/LICENSE - -// @authors astaxie - -package main - -import ( - "github.com/astaxie/beego" - "github.com/astaxie/beego/example/beeapi/controllers" -) - -// Objects - -// URL HTTP Verb Functionality -// /object POST Creating Objects -// /object/ GET Retrieving Objects -// /object/ PUT Updating Objects -// /object GET Queries -// /object/ DELETE Deleting Objects - -func main() { - beego.RESTRouter("/object", &controllers.ObjectController{}) - beego.Run() -} diff --git a/example/beeapi/models/object.go b/example/beeapi/models/object.go deleted file mode 100644 index 46109c50..00000000 --- a/example/beeapi/models/object.go +++ /dev/null @@ -1,58 +0,0 @@ -// Beego (http://beego.me/) -// @description beego is an open-source, high-performance web framework for the Go programming language. -// @link http://github.com/astaxie/beego for the canonical source repository -// @license http://github.com/astaxie/beego/blob/master/LICENSE -// @authors astaxie - -package models - -import ( - "errors" - "strconv" - "time" -) - -var ( - Objects map[string]*Object -) - -type Object struct { - ObjectId string - Score int64 - PlayerName string -} - -func init() { - Objects = make(map[string]*Object) - Objects["hjkhsbnmn123"] = &Object{"hjkhsbnmn123", 100, "astaxie"} - Objects["mjjkxsxsaa23"] = &Object{"mjjkxsxsaa23", 101, "someone"} -} - -func AddOne(object Object) (ObjectId string) { - object.ObjectId = "astaxie" + strconv.FormatInt(time.Now().UnixNano(), 10) - Objects[object.ObjectId] = &object - return object.ObjectId -} - -func GetOne(ObjectId string) (object *Object, err error) { - if v, ok := Objects[ObjectId]; ok { - return v, nil - } - return nil, errors.New("ObjectId Not Exist") -} - -func GetAll() map[string]*Object { - return Objects -} - -func Update(ObjectId string, Score int64) (err error) { - if v, ok := Objects[ObjectId]; ok { - v.Score = Score - return nil - } - return errors.New("ObjectId Not Exist") -} - -func Delete(ObjectId string) { - delete(Objects, ObjectId) -} diff --git a/example/chat/conf/app.conf b/example/chat/conf/app.conf deleted file mode 100644 index 9d5a823f..00000000 --- a/example/chat/conf/app.conf +++ /dev/null @@ -1,3 +0,0 @@ -appname = chat -httpport = 8080 -runmode = dev diff --git a/example/chat/controllers/default.go b/example/chat/controllers/default.go deleted file mode 100644 index 95ef8a9b..00000000 --- a/example/chat/controllers/default.go +++ /dev/null @@ -1,20 +0,0 @@ -// Beego (http://beego.me/) -// @description beego is an open-source, high-performance web framework for the Go programming language. -// @link http://github.com/astaxie/beego for the canonical source repository -// @license http://github.com/astaxie/beego/blob/master/LICENSE -// @authors Unknwon - -package controllers - -import ( - "github.com/astaxie/beego" -) - -type MainController struct { - beego.Controller -} - -func (m *MainController) Get() { - m.Data["host"] = m.Ctx.Request.Host - m.TplNames = "index.tpl" -} diff --git a/example/chat/controllers/ws.go b/example/chat/controllers/ws.go deleted file mode 100644 index fc3917b3..00000000 --- a/example/chat/controllers/ws.go +++ /dev/null @@ -1,181 +0,0 @@ -// Beego (http://beego.me/) -// @description beego is an open-source, high-performance web framework for the Go programming language. -// @link http://github.com/astaxie/beego for the canonical source repository -// @license http://github.com/astaxie/beego/blob/master/LICENSE -// @authors Unknwon - -package controllers - -import ( - "io/ioutil" - "math/rand" - "net/http" - "time" - - "github.com/astaxie/beego" - "github.com/gorilla/websocket" -) - -const ( - // Time allowed to write a message to the client. - writeWait = 10 * time.Second - - // Time allowed to read the next message from the client. - readWait = 60 * time.Second - - // Send pings to client with this period. Must be less than readWait. - pingPeriod = (readWait * 9) / 10 - - // Maximum message size allowed from client. - maxMessageSize = 512 -) - -func init() { - rand.Seed(time.Now().UTC().UnixNano()) - go h.run() -} - -// connection is an middleman between the websocket connection and the hub. -type connection struct { - username string - - // The websocket connection. - ws *websocket.Conn - - // Buffered channel of outbound messages. - send chan []byte -} - -// readPump pumps messages from the websocket connection to the hub. -func (c *connection) readPump() { - defer func() { - h.unregister <- c - c.ws.Close() - }() - c.ws.SetReadLimit(maxMessageSize) - c.ws.SetReadDeadline(time.Now().Add(readWait)) - for { - op, r, err := c.ws.NextReader() - if err != nil { - break - } - switch op { - case websocket.PongMessage: - c.ws.SetReadDeadline(time.Now().Add(readWait)) - case websocket.TextMessage: - message, err := ioutil.ReadAll(r) - if err != nil { - break - } - h.broadcast <- []byte(c.username + "_" + time.Now().Format("15:04:05") + ":" + string(message)) - } - } -} - -// write writes a message with the given opCode and payload. -func (c *connection) write(opCode int, payload []byte) error { - c.ws.SetWriteDeadline(time.Now().Add(writeWait)) - return c.ws.WriteMessage(opCode, payload) -} - -// writePump pumps messages from the hub to the websocket connection. -func (c *connection) writePump() { - ticker := time.NewTicker(pingPeriod) - defer func() { - ticker.Stop() - c.ws.Close() - }() - for { - select { - case message, ok := <-c.send: - if !ok { - c.write(websocket.CloseMessage, []byte{}) - return - } - if err := c.write(websocket.TextMessage, message); err != nil { - return - } - case <-ticker.C: - if err := c.write(websocket.PingMessage, []byte{}); err != nil { - return - } - } - } -} - -type hub struct { - // Registered connections. - connections map[*connection]bool - - // Inbound messages from the connections. - broadcast chan []byte - - // Register requests from the connections. - register chan *connection - - // Unregister requests from connections. - unregister chan *connection -} - -var h = &hub{ - broadcast: make(chan []byte, maxMessageSize), - register: make(chan *connection, 1), - unregister: make(chan *connection, 1), - connections: make(map[*connection]bool), -} - -func (h *hub) run() { - for { - select { - case c := <-h.register: - h.connections[c] = true - case c := <-h.unregister: - delete(h.connections, c) - close(c.send) - case m := <-h.broadcast: - for c := range h.connections { - select { - case c.send <- m: - default: - close(c.send) - delete(h.connections, c) - } - } - } - } -} - -type WSController struct { - beego.Controller -} - -var upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, -} - -func (w *WSController) Get() { - ws, err := upgrader.Upgrade(w.Ctx.ResponseWriter, w.Ctx.Request, nil) - if _, ok := err.(websocket.HandshakeError); ok { - http.Error(w.Ctx.ResponseWriter, "Not a websocket handshake", 400) - return - } else if err != nil { - return - } - c := &connection{send: make(chan []byte, 256), ws: ws, username: randomString(10)} - h.register <- c - go c.writePump() - c.readPump() -} - -func randomString(l int) string { - bytes := make([]byte, l) - for i := 0; i < l; i++ { - bytes[i] = byte(randInt(65, 90)) - } - return string(bytes) -} - -func randInt(min int, max int) int { - return min + rand.Intn(max-min) -} diff --git a/example/chat/main.go b/example/chat/main.go deleted file mode 100644 index be055b25..00000000 --- a/example/chat/main.go +++ /dev/null @@ -1,17 +0,0 @@ -// Beego (http://beego.me/) -// @description beego is an open-source, high-performance web framework for the Go programming language. -// @link http://github.com/astaxie/beego for the canonical source repository -// @license http://github.com/astaxie/beego/blob/master/LICENSE -// @authors Unknwon -package main - -import ( - "github.com/astaxie/beego" - "github.com/astaxie/beego/example/chat/controllers" -) - -func main() { - beego.Router("/", &controllers.MainController{}) - beego.Router("/ws", &controllers.WSController{}) - beego.Run() -} diff --git a/example/chat/views/index.tpl b/example/chat/views/index.tpl deleted file mode 100644 index 3a9d8838..00000000 --- a/example/chat/views/index.tpl +++ /dev/null @@ -1,92 +0,0 @@ - - - -Chat Example - - - - - -
-
- - -
- - \ No newline at end of file diff --git a/filter.go b/filter.go index f673ab66..863223f7 100644 --- a/filter.go +++ b/filter.go @@ -32,14 +32,12 @@ type FilterRouter struct { // ValidRouter checks if the current request is matched by this filter. // If the request is matched, the values of the URL parameters defined // by the filter pattern are also returned. -func (f *FilterRouter) ValidRouter(url string) (bool, map[string]string) { - isok, params := f.tree.Match(url) - if isok == nil { - return false, nil - } - if isok, ok := isok.(bool); ok { - return isok, params - } else { - return false, nil +func (f *FilterRouter) ValidRouter(url string, ctx *context.Context) bool { + isOk := f.tree.Match(url, ctx) + if isOk != nil { + if b, ok := isOk.(bool); ok { + return b + } } + return false } diff --git a/filter_test.go b/filter_test.go index ff6f750b..d9928d8d 100644 --- a/filter_test.go +++ b/filter_test.go @@ -20,10 +20,16 @@ import ( "testing" "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" ) +func init() { + BeeLogger = logs.NewLogger(10000) + BeeLogger.SetLogger("console", "") +} + var FilterUser = func(ctx *context.Context) { - ctx.Output.Body([]byte("i am " + ctx.Input.Params[":last"] + ctx.Input.Params[":first"])) + ctx.Output.Body([]byte("i am " + ctx.Input.Param(":last") + ctx.Input.Param(":first"))) } func TestFilter(t *testing.T) { diff --git a/flash.go b/flash.go index 5ccf339a..a6485a17 100644 --- a/flash.go +++ b/flash.go @@ -83,27 +83,27 @@ func (fd *FlashData) Store(c *Controller) { c.Data["flash"] = fd.Data var flashValue string for key, value := range fd.Data { - flashValue += "\x00" + key + "\x23" + FlashSeperator + "\x23" + value + "\x00" + flashValue += "\x00" + key + "\x23" + BConfig.WebConfig.FlashSeparator + "\x23" + value + "\x00" } - c.Ctx.SetCookie(FlashName, url.QueryEscape(flashValue), 0, "/") + c.Ctx.SetCookie(BConfig.WebConfig.FlashName, url.QueryEscape(flashValue), 0, "/") } // ReadFromRequest parsed flash data from encoded values in cookie. func ReadFromRequest(c *Controller) *FlashData { flash := NewFlash() - if cookie, err := c.Ctx.Request.Cookie(FlashName); err == nil { + if cookie, err := c.Ctx.Request.Cookie(BConfig.WebConfig.FlashName); err == nil { v, _ := url.QueryUnescape(cookie.Value) vals := strings.Split(v, "\x00") for _, v := range vals { if len(v) > 0 { - kv := strings.Split(v, "\x23"+FlashSeperator+"\x23") + kv := strings.Split(v, "\x23"+BConfig.WebConfig.FlashSeparator+"\x23") if len(kv) == 2 { flash.Data[kv[0]] = kv[1] } } } //read one time then delete it - c.Ctx.SetCookie(FlashName, "", -1, "/") + c.Ctx.SetCookie(BConfig.WebConfig.FlashName, "", -1, "/") } c.Data["flash"] = flash.Data return flash diff --git a/flash_test.go b/flash_test.go index b655f552..640d54de 100644 --- a/flash_test.go +++ b/flash_test.go @@ -30,7 +30,7 @@ func (t *TestFlashController) TestWriteFlash() { flash.Notice("TestFlashString") flash.Store(&t.Controller) // we choose to serve json because we don't want to load a template html file - t.ServeJson(true) + t.ServeJSON(true) } func TestFlashHeader(t *testing.T) { diff --git a/grace/conn.go b/grace/conn.go index 2cf3a93d..6807e1ac 100644 --- a/grace/conn.go +++ b/grace/conn.go @@ -1,13 +1,28 @@ package grace -import "net" +import ( + "errors" + "net" +) type graceConn struct { net.Conn - server *graceServer + server *Server } -func (c graceConn) Close() error { +func (c graceConn) Close() (err error) { + defer func() { + if r := recover(); r != nil { + switch x := r.(type) { + case string: + err = errors.New(x) + case error: + err = x + default: + err = errors.New("Unknown panic") + } + } + }() c.server.wg.Done() return c.Conn.Close() } diff --git a/grace/grace.go b/grace/grace.go index e5577267..af530d50 100644 --- a/grace/grace.go +++ b/grace/grace.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package grace use to hot reload // Description: http://grisha.org/blog/2014/06/03/graceful-restart-in-golang/ // // Usage: @@ -32,7 +33,7 @@ // mux := http.NewServeMux() // mux.HandleFunc("/hello", handler) // -// err := grace.ListenAndServe("localhost:8080", mux1) +// err := grace.ListenAndServe("localhost:8080", mux) // if err != nil { // log.Println(err) // } @@ -52,46 +53,53 @@ import ( ) const ( - PRE_SIGNAL = iota - POST_SIGNAL - - STATE_INIT - STATE_RUNNING - STATE_SHUTTING_DOWN - STATE_TERMINATE + // PreSignal is the position to add filter before signal + PreSignal = iota + // PostSignal is the position to add filter after signal + PostSignal + // StateInit represent the application inited + StateInit + // StateRunning represent the application is running + StateRunning + // StateShuttingDown represent the application is shutting down + StateShuttingDown + // StateTerminate represent the application is killed + StateTerminate ) var ( regLock *sync.Mutex - runningServers map[string]*graceServer + runningServers map[string]*Server runningServersOrder []string socketPtrOffsetMap map[string]uint runningServersForked bool - DefaultReadTimeOut time.Duration - DefaultWriteTimeOut time.Duration + // DefaultReadTimeOut is the HTTP read timeout + DefaultReadTimeOut time.Duration + // DefaultWriteTimeOut is the HTTP Write timeout + DefaultWriteTimeOut time.Duration + // DefaultMaxHeaderBytes is the Max HTTP Herder size, default is 0, no limit DefaultMaxHeaderBytes int - DefaultTimeout time.Duration + // DefaultTimeout is the shutdown server's timeout. default is 60s + DefaultTimeout = 60 * time.Second isChild bool socketOrder string + once sync.Once ) -func init() { +func onceInit() { regLock = &sync.Mutex{} flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)") flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started") - runningServers = make(map[string]*graceServer) + runningServers = make(map[string]*Server) runningServersOrder = []string{} socketPtrOffsetMap = make(map[string]uint) - - DefaultMaxHeaderBytes = 0 - - DefaultTimeout = 60 * time.Second } // NewServer returns a new graceServer. -func NewServer(addr string, handler http.Handler) (srv *graceServer) { +func NewServer(addr string, handler http.Handler) (srv *Server) { + once.Do(onceInit) regLock.Lock() defer regLock.Unlock() if !flag.Parsed() { @@ -105,23 +113,23 @@ func NewServer(addr string, handler http.Handler) (srv *graceServer) { socketPtrOffsetMap[addr] = uint(len(runningServersOrder)) } - srv = &graceServer{ + srv = &Server{ wg: sync.WaitGroup{}, sigChan: make(chan os.Signal), isChild: isChild, SignalHooks: map[int]map[os.Signal][]func(){ - PRE_SIGNAL: map[os.Signal][]func(){ + PreSignal: map[os.Signal][]func(){ syscall.SIGHUP: []func(){}, syscall.SIGINT: []func(){}, syscall.SIGTERM: []func(){}, }, - POST_SIGNAL: map[os.Signal][]func(){ + PostSignal: map[os.Signal][]func(){ syscall.SIGHUP: []func(){}, syscall.SIGINT: []func(){}, syscall.SIGTERM: []func(){}, }, }, - state: STATE_INIT, + state: StateInit, Network: "tcp", } srv.Server = &http.Server{} @@ -137,13 +145,13 @@ func NewServer(addr string, handler http.Handler) (srv *graceServer) { return } -// refer http.ListenAndServe +// ListenAndServe refer http.ListenAndServe func ListenAndServe(addr string, handler http.Handler) error { server := NewServer(addr, handler) return server.ListenAndServe() } -// refer http.ListenAndServeTLS +// ListenAndServeTLS refer http.ListenAndServeTLS func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error { server := NewServer(addr, handler) return server.ListenAndServeTLS(certFile, keyFile) diff --git a/grace/listener.go b/grace/listener.go index 8c5d4f9b..5439d0b2 100644 --- a/grace/listener.go +++ b/grace/listener.go @@ -11,10 +11,10 @@ type graceListener struct { net.Listener stop chan error stopped bool - server *graceServer + server *Server } -func newGraceListener(l net.Listener, srv *graceServer) (el *graceListener) { +func newGraceListener(l net.Listener, srv *Server) (el *graceListener) { el = &graceListener{ Listener: l, stop: make(chan error), @@ -46,17 +46,17 @@ func (gl *graceListener) Accept() (c net.Conn, err error) { return } -func (el *graceListener) Close() error { - if el.stopped { +func (gl *graceListener) Close() error { + if gl.stopped { return syscall.EINVAL } - el.stop <- nil - return <-el.stop + gl.stop <- nil + return <-gl.stop } -func (el *graceListener) File() *os.File { +func (gl *graceListener) File() *os.File { // returns a dup(2) - FD_CLOEXEC flag *not* set - tl := el.Listener.(*net.TCPListener) + tl := gl.Listener.(*net.TCPListener) fl, _ := tl.File() return fl } diff --git a/grace/server.go b/grace/server.go index aea8d7d3..f4512ded 100644 --- a/grace/server.go +++ b/grace/server.go @@ -15,7 +15,8 @@ import ( "time" ) -type graceServer struct { +// Server embedded http.Server +type Server struct { *http.Server GraceListener net.Listener SignalHooks map[int]map[os.Signal][]func() @@ -30,19 +31,19 @@ type graceServer struct { // Serve accepts incoming connections on the Listener l, // creating a new service goroutine for each. // The service goroutines read requests and then call srv.Handler to reply to them. -func (srv *graceServer) Serve() (err error) { - srv.state = STATE_RUNNING +func (srv *Server) Serve() (err error) { + srv.state = StateRunning err = srv.Server.Serve(srv.GraceListener) log.Println(syscall.Getpid(), "Waiting for connections to finish...") srv.wg.Wait() - srv.state = STATE_TERMINATE + srv.state = StateTerminate return } // ListenAndServe listens on the TCP network address srv.Addr and then calls Serve // to handle requests on incoming connections. If srv.Addr is blank, ":http" is // used. -func (srv *graceServer) ListenAndServe() (err error) { +func (srv *Server) ListenAndServe() (err error) { addr := srv.Addr if addr == "" { addr = ":http" @@ -83,7 +84,7 @@ func (srv *graceServer) ListenAndServe() (err error) { // CA's certificate. // // If srv.Addr is blank, ":https" is used. -func (srv *graceServer) ListenAndServeTLS(certFile, keyFile string) (err error) { +func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { addr := srv.Addr if addr == "" { addr = ":https" @@ -131,9 +132,9 @@ func (srv *graceServer) ListenAndServeTLS(certFile, keyFile string) (err error) // getListener either opens a new socket to listen on, or takes the acceptor socket // it got passed when restarted. -func (srv *graceServer) getListener(laddr string) (l net.Listener, err error) { +func (srv *Server) getListener(laddr string) (l net.Listener, err error) { if srv.isChild { - var ptrOffset uint = 0 + var ptrOffset uint if len(socketPtrOffsetMap) > 0 { ptrOffset = socketPtrOffsetMap[laddr] log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr]) @@ -157,7 +158,7 @@ func (srv *graceServer) getListener(laddr string) (l net.Listener, err error) { // handleSignals listens for os Signals and calls any hooked in function that the // user had registered with the signal. -func (srv *graceServer) handleSignals() { +func (srv *Server) handleSignals() { var sig os.Signal signal.Notify( @@ -170,7 +171,7 @@ func (srv *graceServer) handleSignals() { pid := syscall.Getpid() for { sig = <-srv.sigChan - srv.signalHooks(PRE_SIGNAL, sig) + srv.signalHooks(PreSignal, sig) switch sig { case syscall.SIGHUP: log.Println(pid, "Received SIGHUP. forking.") @@ -187,11 +188,11 @@ func (srv *graceServer) handleSignals() { default: log.Printf("Received %v: nothing i care about...\n", sig) } - srv.signalHooks(POST_SIGNAL, sig) + srv.signalHooks(PostSignal, sig) } } -func (srv *graceServer) signalHooks(ppFlag int, sig os.Signal) { +func (srv *Server) signalHooks(ppFlag int, sig os.Signal) { if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet { return } @@ -204,12 +205,12 @@ func (srv *graceServer) signalHooks(ppFlag int, sig os.Signal) { // shutdown closes the listener so that no new connections are accepted. it also // starts a goroutine that will serverTimeout (stop all running requests) the server // after DefaultTimeout. -func (srv *graceServer) shutdown() { - if srv.state != STATE_RUNNING { +func (srv *Server) shutdown() { + if srv.state != StateRunning { return } - srv.state = STATE_SHUTTING_DOWN + srv.state = StateShuttingDown if DefaultTimeout >= 0 { go srv.serverTimeout(DefaultTimeout) } @@ -224,26 +225,26 @@ func (srv *graceServer) shutdown() { // serverTimeout forces the server to shutdown in a given timeout - whether it // finished outstanding requests or not. if Read/WriteTimeout are not set or the // max header size is very big a connection could hang -func (srv *graceServer) serverTimeout(d time.Duration) { +func (srv *Server) serverTimeout(d time.Duration) { defer func() { if r := recover(); r != nil { log.Println("WaitGroup at 0", r) } }() - if srv.state != STATE_SHUTTING_DOWN { + if srv.state != StateShuttingDown { return } time.Sleep(d) log.Println("[STOP - Hammer Time] Forcefully shutting down parent") for { - if srv.state == STATE_TERMINATE { + if srv.state == StateTerminate { break } srv.wg.Done() } } -func (srv *graceServer) fork() (err error) { +func (srv *Server) fork() (err error) { regLock.Lock() defer regLock.Unlock() if runningServersForked { diff --git a/hooks.go b/hooks.go new file mode 100644 index 00000000..78abf8ef --- /dev/null +++ b/hooks.go @@ -0,0 +1,95 @@ +package beego + +import ( + "encoding/json" + "mime" + "net/http" + "path/filepath" + + "github.com/astaxie/beego/session" +) + +// +func registerMime() error { + for k, v := range mimemaps { + mime.AddExtensionType(k, v) + } + return nil +} + +// register default error http handlers, 404,401,403,500 and 503. +func registerDefaultErrorHandler() error { + m := map[string]func(http.ResponseWriter, *http.Request){ + "401": unauthorized, + "402": paymentRequired, + "403": forbidden, + "404": notFound, + "405": methodNotAllowed, + "500": internalServerError, + "501": notImplemented, + "502": badGateway, + "503": serviceUnavailable, + "504": gatewayTimeout, + } + for e, h := range m { + if _, ok := ErrorMaps[e]; !ok { + ErrorHandler(e, h) + } + } + return nil +} + +func registerSession() error { + if BConfig.WebConfig.Session.SessionOn { + var err error + sessionConfig := AppConfig.String("sessionConfig") + if sessionConfig == "" { + conf := map[string]interface{}{ + "cookieName": BConfig.WebConfig.Session.SessionName, + "gclifetime": BConfig.WebConfig.Session.SessionGCMaxLifetime, + "providerConfig": filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig), + "secure": BConfig.Listen.EnableHTTPS, + "enableSetCookie": BConfig.WebConfig.Session.SessionAutoSetCookie, + "domain": BConfig.WebConfig.Session.SessionDomain, + "cookieLifeTime": BConfig.WebConfig.Session.SessionCookieLifeTime, + } + confBytes, err := json.Marshal(conf) + if err != nil { + return err + } + sessionConfig = string(confBytes) + } + if GlobalSessions, err = session.NewManager(BConfig.WebConfig.Session.SessionProvider, sessionConfig); err != nil { + return err + } + go GlobalSessions.GC() + } + return nil +} + +func registerTemplate() error { + if BConfig.WebConfig.AutoRender { + if err := BuildTemplate(BConfig.WebConfig.ViewsPath); err != nil { + if BConfig.RunMode == DEV { + Warn(err) + } + return err + } + } + return nil +} + +func registerDocs() error { + if BConfig.WebConfig.EnableDocs { + Get("/docs", serverDocs) + Get("/docs/*", serverDocs) + } + return nil +} + +func registerAdmin() error { + if BConfig.Listen.EnableAdmin { + go beeAdminApp.Run() + } + return nil +} diff --git a/httplib/httplib.go b/httplib/httplib.go index 68c22d70..fb64a30a 100644 --- a/httplib/httplib.go +++ b/httplib/httplib.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package httplib is used as http.Client // Usage: // // import "github.com/astaxie/beego/httplib" @@ -51,7 +52,7 @@ import ( "time" ) -var defaultSetting = BeegoHttpSettings{ +var defaultSetting = BeegoHTTPSettings{ UserAgent: "beegoServer", ConnectTimeout: 60 * time.Second, ReadWriteTimeout: 60 * time.Second, @@ -69,25 +70,19 @@ func createDefaultCookie() { defaultCookieJar, _ = cookiejar.New(nil) } -// Overwrite default settings -func SetDefaultSetting(setting BeegoHttpSettings) { +// SetDefaultSetting Overwrite default settings +func SetDefaultSetting(setting BeegoHTTPSettings) { settingMutex.Lock() defer settingMutex.Unlock() defaultSetting = setting - if defaultSetting.ConnectTimeout == 0 { - defaultSetting.ConnectTimeout = 60 * time.Second - } - if defaultSetting.ReadWriteTimeout == 0 { - defaultSetting.ReadWriteTimeout = 60 * time.Second - } } -// return *BeegoHttpRequest with specific method -func NewBeegoRequest(rawurl, method string) *BeegoHttpRequest { +// NewBeegoRequest return *BeegoHttpRequest with specific method +func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest { var resp http.Response u, err := url.Parse(rawurl) if err != nil { - log.Fatal(err) + log.Println("Httplib:", err) } req := http.Request{ URL: u, @@ -97,10 +92,10 @@ func NewBeegoRequest(rawurl, method string) *BeegoHttpRequest { ProtoMajor: 1, ProtoMinor: 1, } - return &BeegoHttpRequest{ + return &BeegoHTTPRequest{ url: rawurl, req: &req, - params: map[string]string{}, + params: map[string][]string{}, files: map[string]string{}, setting: defaultSetting, resp: &resp, @@ -108,37 +103,37 @@ func NewBeegoRequest(rawurl, method string) *BeegoHttpRequest { } // Get returns *BeegoHttpRequest with GET method. -func Get(url string) *BeegoHttpRequest { +func Get(url string) *BeegoHTTPRequest { return NewBeegoRequest(url, "GET") } // Post returns *BeegoHttpRequest with POST method. -func Post(url string) *BeegoHttpRequest { +func Post(url string) *BeegoHTTPRequest { return NewBeegoRequest(url, "POST") } // Put returns *BeegoHttpRequest with PUT method. -func Put(url string) *BeegoHttpRequest { +func Put(url string) *BeegoHTTPRequest { return NewBeegoRequest(url, "PUT") } // Delete returns *BeegoHttpRequest DELETE method. -func Delete(url string) *BeegoHttpRequest { +func Delete(url string) *BeegoHTTPRequest { return NewBeegoRequest(url, "DELETE") } // Head returns *BeegoHttpRequest with HEAD method. -func Head(url string) *BeegoHttpRequest { +func Head(url string) *BeegoHTTPRequest { return NewBeegoRequest(url, "HEAD") } -// BeegoHttpSettings -type BeegoHttpSettings struct { +// BeegoHTTPSettings is the http.Client setting +type BeegoHTTPSettings struct { ShowDebug bool UserAgent string ConnectTimeout time.Duration ReadWriteTimeout time.Duration - TlsClientConfig *tls.Config + TLSClientConfig *tls.Config Proxy func(*http.Request) (*url.URL, error) Transport http.RoundTripper EnableCookie bool @@ -146,92 +141,92 @@ type BeegoHttpSettings struct { DumpBody bool } -// BeegoHttpRequest provides more useful methods for requesting one url than http.Request. -type BeegoHttpRequest struct { +// BeegoHTTPRequest provides more useful methods for requesting one url than http.Request. +type BeegoHTTPRequest struct { url string req *http.Request - params map[string]string + params map[string][]string files map[string]string - setting BeegoHttpSettings + setting BeegoHTTPSettings resp *http.Response body []byte dump []byte } -// get request -func (b *BeegoHttpRequest) GetRequest() *http.Request { +// GetRequest return the request object +func (b *BeegoHTTPRequest) GetRequest() *http.Request { return b.req } -// Change request settings -func (b *BeegoHttpRequest) Setting(setting BeegoHttpSettings) *BeegoHttpRequest { +// Setting Change request settings +func (b *BeegoHTTPRequest) Setting(setting BeegoHTTPSettings) *BeegoHTTPRequest { b.setting = setting return b } // SetBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and password. -func (b *BeegoHttpRequest) SetBasicAuth(username, password string) *BeegoHttpRequest { +func (b *BeegoHTTPRequest) SetBasicAuth(username, password string) *BeegoHTTPRequest { b.req.SetBasicAuth(username, password) return b } // SetEnableCookie sets enable/disable cookiejar -func (b *BeegoHttpRequest) SetEnableCookie(enable bool) *BeegoHttpRequest { +func (b *BeegoHTTPRequest) SetEnableCookie(enable bool) *BeegoHTTPRequest { b.setting.EnableCookie = enable return b } // SetUserAgent sets User-Agent header field -func (b *BeegoHttpRequest) SetUserAgent(useragent string) *BeegoHttpRequest { +func (b *BeegoHTTPRequest) SetUserAgent(useragent string) *BeegoHTTPRequest { b.setting.UserAgent = useragent return b } // Debug sets show debug or not when executing request. -func (b *BeegoHttpRequest) Debug(isdebug bool) *BeegoHttpRequest { +func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest { b.setting.ShowDebug = isdebug return b } -// Dump Body. -func (b *BeegoHttpRequest) DumpBody(isdump bool) *BeegoHttpRequest { +// DumpBody setting whether need to Dump the Body. +func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest { b.setting.DumpBody = isdump return b } -// return the DumpRequest -func (b *BeegoHttpRequest) DumpRequest() []byte { +// DumpRequest return the DumpRequest +func (b *BeegoHTTPRequest) DumpRequest() []byte { return b.dump } // SetTimeout sets connect time out and read-write time out for BeegoRequest. -func (b *BeegoHttpRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHttpRequest { +func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHTTPRequest { b.setting.ConnectTimeout = connectTimeout b.setting.ReadWriteTimeout = readWriteTimeout return b } // SetTLSClientConfig sets tls connection configurations if visiting https url. -func (b *BeegoHttpRequest) SetTLSClientConfig(config *tls.Config) *BeegoHttpRequest { - b.setting.TlsClientConfig = config +func (b *BeegoHTTPRequest) SetTLSClientConfig(config *tls.Config) *BeegoHTTPRequest { + b.setting.TLSClientConfig = config return b } // Header add header item string in request. -func (b *BeegoHttpRequest) Header(key, value string) *BeegoHttpRequest { +func (b *BeegoHTTPRequest) Header(key, value string) *BeegoHTTPRequest { b.req.Header.Set(key, value) return b } -// Set HOST -func (b *BeegoHttpRequest) SetHost(host string) *BeegoHttpRequest { +// SetHost set the request host +func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest { b.req.Host = host return b } -// Set the protocol version for incoming requests. +// SetProtocolVersion Set the protocol version for incoming requests. // Client requests always use HTTP/1.1. -func (b *BeegoHttpRequest) SetProtocolVersion(vers string) *BeegoHttpRequest { +func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest { if len(vers) == 0 { vers = "HTTP/1.1" } @@ -247,44 +242,49 @@ func (b *BeegoHttpRequest) SetProtocolVersion(vers string) *BeegoHttpRequest { } // SetCookie add cookie into request. -func (b *BeegoHttpRequest) SetCookie(cookie *http.Cookie) *BeegoHttpRequest { +func (b *BeegoHTTPRequest) SetCookie(cookie *http.Cookie) *BeegoHTTPRequest { b.req.Header.Add("Cookie", cookie.String()) return b } -// Set transport to -func (b *BeegoHttpRequest) SetTransport(transport http.RoundTripper) *BeegoHttpRequest { +// SetTransport set the setting transport +func (b *BeegoHTTPRequest) SetTransport(transport http.RoundTripper) *BeegoHTTPRequest { b.setting.Transport = transport return b } -// Set http proxy +// SetProxy set the http proxy // example: // // func(req *http.Request) (*url.URL, error) { // u, _ := url.ParseRequestURI("http://127.0.0.1:8118") // return u, nil // } -func (b *BeegoHttpRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHttpRequest { +func (b *BeegoHTTPRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHTTPRequest { b.setting.Proxy = proxy return b } // Param adds query param in to request. // params build query string as ?key1=value1&key2=value2... -func (b *BeegoHttpRequest) Param(key, value string) *BeegoHttpRequest { - b.params[key] = value +func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest { + if param, ok := b.params[key]; ok { + b.params[key] = append(param, value) + } else { + b.params[key] = []string{value} + } return b } -func (b *BeegoHttpRequest) PostFile(formname, filename string) *BeegoHttpRequest { +// PostFile add a post file to the request +func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest { b.files[formname] = filename return b } // Body adds request raw body. // it supports string and []byte. -func (b *BeegoHttpRequest) Body(data interface{}) *BeegoHttpRequest { +func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest { switch t := data.(type) { case string: bf := bytes.NewBufferString(t) @@ -298,8 +298,8 @@ func (b *BeegoHttpRequest) Body(data interface{}) *BeegoHttpRequest { return b } -// JsonBody adds request raw body encoding by JSON. -func (b *BeegoHttpRequest) JsonBody(obj interface{}) (*BeegoHttpRequest, error) { +// JSONBody adds request raw body encoding by JSON. +func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) { if b.req.Body == nil && obj != nil { buf := bytes.NewBuffer(nil) enc := json.NewEncoder(buf) @@ -313,7 +313,7 @@ func (b *BeegoHttpRequest) JsonBody(obj interface{}) (*BeegoHttpRequest, error) return b, nil } -func (b *BeegoHttpRequest) buildUrl(paramBody string) { +func (b *BeegoHTTPRequest) buildURL(paramBody string) { // build GET url with query string if b.req.Method == "GET" && len(paramBody) > 0 { if strings.Index(b.url, "?") != -1 { @@ -334,21 +334,23 @@ func (b *BeegoHttpRequest) buildUrl(paramBody string) { for formname, filename := range b.files { fileWriter, err := bodyWriter.CreateFormFile(formname, filename) if err != nil { - log.Fatal(err) + log.Println("Httplib:", err) } fh, err := os.Open(filename) if err != nil { - log.Fatal(err) + log.Println("Httplib:", err) } //iocopy _, err = io.Copy(fileWriter, fh) fh.Close() if err != nil { - log.Fatal(err) + log.Println("Httplib:", err) } } for k, v := range b.params { - bodyWriter.WriteField(k, v) + for _, vv := range v { + bodyWriter.WriteField(k, vv) + } } bodyWriter.Close() pw.Close() @@ -366,11 +368,11 @@ func (b *BeegoHttpRequest) buildUrl(paramBody string) { } } -func (b *BeegoHttpRequest) getResponse() (*http.Response, error) { +func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) { if b.resp.StatusCode != 0 { return b.resp, nil } - resp, err := b.SendOut() + resp, err := b.DoRequest() if err != nil { return nil, err } @@ -378,21 +380,24 @@ func (b *BeegoHttpRequest) getResponse() (*http.Response, error) { return resp, nil } -func (b *BeegoHttpRequest) SendOut() (*http.Response, error) { +// DoRequest will do the client.Do +func (b *BeegoHTTPRequest) DoRequest() (*http.Response, error) { var paramBody string if len(b.params) > 0 { var buf bytes.Buffer for k, v := range b.params { - buf.WriteString(url.QueryEscape(k)) - buf.WriteByte('=') - buf.WriteString(url.QueryEscape(v)) - buf.WriteByte('&') + for _, vv := range v { + buf.WriteString(url.QueryEscape(k)) + buf.WriteByte('=') + buf.WriteString(url.QueryEscape(vv)) + buf.WriteByte('&') + } } paramBody = buf.String() paramBody = paramBody[0 : len(paramBody)-1] } - b.buildUrl(paramBody) + b.buildURL(paramBody) url, err := url.Parse(b.url) if err != nil { return nil, err @@ -405,7 +410,7 @@ func (b *BeegoHttpRequest) SendOut() (*http.Response, error) { if trans == nil { // create default transport trans = &http.Transport{ - TLSClientConfig: b.setting.TlsClientConfig, + TLSClientConfig: b.setting.TLSClientConfig, Proxy: b.setting.Proxy, Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout), } @@ -413,7 +418,7 @@ func (b *BeegoHttpRequest) SendOut() (*http.Response, error) { // if b.transport is *http.Transport then set the settings. if t, ok := trans.(*http.Transport); ok { if t.TLSClientConfig == nil { - t.TLSClientConfig = b.setting.TlsClientConfig + t.TLSClientConfig = b.setting.TLSClientConfig } if t.Proxy == nil { t.Proxy = b.setting.Proxy @@ -424,7 +429,7 @@ func (b *BeegoHttpRequest) SendOut() (*http.Response, error) { } } - var jar http.CookieJar = nil + var jar http.CookieJar if b.setting.EnableCookie { if defaultCookieJar == nil { createDefaultCookie() @@ -453,7 +458,7 @@ func (b *BeegoHttpRequest) SendOut() (*http.Response, error) { // String returns the body string in response. // it calls Response inner. -func (b *BeegoHttpRequest) String() (string, error) { +func (b *BeegoHTTPRequest) String() (string, error) { data, err := b.Bytes() if err != nil { return "", err @@ -464,7 +469,7 @@ func (b *BeegoHttpRequest) String() (string, error) { // Bytes returns the body []byte in response. // it calls Response inner. -func (b *BeegoHttpRequest) Bytes() ([]byte, error) { +func (b *BeegoHTTPRequest) Bytes() ([]byte, error) { if b.body != nil { return b.body, nil } @@ -490,7 +495,7 @@ func (b *BeegoHttpRequest) Bytes() ([]byte, error) { // ToFile saves the body data in response to one file. // it calls Response inner. -func (b *BeegoHttpRequest) ToFile(filename string) error { +func (b *BeegoHTTPRequest) ToFile(filename string) error { f, err := os.Create(filename) if err != nil { return err @@ -509,9 +514,9 @@ func (b *BeegoHttpRequest) ToFile(filename string) error { return err } -// ToJson returns the map that marshals from the body bytes as json in response . +// ToJSON returns the map that marshals from the body bytes as json in response . // it calls Response inner. -func (b *BeegoHttpRequest) ToJson(v interface{}) error { +func (b *BeegoHTTPRequest) ToJSON(v interface{}) error { data, err := b.Bytes() if err != nil { return err @@ -519,9 +524,9 @@ func (b *BeegoHttpRequest) ToJson(v interface{}) error { return json.Unmarshal(data, v) } -// ToXml returns the map that marshals from the body bytes as xml in response . +// ToXML returns the map that marshals from the body bytes as xml in response . // it calls Response inner. -func (b *BeegoHttpRequest) ToXml(v interface{}) error { +func (b *BeegoHTTPRequest) ToXML(v interface{}) error { data, err := b.Bytes() if err != nil { return err @@ -530,7 +535,7 @@ func (b *BeegoHttpRequest) ToXml(v interface{}) error { } // Response executes request client gets response mannually. -func (b *BeegoHttpRequest) Response() (*http.Response, error) { +func (b *BeegoHTTPRequest) Response() (*http.Response, error) { return b.getResponse() } diff --git a/httplib/httplib_test.go b/httplib/httplib_test.go index 0b551c53..05815054 100644 --- a/httplib/httplib_test.go +++ b/httplib/httplib_test.go @@ -19,6 +19,7 @@ import ( "os" "strings" "testing" + "time" ) func TestResponse(t *testing.T) { @@ -149,10 +150,11 @@ func TestWithUserAgent(t *testing.T) { func TestWithSetting(t *testing.T) { v := "beego" - var setting BeegoHttpSettings + var setting BeegoHTTPSettings setting.EnableCookie = true setting.UserAgent = v setting.Transport = nil + setting.ReadWriteTimeout = 5 * time.Second SetDefaultSetting(setting) str, err := Get("http://httpbin.org/get").String() @@ -176,11 +178,11 @@ func TestToJson(t *testing.T) { t.Log(resp) // httpbin will return http remote addr - type Ip struct { + type IP struct { Origin string `json:"origin"` } - var ip Ip - err = req.ToJson(&ip) + var ip IP + err = req.ToJSON(&ip) if err != nil { t.Fatal(err) } diff --git a/log.go b/log.go index 7949ed96..46ec57dd 100644 --- a/log.go +++ b/log.go @@ -32,20 +32,20 @@ const ( LevelDebug ) -// SetLogLevel sets the global log level used by the simple -// logger. +// BeeLogger references the used application logger. +var BeeLogger = logs.NewLogger(100) + +// SetLevel sets the global log level used by the simple logger. func SetLevel(l int) { BeeLogger.SetLevel(l) } +// SetLogFuncCall set the CallDepth, default is 3 func SetLogFuncCall(b bool) { BeeLogger.EnableFuncCallDepth(b) BeeLogger.SetLogFuncCallDepth(3) } -// logger references the used application logger. -var BeeLogger *logs.BeeLogger - // SetLogger sets a new logger. func SetLogger(adaptername string, config string) error { err := BeeLogger.SetLogger(adaptername, config) @@ -55,10 +55,12 @@ func SetLogger(adaptername string, config string) error { return nil } +// Emergency logs a message at emergency level. func Emergency(v ...interface{}) { BeeLogger.Emergency(generateFmtStr(len(v)), v...) } +// Alert logs a message at alert level. func Alert(v ...interface{}) { BeeLogger.Alert(generateFmtStr(len(v)), v...) } @@ -78,21 +80,22 @@ func Warning(v ...interface{}) { BeeLogger.Warning(generateFmtStr(len(v)), v...) } -// compatibility alias for Warning() +// Warn compatibility alias for Warning() func Warn(v ...interface{}) { BeeLogger.Warn(generateFmtStr(len(v)), v...) } +// Notice logs a message at notice level. func Notice(v ...interface{}) { BeeLogger.Notice(generateFmtStr(len(v)), v...) } -// Info logs a message at info level. +// Informational logs a message at info level. func Informational(v ...interface{}) { BeeLogger.Informational(generateFmtStr(len(v)), v...) } -// compatibility alias for Warning() +// Info compatibility alias for Warning() func Info(v ...interface{}) { BeeLogger.Info(generateFmtStr(len(v)), v...) } diff --git a/logs/conn.go b/logs/conn.go index 2240eece..3655bf51 100644 --- a/logs/conn.go +++ b/logs/conn.go @@ -21,9 +21,9 @@ import ( "net" ) -// ConnWriter implements LoggerInterface. +// connWriter implements LoggerInterface. // it writes messages in keep-live tcp connection. -type ConnWriter struct { +type connWriter struct { lg *log.Logger innerWriter io.WriteCloser ReconnectOnMsg bool `json:"reconnectOnMsg"` @@ -33,22 +33,22 @@ type ConnWriter struct { Level int `json:"level"` } -// create new ConnWrite returning as LoggerInterface. -func NewConn() LoggerInterface { - conn := new(ConnWriter) +// NewConn create new ConnWrite returning as LoggerInterface. +func NewConn() Logger { + conn := new(connWriter) conn.Level = LevelTrace return conn } -// init connection writer with json config. +// Init init connection writer with json config. // json config only need key "level". -func (c *ConnWriter) Init(jsonconfig string) error { +func (c *connWriter) Init(jsonconfig string) error { return json.Unmarshal([]byte(jsonconfig), c) } -// write message in connection. +// WriteMsg write message in connection. // if connection is down, try to re-connect. -func (c *ConnWriter) WriteMsg(msg string, level int) error { +func (c *connWriter) WriteMsg(msg string, level int) error { if level > c.Level { return nil } @@ -66,19 +66,19 @@ func (c *ConnWriter) WriteMsg(msg string, level int) error { return nil } -// implementing method. empty. -func (c *ConnWriter) Flush() { +// Flush implementing method. empty. +func (c *connWriter) Flush() { } -// destroy connection writer and close tcp listener. -func (c *ConnWriter) Destroy() { +// Destroy destroy connection writer and close tcp listener. +func (c *connWriter) Destroy() { if c.innerWriter != nil { c.innerWriter.Close() } } -func (c *ConnWriter) connect() error { +func (c *connWriter) connect() error { if c.innerWriter != nil { c.innerWriter.Close() c.innerWriter = nil @@ -98,7 +98,7 @@ func (c *ConnWriter) connect() error { return nil } -func (c *ConnWriter) neddedConnectOnMsg() bool { +func (c *connWriter) neddedConnectOnMsg() bool { if c.Reconnect { c.Reconnect = false return true diff --git a/logs/console.go b/logs/console.go index ce7ecd54..23e8ebca 100644 --- a/logs/console.go +++ b/logs/console.go @@ -21,9 +21,11 @@ import ( "runtime" ) -type Brush func(string) string +// brush is a color join function +type brush func(string) string -func NewBrush(color string) Brush { +// newBrush return a fix color Brush +func newBrush(color string) brush { pre := "\033[" reset := "\033[0m" return func(text string) string { @@ -31,43 +33,43 @@ func NewBrush(color string) Brush { } } -var colors = []Brush{ - NewBrush("1;37"), // Emergency white - NewBrush("1;36"), // Alert cyan - NewBrush("1;35"), // Critical magenta - NewBrush("1;31"), // Error red - NewBrush("1;33"), // Warning yellow - NewBrush("1;32"), // Notice green - NewBrush("1;34"), // Informational blue - NewBrush("1;34"), // Debug blue +var colors = []brush{ + newBrush("1;37"), // Emergency white + newBrush("1;36"), // Alert cyan + newBrush("1;35"), // Critical magenta + newBrush("1;31"), // Error red + newBrush("1;33"), // Warning yellow + newBrush("1;32"), // Notice green + newBrush("1;34"), // Informational blue + newBrush("1;34"), // Debug blue } -// ConsoleWriter implements LoggerInterface and writes messages to terminal. -type ConsoleWriter struct { +// consoleWriter implements LoggerInterface and writes messages to terminal. +type consoleWriter struct { lg *log.Logger Level int `json:"level"` } -// create ConsoleWriter returning as LoggerInterface. -func NewConsole() LoggerInterface { - cw := &ConsoleWriter{ +// NewConsole create ConsoleWriter returning as LoggerInterface. +func NewConsole() Logger { + cw := &consoleWriter{ lg: log.New(os.Stdout, "", log.Ldate|log.Ltime), Level: LevelDebug, } return cw } -// init console logger. +// Init init console logger. // jsonconfig like '{"level":LevelTrace}'. -func (c *ConsoleWriter) Init(jsonconfig string) error { +func (c *consoleWriter) Init(jsonconfig string) error { if len(jsonconfig) == 0 { return nil } return json.Unmarshal([]byte(jsonconfig), c) } -// write message in console. -func (c *ConsoleWriter) WriteMsg(msg string, level int) error { +// WriteMsg write message in console. +func (c *consoleWriter) WriteMsg(msg string, level int) error { if level > c.Level { return nil } @@ -80,13 +82,13 @@ func (c *ConsoleWriter) WriteMsg(msg string, level int) error { return nil } -// implementing method. empty. -func (c *ConsoleWriter) Destroy() { +// Destroy implementing method. empty. +func (c *consoleWriter) Destroy() { } -// implementing method. empty. -func (c *ConsoleWriter) Flush() { +// Flush implementing method. empty. +func (c *consoleWriter) Flush() { } diff --git a/logs/console_test.go b/logs/console_test.go index 2fad7241..c4bb1da2 100644 --- a/logs/console_test.go +++ b/logs/console_test.go @@ -43,11 +43,3 @@ func TestConsole(t *testing.T) { testConsoleCalls(log2) } -func BenchmarkConsole(b *testing.B) { - log := NewLogger(10000) - log.EnableFuncCallDepth(true) - log.SetLogger("console", "") - for i := 0; i < b.N; i++ { - log.Debug("debug") - } -} diff --git a/logs/es/es.go b/logs/es/es.go index 3a73d4dd..f8dc5f65 100644 --- a/logs/es/es.go +++ b/logs/es/es.go @@ -12,7 +12,8 @@ import ( "github.com/belogik/goes" ) -func NewES() logs.LoggerInterface { +// NewES return a LoggerInterface +func NewES() logs.Logger { cw := &esLogger{ Level: logs.LevelDebug, } @@ -46,6 +47,7 @@ func (el *esLogger) Init(jsonconfig string) error { return nil } +// WriteMsg will write the msg and level into es func (el *esLogger) WriteMsg(msg string, level int) error { if level > el.Level { return nil @@ -63,10 +65,12 @@ func (el *esLogger) WriteMsg(msg string, level int) error { return err } +// Destroy is a empty method func (el *esLogger) Destroy() { } +// Flush is a empty method func (el *esLogger) Flush() { } diff --git a/logs/file.go b/logs/file.go index 2d3449ce..0eae734a 100644 --- a/logs/file.go +++ b/logs/file.go @@ -20,7 +20,6 @@ import ( "errors" "fmt" "io" - "log" "os" "path/filepath" "strings" @@ -28,84 +27,62 @@ import ( "time" ) -// FileLogWriter implements LoggerInterface. +// fileLogWriter implements LoggerInterface. // It writes messages by lines limit, file size limit, or time frequency. -type FileLogWriter struct { - *log.Logger - mw *MuxWriter +type fileLogWriter struct { + sync.Mutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize // The opened file - Filename string `json:"filename"` + Filename string `json:"filename"` + fileWriter *os.File - Maxlines int `json:"maxlines"` - maxlines_curlines int + // Rotate at line + MaxLines int `json:"maxlines"` + maxLinesCurLines int // Rotate at size - Maxsize int `json:"maxsize"` - maxsize_cursize int + MaxSize int `json:"maxsize"` + maxSizeCurSize int // Rotate daily - Daily bool `json:"daily"` - Maxdays int64 `json:"maxdays"` - daily_opendate int + Daily bool `json:"daily"` + MaxDays int64 `json:"maxdays"` + dailyOpenDate int Rotate bool `json:"rotate"` - startLock sync.Mutex // Only one log can write to the file - Level int `json:"level"` + + Perm os.FileMode `json:"perm"` } -// an *os.File writer with locker. -type MuxWriter struct { - sync.Mutex - fd *os.File -} - -// write to os.File. -func (l *MuxWriter) Write(b []byte) (int, error) { - l.Lock() - defer l.Unlock() - return l.fd.Write(b) -} - -// set os.File in writer. -func (l *MuxWriter) SetFd(fd *os.File) { - if l.fd != nil { - l.fd.Close() - } - l.fd = fd -} - -// create a FileLogWriter returning as LoggerInterface. -func NewFileWriter() LoggerInterface { - w := &FileLogWriter{ +// NewFileWriter create a FileLogWriter returning as LoggerInterface. +func newFileWriter() Logger { + w := &fileLogWriter{ Filename: "", - Maxlines: 1000000, - Maxsize: 1 << 28, //256 MB + MaxLines: 1000000, + MaxSize: 1 << 28, //256 MB Daily: true, - Maxdays: 7, + MaxDays: 7, Rotate: true, Level: LevelTrace, + Perm: 0660, } - // use MuxWriter instead direct use os.File for lock write when rotate - w.mw = new(MuxWriter) - // set MuxWriter as Logger's io.Writer - w.Logger = log.New(w.mw, "", log.Ldate|log.Ltime) return w } // Init file logger with json config. -// jsonconfig like: +// jsonConfig like: // { // "filename":"logs/beego.log", -// "maxlines":10000, +// "maxLines":10000, // "maxsize":1<<30, // "daily":true, -// "maxdays":15, -// "rotate":true +// "maxDays":15, +// "rotate":true, +// "perm":0600 // } -func (w *FileLogWriter) Init(jsonconfig string) error { - err := json.Unmarshal([]byte(jsonconfig), w) +func (w *fileLogWriter) Init(jsonConfig string) error { + err := json.Unmarshal([]byte(jsonConfig), w) if err != nil { return err } @@ -117,67 +94,120 @@ func (w *FileLogWriter) Init(jsonconfig string) error { } // start file logger. create log file and set to locker-inside file writer. -func (w *FileLogWriter) startLogger() error { - fd, err := w.createLogFile() +func (w *fileLogWriter) startLogger() error { + file, err := w.createLogFile() if err != nil { return err } - w.mw.SetFd(fd) + if w.fileWriter != nil { + w.fileWriter.Close() + } + w.fileWriter = file return w.initFd() } -func (w *FileLogWriter) docheck(size int) { - w.startLock.Lock() - defer w.startLock.Unlock() - if w.Rotate && ((w.Maxlines > 0 && w.maxlines_curlines >= w.Maxlines) || - (w.Maxsize > 0 && w.maxsize_cursize >= w.Maxsize) || - (w.Daily && time.Now().Day() != w.daily_opendate)) { - if err := w.DoRotate(); err != nil { - fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) - return - } - } - w.maxlines_curlines++ - w.maxsize_cursize += size +func (w *fileLogWriter) needRotate(size int, day int) bool { + return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) || + (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) || + (w.Daily && day != w.dailyOpenDate) + } -// write logger message into file. -func (w *FileLogWriter) WriteMsg(msg string, level int) error { +// WriteMsg write logger message into file. +func (w *fileLogWriter) WriteMsg(msg string, level int) error { if level > w.Level { return nil } - n := 24 + len(msg) // 24 stand for the length "2013/06/23 21:00:22 [T] " - w.docheck(n) - w.Logger.Println(msg) - return nil + //2016/01/12 21:34:33 + now := time.Now() + y, mo, d := now.Date() + h, mi, s := now.Clock() + //len(2006/01/02 15:03:04)==19 + var buf [20]byte + t := 3 + for y >= 10 { + p := y / 10 + buf[t] = byte('0' + y - p*10) + y = p + t-- + } + buf[0] = byte('0' + y) + buf[4] = '/' + if mo > 9 { + buf[5] = '1' + buf[6] = byte('0' + mo - 9) + } else { + buf[5] = '0' + buf[6] = byte('0' + mo) + } + buf[7] = '/' + t = d / 10 + buf[8] = byte('0' + t) + buf[9] = byte('0' + d - t*10) + buf[10] = ' ' + t = h / 10 + buf[11] = byte('0' + t) + buf[12] = byte('0' + h - t*10) + buf[13] = ':' + t = mi / 10 + buf[14] = byte('0' + t) + buf[15] = byte('0' + mi - t*10) + buf[16] = ':' + t = s / 10 + buf[17] = byte('0' + t) + buf[18] = byte('0' + s - t*10) + buf[19] = ' ' + msg = string(buf[0:]) + msg + "\n" + + if w.Rotate { + if w.needRotate(len(msg), d) { + w.Lock() + if w.needRotate(len(msg), d) { + if err := w.doRotate(); err != nil { + fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) + } + + } + w.Unlock() + } + } + + w.Lock() + _, err := w.fileWriter.Write([]byte(msg)) + if err == nil { + w.maxLinesCurLines++ + w.maxSizeCurSize += len(msg) + } + w.Unlock() + return err } -func (w *FileLogWriter) createLogFile() (*os.File, error) { +func (w *fileLogWriter) createLogFile() (*os.File, error) { // Open the log file - fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0660) + fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, w.Perm) return fd, err } -func (w *FileLogWriter) initFd() error { - fd := w.mw.fd - finfo, err := fd.Stat() +func (w *fileLogWriter) initFd() error { + fd := w.fileWriter + fInfo, err := fd.Stat() if err != nil { return fmt.Errorf("get stat err: %s\n", err) } - w.maxsize_cursize = int(finfo.Size()) - w.daily_opendate = time.Now().Day() - w.maxlines_curlines = 0 - if finfo.Size() > 0 { + w.maxSizeCurSize = int(fInfo.Size()) + w.dailyOpenDate = time.Now().Day() + w.maxLinesCurLines = 0 + if fInfo.Size() > 0 { count, err := w.lines() if err != nil { return err } - w.maxlines_curlines = count + w.maxLinesCurLines = count } return nil } -func (w *FileLogWriter) lines() (int, error) { +func (w *fileLogWriter) lines() (int, error) { fd, err := os.Open(w.Filename) if err != nil { return 0, err @@ -205,59 +235,60 @@ func (w *FileLogWriter) lines() (int, error) { } // DoRotate means it need to write file in new file. -// new file name like xx.log.2013-01-01.2 -func (w *FileLogWriter) DoRotate() error { +// new file name like xx.2013-01-01.2.log +func (w *fileLogWriter) doRotate() error { _, err := os.Lstat(w.Filename) - if err == nil { // file exists - // Find the next available number - num := 1 - fname := "" - for ; err == nil && num <= 999; num++ { - fname = w.Filename + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), num) - _, err = os.Lstat(fname) - } - // return error if the last file checked still existed - if err == nil { - return fmt.Errorf("Rotate: Cannot find free log number to rename %s\n", w.Filename) - } - - // block Logger's io.Writer - w.mw.Lock() - defer w.mw.Unlock() - - fd := w.mw.fd - fd.Close() - - // close fd before rename - // Rename the file to its newfound home - err = os.Rename(w.Filename, fname) - if err != nil { - return fmt.Errorf("Rotate: %s\n", err) - } - - // re-start logger - err = w.startLogger() - if err != nil { - return fmt.Errorf("Rotate StartLogger: %s\n", err) - } - - go w.deleteOldLog() + if err != nil { + return err + } + // file exists + // Find the next available number + num := 1 + fName := "" + suffix := filepath.Ext(w.Filename) + filenameOnly := strings.TrimSuffix(w.Filename, suffix) + if suffix == "" { + suffix = ".log" + } + for ; err == nil && num <= 999; num++ { + fName = filenameOnly + fmt.Sprintf(".%s.%03d%s", time.Now().Format("2006-01-02"), num, suffix) + _, err = os.Lstat(fName) + } + // return error if the last file checked still existed + if err == nil { + return fmt.Errorf("Rotate: Cannot find free log number to rename %s\n", w.Filename) } + // close fileWriter before rename + w.fileWriter.Close() + + // Rename the file to its new found name + // even if occurs error,we MUST guarantee to restart new logger + renameErr := os.Rename(w.Filename, fName) + // re-start logger + startLoggerErr := w.startLogger() + go w.deleteOldLog() + + if startLoggerErr != nil { + return fmt.Errorf("Rotate StartLogger: %s\n", startLoggerErr) + } + if renameErr != nil { + return fmt.Errorf("Rotate: %s\n", renameErr) + } return nil + } -func (w *FileLogWriter) deleteOldLog() { +func (w *fileLogWriter) deleteOldLog() { dir := filepath.Dir(w.Filename) filepath.Walk(dir, func(path string, info os.FileInfo, err error) (returnErr error) { defer func() { if r := recover(); r != nil { - returnErr = fmt.Errorf("Unable to delete old log '%s', error: %+v", path, r) - fmt.Println(returnErr) + fmt.Fprintf(os.Stderr, "Unable to delete old log '%s', error: %v\n", path, r) } }() - if !info.IsDir() && info.ModTime().Unix() < (time.Now().Unix()-60*60*24*w.Maxdays) { + if !info.IsDir() && info.ModTime().Unix() < (time.Now().Unix()-60*60*24*w.MaxDays) { if strings.HasPrefix(filepath.Base(path), filepath.Base(w.Filename)) { os.Remove(path) } @@ -266,18 +297,18 @@ func (w *FileLogWriter) deleteOldLog() { }) } -// destroy file logger, close file writer. -func (w *FileLogWriter) Destroy() { - w.mw.fd.Close() +// Destroy close the file description, close file writer. +func (w *fileLogWriter) Destroy() { + w.fileWriter.Close() } -// flush file logger. +// Flush flush file logger. // there are no buffering messages in file logger in memory. // flush file means sync file from disk. -func (w *FileLogWriter) Flush() { - w.mw.fd.Sync() +func (w *fileLogWriter) Flush() { + w.fileWriter.Sync() } func init() { - Register("file", NewFileWriter) + Register("file", newFileWriter) } diff --git a/logs/file_test.go b/logs/file_test.go index c71e9bb4..f9b54c26 100644 --- a/logs/file_test.go +++ b/logs/file_test.go @@ -23,7 +23,7 @@ import ( "time" ) -func TestFile(t *testing.T) { +func TestFile1(t *testing.T) { log := NewLogger(10000) log.SetLogger("file", `{"filename":"test.log"}`) log.Debug("debug") @@ -34,25 +34,24 @@ func TestFile(t *testing.T) { log.Alert("alert") log.Critical("critical") log.Emergency("emergency") - time.Sleep(time.Second * 4) f, err := os.Open("test.log") if err != nil { t.Fatal(err) } b := bufio.NewReader(f) - linenum := 0 + lineNum := 0 for { line, _, err := b.ReadLine() if err != nil { break } if len(line) > 0 { - linenum++ + lineNum++ } } var expected = LevelDebug + 1 - if linenum != expected { - t.Fatal(linenum, "not "+strconv.Itoa(expected)+" lines") + if lineNum != expected { + t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines") } os.Remove("test.log") } @@ -68,25 +67,24 @@ func TestFile2(t *testing.T) { log.Alert("alert") log.Critical("critical") log.Emergency("emergency") - time.Sleep(time.Second * 4) f, err := os.Open("test2.log") if err != nil { t.Fatal(err) } b := bufio.NewReader(f) - linenum := 0 + lineNum := 0 for { line, _, err := b.ReadLine() if err != nil { break } if len(line) > 0 { - linenum++ + lineNum++ } } var expected = LevelError + 1 - if linenum != expected { - t.Fatal(linenum, "not "+strconv.Itoa(expected)+" lines") + if lineNum != expected { + t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines") } os.Remove("test2.log") } @@ -102,13 +100,13 @@ func TestFileRotate(t *testing.T) { log.Alert("alert") log.Critical("critical") log.Emergency("emergency") - time.Sleep(time.Second * 4) - rotatename := "test3.log" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) - b, err := exists(rotatename) + rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) + ".log" + b, err := exists(rotateName) if !b || err != nil { + os.Remove("test3.log") t.Fatal("rotate not generated") } - os.Remove(rotatename) + os.Remove(rotateName) os.Remove("test3.log") } @@ -131,3 +129,46 @@ func BenchmarkFile(b *testing.B) { } os.Remove("test4.log") } + + +func BenchmarkFileAsynchronous(b *testing.B) { + log := NewLogger(100000) + log.SetLogger("file", `{"filename":"test4.log"}`) + log.Async() + for i := 0; i < b.N; i++ { + log.Debug("debug") + } + os.Remove("test4.log") +} + +func BenchmarkFileCallDepth(b *testing.B) { + log := NewLogger(100000) + log.SetLogger("file", `{"filename":"test4.log"}`) + log.EnableFuncCallDepth(true) + log.SetLogFuncCallDepth(2) + for i := 0; i < b.N; i++ { + log.Debug("debug") + } + os.Remove("test4.log") +} + +func BenchmarkFileAsynchronousCallDepth(b *testing.B) { + log := NewLogger(100000) + log.SetLogger("file", `{"filename":"test4.log"}`) + log.EnableFuncCallDepth(true) + log.SetLogFuncCallDepth(2) + log.Async() + for i := 0; i < b.N; i++ { + log.Debug("debug") + } + os.Remove("test4.log") +} + +func BenchmarkFileOnGoroutine(b *testing.B) { + log := NewLogger(100000) + log.SetLogger("file", `{"filename":"test4.log"}`) + for i := 0; i < b.N; i++ { + go log.Debug("debug") + } + os.Remove("test4.log") +} diff --git a/logs/log.go b/logs/log.go index cebbc737..ccaaa3ad 100644 --- a/logs/log.go +++ b/logs/log.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package logs provide a general log interface // Usage: // // import "github.com/astaxie/beego/logs" @@ -34,8 +35,10 @@ package logs import ( "fmt" + "os" "path" "runtime" + "strconv" "sync" ) @@ -60,10 +63,10 @@ const ( LevelWarn = LevelWarning ) -type loggerType func() LoggerInterface +type loggerType func() Logger -// LoggerInterface defines the behavior of a log provider. -type LoggerInterface interface { +// Logger defines the behavior of a log provider. +type Logger interface { Init(config string) error WriteMsg(msg string, level int) error Destroy() @@ -93,8 +96,13 @@ type BeeLogger struct { enableFuncCallDepth bool loggerFuncCallDepth int asynchronous bool - msg chan *logMsg - outputs map[string]LoggerInterface + msgChan chan *logMsg + outputs []*nameLogger +} + +type nameLogger struct { + Logger + name string } type logMsg struct { @@ -102,59 +110,79 @@ type logMsg struct { msg string } +var logMsgPool *sync.Pool + // NewLogger returns a new BeeLogger. -// channellen means the number of messages in chan. +// channelLen means the number of messages in chan(used where asynchronous is true). // if the buffering chan is full, logger adapters write to file or other way. -func NewLogger(channellen int64) *BeeLogger { +func NewLogger(channelLen int64) *BeeLogger { bl := new(BeeLogger) bl.level = LevelDebug bl.loggerFuncCallDepth = 2 - bl.msg = make(chan *logMsg, channellen) - bl.outputs = make(map[string]LoggerInterface) + bl.msgChan = make(chan *logMsg, channelLen) return bl } +// Async set the log to asynchronous and start the goroutine func (bl *BeeLogger) Async() *BeeLogger { bl.asynchronous = true + logMsgPool = &sync.Pool{ + New: func() interface{} { + return &logMsg{} + }, + } go bl.startLogger() return bl } // SetLogger provides a given logger adapter into BeeLogger with config string. // config need to be correct JSON as string: {"interval":360}. -func (bl *BeeLogger) SetLogger(adaptername string, config string) error { +func (bl *BeeLogger) SetLogger(adapterName string, config string) error { bl.lock.Lock() defer bl.lock.Unlock() - if log, ok := adapters[adaptername]; ok { + if log, ok := adapters[adapterName]; ok { lg := log() err := lg.Init(config) - bl.outputs[adaptername] = lg if err != nil { - fmt.Println("logs.BeeLogger.SetLogger: " + err.Error()) + fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error()) return err } + bl.outputs = append(bl.outputs, &nameLogger{name: adapterName, Logger: lg}) } else { - return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adaptername) + return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName) } return nil } -// remove a logger adapter in BeeLogger. -func (bl *BeeLogger) DelLogger(adaptername string) error { +// DelLogger remove a logger adapter in BeeLogger. +func (bl *BeeLogger) DelLogger(adapterName string) error { bl.lock.Lock() defer bl.lock.Unlock() - if lg, ok := bl.outputs[adaptername]; ok { - lg.Destroy() - delete(bl.outputs, adaptername) - return nil - } else { - return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adaptername) + outputs := []*nameLogger{} + for _, lg := range bl.outputs { + if lg.name == adapterName { + lg.Destroy() + } else { + outputs = append(outputs, lg) + } + } + if len(outputs) == len(bl.outputs) { + return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName) + } + bl.outputs = outputs + return nil +} + +func (bl *BeeLogger) writeToLoggers(msg string, level int) { + for _, l := range bl.outputs { + err := l.WriteMsg(msg, level) + if err != nil { + fmt.Fprintf(os.Stderr, "unable to WriteMsg to adapter:%v,error:%v\n", l.name, err) + } } } -func (bl *BeeLogger) writerMsg(loglevel int, msg string) error { - lm := new(logMsg) - lm.level = loglevel +func (bl *BeeLogger) writeMsg(logLevel int, msg string) error { if bl.enableFuncCallDepth { _, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth) if !ok { @@ -162,43 +190,37 @@ func (bl *BeeLogger) writerMsg(loglevel int, msg string) error { line = 0 } _, filename := path.Split(file) - lm.msg = fmt.Sprintf("[%s:%d] %s", filename, line, msg) - } else { - lm.msg = msg + msg = "[" + filename + ":" + strconv.FormatInt(int64(line), 10) + "]" + msg } if bl.asynchronous { - bl.msg <- lm + lm := logMsgPool.Get().(*logMsg) + lm.level = logLevel + lm.msg = msg + bl.msgChan <- lm } else { - for name, l := range bl.outputs { - err := l.WriteMsg(lm.msg, lm.level) - if err != nil { - fmt.Println("unable to WriteMsg to adapter:", name, err) - return err - } - } + bl.writeToLoggers(msg, logLevel) } return nil } -// Set log message level. -// +// SetLevel Set log message level. // If message level (such as LevelDebug) is higher than logger level (such as LevelWarning), // log providers will not even be sent the message. func (bl *BeeLogger) SetLevel(l int) { bl.level = l } -// set log funcCallDepth +// SetLogFuncCallDepth set log funcCallDepth func (bl *BeeLogger) SetLogFuncCallDepth(d int) { bl.loggerFuncCallDepth = d } -// get log funcCallDepth for wrapper +// GetLogFuncCallDepth return log funcCallDepth for wrapper func (bl *BeeLogger) GetLogFuncCallDepth() int { return bl.loggerFuncCallDepth } -// enable log funcCallDepth +// EnableFuncCallDepth enable log funcCallDepth func (bl *BeeLogger) EnableFuncCallDepth(b bool) { bl.enableFuncCallDepth = b } @@ -208,137 +230,129 @@ func (bl *BeeLogger) EnableFuncCallDepth(b bool) { func (bl *BeeLogger) startLogger() { for { select { - case bm := <-bl.msg: - for _, l := range bl.outputs { - err := l.WriteMsg(bm.msg, bm.level) - if err != nil { - fmt.Println("ERROR, unable to WriteMsg:", err) - } - } + case bm := <-bl.msgChan: + bl.writeToLoggers(bm.msg, bm.level) + logMsgPool.Put(bm) } } } -// Log EMERGENCY level message. +// Emergency Log EMERGENCY level message. func (bl *BeeLogger) Emergency(format string, v ...interface{}) { if LevelEmergency > bl.level { return } msg := fmt.Sprintf("[M] "+format, v...) - bl.writerMsg(LevelEmergency, msg) + bl.writeMsg(LevelEmergency, msg) } -// Log ALERT level message. +// Alert Log ALERT level message. func (bl *BeeLogger) Alert(format string, v ...interface{}) { if LevelAlert > bl.level { return } msg := fmt.Sprintf("[A] "+format, v...) - bl.writerMsg(LevelAlert, msg) + bl.writeMsg(LevelAlert, msg) } -// Log CRITICAL level message. +// Critical Log CRITICAL level message. func (bl *BeeLogger) Critical(format string, v ...interface{}) { if LevelCritical > bl.level { return } msg := fmt.Sprintf("[C] "+format, v...) - bl.writerMsg(LevelCritical, msg) + bl.writeMsg(LevelCritical, msg) } -// Log ERROR level message. +// Error Log ERROR level message. func (bl *BeeLogger) Error(format string, v ...interface{}) { if LevelError > bl.level { return } msg := fmt.Sprintf("[E] "+format, v...) - bl.writerMsg(LevelError, msg) + bl.writeMsg(LevelError, msg) } -// Log WARNING level message. +// Warning Log WARNING level message. func (bl *BeeLogger) Warning(format string, v ...interface{}) { if LevelWarning > bl.level { return } msg := fmt.Sprintf("[W] "+format, v...) - bl.writerMsg(LevelWarning, msg) + bl.writeMsg(LevelWarning, msg) } -// Log NOTICE level message. +// Notice Log NOTICE level message. func (bl *BeeLogger) Notice(format string, v ...interface{}) { if LevelNotice > bl.level { return } msg := fmt.Sprintf("[N] "+format, v...) - bl.writerMsg(LevelNotice, msg) + bl.writeMsg(LevelNotice, msg) } -// Log INFORMATIONAL level message. +// Informational Log INFORMATIONAL level message. func (bl *BeeLogger) Informational(format string, v ...interface{}) { if LevelInformational > bl.level { return } msg := fmt.Sprintf("[I] "+format, v...) - bl.writerMsg(LevelInformational, msg) + bl.writeMsg(LevelInformational, msg) } -// Log DEBUG level message. +// Debug Log DEBUG level message. func (bl *BeeLogger) Debug(format string, v ...interface{}) { if LevelDebug > bl.level { return } msg := fmt.Sprintf("[D] "+format, v...) - bl.writerMsg(LevelDebug, msg) + bl.writeMsg(LevelDebug, msg) } -// Log WARN level message. +// Warn Log WARN level message. // compatibility alias for Warning() func (bl *BeeLogger) Warn(format string, v ...interface{}) { if LevelWarning > bl.level { return } msg := fmt.Sprintf("[W] "+format, v...) - bl.writerMsg(LevelWarning, msg) + bl.writeMsg(LevelWarning, msg) } -// Log INFO level message. +// Info Log INFO level message. // compatibility alias for Informational() func (bl *BeeLogger) Info(format string, v ...interface{}) { if LevelInformational > bl.level { return } msg := fmt.Sprintf("[I] "+format, v...) - bl.writerMsg(LevelInformational, msg) + bl.writeMsg(LevelInformational, msg) } -// Log TRACE level message. +// Trace Log TRACE level message. // compatibility alias for Debug() func (bl *BeeLogger) Trace(format string, v ...interface{}) { if LevelDebug > bl.level { return } msg := fmt.Sprintf("[D] "+format, v...) - bl.writerMsg(LevelDebug, msg) + bl.writeMsg(LevelDebug, msg) } -// flush all chan data. +// Flush flush all chan data. func (bl *BeeLogger) Flush() { for _, l := range bl.outputs { l.Flush() } } -// close logger, flush all chan data and destroy all adapters in BeeLogger. +// Close close logger, flush all chan data and destroy all adapters in BeeLogger. func (bl *BeeLogger) Close() { for { - if len(bl.msg) > 0 { - bm := <-bl.msg - for _, l := range bl.outputs { - err := l.WriteMsg(bm.msg, bm.level) - if err != nil { - fmt.Println("ERROR, unable to WriteMsg (while closing logger):", err) - } - } + if len(bl.msgChan) > 0 { + bm := <-bl.msgChan + bl.writeToLoggers(bm.msg, bm.level) + logMsgPool.Put(bm) continue } break diff --git a/logs/smtp.go b/logs/smtp.go index 95123ebf..748462f9 100644 --- a/logs/smtp.go +++ b/logs/smtp.go @@ -24,31 +24,26 @@ import ( "time" ) -const ( -// no usage -// subjectPhrase = "Diagnostic message from server" -) - -// smtpWriter implements LoggerInterface and is used to send emails via given SMTP-server. -type SmtpWriter struct { - Username string `json:"Username"` +// SMTPWriter implements LoggerInterface and is used to send emails via given SMTP-server. +type SMTPWriter struct { + Username string `json:"username"` Password string `json:"password"` - Host string `json:"Host"` + Host string `json:"host"` Subject string `json:"subject"` FromAddress string `json:"fromAddress"` RecipientAddresses []string `json:"sendTos"` Level int `json:"level"` } -// create smtp writer. -func NewSmtpWriter() LoggerInterface { - return &SmtpWriter{Level: LevelTrace} +// NewSMTPWriter create smtp writer. +func newSMTPWriter() Logger { + return &SMTPWriter{Level: LevelTrace} } -// init smtp writer with json config. +// Init smtp writer with json config. // config like: // { -// "Username":"example@gmail.com", +// "username":"example@gmail.com", // "password:"password", // "host":"smtp.gmail.com:465", // "subject":"email title", @@ -56,7 +51,7 @@ func NewSmtpWriter() LoggerInterface { // "sendTos":["email1","email2"], // "level":LevelError // } -func (s *SmtpWriter) Init(jsonconfig string) error { +func (s *SMTPWriter) Init(jsonconfig string) error { err := json.Unmarshal([]byte(jsonconfig), s) if err != nil { return err @@ -64,7 +59,7 @@ func (s *SmtpWriter) Init(jsonconfig string) error { return nil } -func (s *SmtpWriter) GetSmtpAuth(host string) smtp.Auth { +func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth { if len(strings.Trim(s.Username, " ")) == 0 && len(strings.Trim(s.Password, " ")) == 0 { return nil } @@ -76,7 +71,7 @@ func (s *SmtpWriter) GetSmtpAuth(host string) smtp.Auth { ) } -func (s *SmtpWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAddress string, recipients []string, msgContent []byte) error { +func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAddress string, recipients []string, msgContent []byte) error { client, err := smtp.Dial(hostAddressWithPort) if err != nil { return err @@ -129,9 +124,9 @@ func (s *SmtpWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAd return nil } -// write message in smtp writer. +// WriteMsg write message in smtp writer. // it will send an email with subject and only this message. -func (s *SmtpWriter) WriteMsg(msg string, level int) error { +func (s *SMTPWriter) WriteMsg(msg string, level int) error { if level > s.Level { return nil } @@ -139,27 +134,27 @@ func (s *SmtpWriter) WriteMsg(msg string, level int) error { hp := strings.Split(s.Host, ":") // Set up authentication information. - auth := s.GetSmtpAuth(hp[0]) + auth := s.getSMTPAuth(hp[0]) // Connect to the server, authenticate, set the sender and recipient, // and send the email all in one step. - content_type := "Content-Type: text/plain" + "; charset=UTF-8" + contentType := "Content-Type: text/plain" + "; charset=UTF-8" mailmsg := []byte("To: " + strings.Join(s.RecipientAddresses, ";") + "\r\nFrom: " + s.FromAddress + "<" + s.FromAddress + - ">\r\nSubject: " + s.Subject + "\r\n" + content_type + "\r\n\r\n" + fmt.Sprintf(".%s", time.Now().Format("2006-01-02 15:04:05")) + msg) + ">\r\nSubject: " + s.Subject + "\r\n" + contentType + "\r\n\r\n" + fmt.Sprintf(".%s", time.Now().Format("2006-01-02 15:04:05")) + msg) return s.sendMail(s.Host, auth, s.FromAddress, s.RecipientAddresses, mailmsg) } -// implementing method. empty. -func (s *SmtpWriter) Flush() { +// Flush implementing method. empty. +func (s *SMTPWriter) Flush() { return } -// implementing method. empty. -func (s *SmtpWriter) Destroy() { +// Destroy implementing method. empty. +func (s *SMTPWriter) Destroy() { return } func init() { - Register("smtp", NewSmtpWriter) + Register("smtp", newSMTPWriter) } diff --git a/memzipfile.go b/memzipfile.go deleted file mode 100644 index cc5e3851..00000000 --- a/memzipfile.go +++ /dev/null @@ -1,212 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "bytes" - "compress/flate" - "compress/gzip" - "errors" - "io" - "io/ioutil" - "net/http" - "os" - "strings" - "sync" - "time" -) - -var gmfim map[string]*memFileInfo = make(map[string]*memFileInfo) -var lock sync.RWMutex - -// OpenMemZipFile returns MemFile object with a compressed static file. -// it's used for serve static file if gzip enable. -func openMemZipFile(path string, zip string) (*memFile, error) { - osfile, e := os.Open(path) - if e != nil { - return nil, e - } - defer osfile.Close() - - osfileinfo, e := osfile.Stat() - if e != nil { - return nil, e - } - - modtime := osfileinfo.ModTime() - fileSize := osfileinfo.Size() - lock.RLock() - cfi, ok := gmfim[zip+":"+path] - lock.RUnlock() - if !(ok && cfi.ModTime() == modtime && cfi.fileSize == fileSize) { - var content []byte - if zip == "gzip" { - var zipbuf bytes.Buffer - gzipwriter, e := gzip.NewWriterLevel(&zipbuf, gzip.BestCompression) - if e != nil { - return nil, e - } - _, e = io.Copy(gzipwriter, osfile) - gzipwriter.Close() - if e != nil { - return nil, e - } - content, e = ioutil.ReadAll(&zipbuf) - if e != nil { - return nil, e - } - } else if zip == "deflate" { - var zipbuf bytes.Buffer - deflatewriter, e := flate.NewWriter(&zipbuf, flate.BestCompression) - if e != nil { - return nil, e - } - _, e = io.Copy(deflatewriter, osfile) - deflatewriter.Close() - if e != nil { - return nil, e - } - content, e = ioutil.ReadAll(&zipbuf) - if e != nil { - return nil, e - } - } else { - content, e = ioutil.ReadAll(osfile) - if e != nil { - return nil, e - } - } - - cfi = &memFileInfo{osfileinfo, modtime, content, int64(len(content)), fileSize} - lock.Lock() - defer lock.Unlock() - gmfim[zip+":"+path] = cfi - } - return &memFile{fi: cfi, offset: 0}, nil -} - -// MemFileInfo contains a compressed file bytes and file information. -// it implements os.FileInfo interface. -type memFileInfo struct { - os.FileInfo - modTime time.Time - content []byte - contentSize int64 - fileSize int64 -} - -// Name returns the compressed filename. -func (fi *memFileInfo) Name() string { - return fi.Name() -} - -// Size returns the raw file content size, not compressed size. -func (fi *memFileInfo) Size() int64 { - return fi.contentSize -} - -// Mode returns file mode. -func (fi *memFileInfo) Mode() os.FileMode { - return fi.Mode() -} - -// ModTime returns the last modified time of raw file. -func (fi *memFileInfo) ModTime() time.Time { - return fi.modTime -} - -// IsDir returns the compressing file is a directory or not. -func (fi *memFileInfo) IsDir() bool { - return fi.IsDir() -} - -// return nil. implement the os.FileInfo interface method. -func (fi *memFileInfo) Sys() interface{} { - return nil -} - -// MemFile contains MemFileInfo and bytes offset when reading. -// it implements io.Reader,io.ReadCloser and io.Seeker. -type memFile struct { - fi *memFileInfo - offset int64 -} - -// Close memfile. -func (f *memFile) Close() error { - return nil -} - -// Get os.FileInfo of memfile. -func (f *memFile) Stat() (os.FileInfo, error) { - return f.fi, nil -} - -// read os.FileInfo of files in directory of memfile. -// it returns empty slice. -func (f *memFile) Readdir(count int) ([]os.FileInfo, error) { - infos := []os.FileInfo{} - - return infos, nil -} - -// Read bytes from the compressed file bytes. -func (f *memFile) Read(p []byte) (n int, err error) { - if len(f.fi.content)-int(f.offset) >= len(p) { - n = len(p) - } else { - n = len(f.fi.content) - int(f.offset) - err = io.EOF - } - copy(p, f.fi.content[f.offset:f.offset+int64(n)]) - f.offset += int64(n) - return -} - -var errWhence = errors.New("Seek: invalid whence") -var errOffset = errors.New("Seek: invalid offset") - -// Read bytes from the compressed file bytes by seeker. -func (f *memFile) Seek(offset int64, whence int) (ret int64, err error) { - switch whence { - default: - return 0, errWhence - case os.SEEK_SET: - case os.SEEK_CUR: - offset += f.offset - case os.SEEK_END: - offset += int64(len(f.fi.content)) - } - if offset < 0 || int(offset) > len(f.fi.content) { - return 0, errOffset - } - f.offset = offset - return f.offset, nil -} - -// GetAcceptEncodingZip returns accept encoding format in http header. -// zip is first, then deflate if both accepted. -// If no accepted, return empty string. -func getAcceptEncodingZip(r *http.Request) string { - ss := r.Header.Get("Accept-Encoding") - ss = strings.ToLower(ss) - if strings.Contains(ss, "gzip") { - return "gzip" - } else if strings.Contains(ss, "deflate") { - return "deflate" - } else { - return "" - } -} diff --git a/middleware/i18n.go b/middleware/i18n.go deleted file mode 100644 index f54b4bb5..00000000 --- a/middleware/i18n.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Usage: -// -// import "github.com/astaxie/beego/middleware" -// -// I18N = middleware.NewLocale("conf/i18n.conf", beego.AppConfig.String("language")) -// -// more docs: http://beego.me/docs/module/i18n.md -package middleware - -import ( - "encoding/json" - "io/ioutil" - "os" -) - -type Translation struct { - filepath string - CurrentLocal string - Locales map[string]map[string]string -} - -func NewLocale(filepath string, defaultlocal string) *Translation { - file, err := os.Open(filepath) - if err != nil { - panic("open " + filepath + " err :" + err.Error()) - } - data, err := ioutil.ReadAll(file) - if err != nil { - panic("read " + filepath + " err :" + err.Error()) - } - - i18n := make(map[string]map[string]string) - if err = json.Unmarshal(data, &i18n); err != nil { - panic("json.Unmarshal " + filepath + " err :" + err.Error()) - } - return &Translation{ - filepath: filepath, - CurrentLocal: defaultlocal, - Locales: i18n, - } -} - -func (t *Translation) SetLocale(local string) { - t.CurrentLocal = local -} - -func (t *Translation) Translate(key string, local string) string { - if local == "" { - local = t.CurrentLocal - } - if ct, ok := t.Locales[key]; ok { - if v, o := ct[local]; o { - return v - } - } - return key -} diff --git a/migration/ddl.go b/migration/ddl.go index f9b60117..51243337 100644 --- a/migration/ddl.go +++ b/migration/ddl.go @@ -14,33 +14,40 @@ package migration +// Table store the tablename and Column type Table struct { TableName string Columns []*Column } +// Create return the create sql func (t *Table) Create() string { return "" } +// Drop return the drop sql func (t *Table) Drop() string { return "" } +// Column define the columns name type and Default type Column struct { Name string Type string Default interface{} } +// Create return create sql with the provided tbname and columns func Create(tbname string, columns ...Column) string { return "" } +// Drop return the drop sql with the provided tbname and columns func Drop(tbname string, columns ...Column) string { return "" } +// TableDDL is still in think func TableDDL(tbname string, columns ...Column) string { return "" } diff --git a/migration/migration.go b/migration/migration.go index d64d60d3..1591bc50 100644 --- a/migration/migration.go +++ b/migration/migration.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// migration package for migration +// Package migration is used for migration // // The table structure is as follow: // @@ -39,8 +39,8 @@ import ( // const the data format for the bee generate migration datatype const ( - M_DATE_FORMAT = "20060102_150405" - M_DB_DATE_FORMAT = "2006-01-02 15:04:05" + DateFormat = "20060102_150405" + DBDateFormat = "2006-01-02 15:04:05" ) // Migrationer is an interface for all Migration struct @@ -60,24 +60,24 @@ func init() { migrationMap = make(map[string]Migrationer) } -// the basic type which will implement the basic type +// Migration the basic type which will implement the basic type type Migration struct { sqls []string Created string } -// implement in the Inheritance struct for upgrade +// Up implement in the Inheritance struct for upgrade func (m *Migration) Up() { } -// implement in the Inheritance struct for down +// Down implement in the Inheritance struct for down func (m *Migration) Down() { } -// add sql want to execute -func (m *Migration) Sql(sql string) { +// SQL add sql want to execute +func (m *Migration) SQL(sql string) { m.sqls = append(m.sqls, sql) } @@ -86,7 +86,7 @@ func (m *Migration) Reset() { m.sqls = make([]string, 0) } -// execute the sql already add in the sql +// Exec execute the sql already add in the sql func (m *Migration) Exec(name, status string) error { o := orm.NewOrm() for _, s := range m.sqls { @@ -104,33 +104,32 @@ func (m *Migration) addOrUpdateRecord(name, status string) error { o := orm.NewOrm() if status == "down" { status = "rollback" - p, err := o.Raw("update migrations set `status` = ?, `rollback_statements` = ?, `created_at` = ? where name = ?").Prepare() + p, err := o.Raw("update migrations set status = ?, rollback_statements = ?, created_at = ? where name = ?").Prepare() if err != nil { return nil } - _, err = p.Exec(status, strings.Join(m.sqls, "; "), time.Now().Format(M_DB_DATE_FORMAT), name) - return err - } else { - status = "update" - p, err := o.Raw("insert into migrations(`name`, `created_at`, `statements`, `status`) values(?,?,?,?)").Prepare() - if err != nil { - return err - } - _, err = p.Exec(name, time.Now().Format(M_DB_DATE_FORMAT), strings.Join(m.sqls, "; "), status) + _, err = p.Exec(status, strings.Join(m.sqls, "; "), time.Now().Format(DBDateFormat), name) return err } + status = "update" + p, err := o.Raw("insert into migrations(name, created_at, statements, status) values(?,?,?,?)").Prepare() + if err != nil { + return err + } + _, err = p.Exec(name, time.Now().Format(DBDateFormat), strings.Join(m.sqls, "; "), status) + return err } -// get the unixtime from the Created +// GetCreated get the unixtime from the Created func (m *Migration) GetCreated() int64 { - t, err := time.Parse(M_DATE_FORMAT, m.Created) + t, err := time.Parse(DateFormat, m.Created) if err != nil { return 0 } return t.Unix() } -// register the Migration in the map +// Register register the Migration in the map func Register(name string, m Migrationer) error { if _, ok := migrationMap[name]; ok { return errors.New("already exist name:" + name) @@ -139,7 +138,7 @@ func Register(name string, m Migrationer) error { return nil } -// upgrate the migration from lasttime +// Upgrade upgrate the migration from lasttime func Upgrade(lasttime int64) error { sm := sortMap(migrationMap) i := 0 @@ -163,7 +162,7 @@ func Upgrade(lasttime int64) error { return nil } -//rollback the migration by the name +// Rollback rollback the migration by the name func Rollback(name string) error { if v, ok := migrationMap[name]; ok { beego.Info("start rollback") @@ -178,14 +177,13 @@ func Rollback(name string) error { beego.Info("end rollback") time.Sleep(2 * time.Second) return nil - } else { - beego.Error("not exist the migrationMap name:" + name) - time.Sleep(2 * time.Second) - return errors.New("not exist the migrationMap name:" + name) } + beego.Error("not exist the migrationMap name:" + name) + time.Sleep(2 * time.Second) + return errors.New("not exist the migrationMap name:" + name) } -// reset all migration +// Reset reset all migration // run all migration's down function func Reset() error { sm := sortMap(migrationMap) @@ -214,7 +212,7 @@ func Reset() error { return nil } -// first Reset, then Upgrade +// Refresh first Reset, then Upgrade func Refresh() error { err := Reset() if err != nil { diff --git a/mime.go b/mime.go index 20246c21..e85fcb2a 100644 --- a/mime.go +++ b/mime.go @@ -14,11 +14,7 @@ package beego -import ( - "mime" -) - -var mimemaps map[string]string = map[string]string{ +var mimemaps = map[string]string{ ".3dm": "x-world/x-3dmf", ".3dmf": "x-world/x-3dmf", ".7z": "application/x-7z-compressed", @@ -558,10 +554,3 @@ var mimemaps map[string]string = map[string]string{ ".oex": "application/x-opera-extension", ".mustache": "text/html", } - -func initMime() error { - for k, v := range mimemaps { - mime.AddExtensionType(k, v) - } - return nil -} diff --git a/namespace.go b/namespace.go index ebb7c14f..0dfdd7af 100644 --- a/namespace.go +++ b/namespace.go @@ -23,16 +23,17 @@ import ( type namespaceCond func(*beecontext.Context) bool -type innnerNamespace func(*Namespace) +// LinkNamespace used as link action +type LinkNamespace func(*Namespace) // Namespace is store all the info type Namespace struct { prefix string - handlers *ControllerRegistor + handlers *ControllerRegister } -// get new Namespace -func NewNamespace(prefix string, params ...innnerNamespace) *Namespace { +// NewNamespace get new Namespace +func NewNamespace(prefix string, params ...LinkNamespace) *Namespace { ns := &Namespace{ prefix: prefix, handlers: NewControllerRegister(), @@ -43,7 +44,7 @@ func NewNamespace(prefix string, params ...innnerNamespace) *Namespace { return ns } -// set condtion function +// Cond set condtion function // if cond return true can run this namespace, else can't // usage: // ns.Cond(func (ctx *context.Context) bool{ @@ -72,7 +73,7 @@ func (n *Namespace) Cond(cond namespaceCond) *Namespace { return n } -// add filter in the Namespace +// Filter add filter in the Namespace // action has before & after // FilterFunc // usage: @@ -95,98 +96,98 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { return n } -// same as beego.Rourer +// Router same as beego.Rourer // refer: https://godoc.org/github.com/astaxie/beego#Router func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace { n.handlers.Add(rootpath, c, mappingMethods...) return n } -// same as beego.AutoRouter +// AutoRouter same as beego.AutoRouter // refer: https://godoc.org/github.com/astaxie/beego#AutoRouter func (n *Namespace) AutoRouter(c ControllerInterface) *Namespace { n.handlers.AddAuto(c) return n } -// same as beego.AutoPrefix +// AutoPrefix same as beego.AutoPrefix // refer: https://godoc.org/github.com/astaxie/beego#AutoPrefix func (n *Namespace) AutoPrefix(prefix string, c ControllerInterface) *Namespace { n.handlers.AddAutoPrefix(prefix, c) return n } -// same as beego.Get +// Get same as beego.Get // refer: https://godoc.org/github.com/astaxie/beego#Get func (n *Namespace) Get(rootpath string, f FilterFunc) *Namespace { n.handlers.Get(rootpath, f) return n } -// same as beego.Post +// Post same as beego.Post // refer: https://godoc.org/github.com/astaxie/beego#Post func (n *Namespace) Post(rootpath string, f FilterFunc) *Namespace { n.handlers.Post(rootpath, f) return n } -// same as beego.Delete +// Delete same as beego.Delete // refer: https://godoc.org/github.com/astaxie/beego#Delete func (n *Namespace) Delete(rootpath string, f FilterFunc) *Namespace { n.handlers.Delete(rootpath, f) return n } -// same as beego.Put +// Put same as beego.Put // refer: https://godoc.org/github.com/astaxie/beego#Put func (n *Namespace) Put(rootpath string, f FilterFunc) *Namespace { n.handlers.Put(rootpath, f) return n } -// same as beego.Head +// Head same as beego.Head // refer: https://godoc.org/github.com/astaxie/beego#Head func (n *Namespace) Head(rootpath string, f FilterFunc) *Namespace { n.handlers.Head(rootpath, f) return n } -// same as beego.Options +// Options same as beego.Options // refer: https://godoc.org/github.com/astaxie/beego#Options func (n *Namespace) Options(rootpath string, f FilterFunc) *Namespace { n.handlers.Options(rootpath, f) return n } -// same as beego.Patch +// Patch same as beego.Patch // refer: https://godoc.org/github.com/astaxie/beego#Patch func (n *Namespace) Patch(rootpath string, f FilterFunc) *Namespace { n.handlers.Patch(rootpath, f) return n } -// same as beego.Any +// Any same as beego.Any // refer: https://godoc.org/github.com/astaxie/beego#Any func (n *Namespace) Any(rootpath string, f FilterFunc) *Namespace { n.handlers.Any(rootpath, f) return n } -// same as beego.Handler +// Handler same as beego.Handler // refer: https://godoc.org/github.com/astaxie/beego#Handler func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace { n.handlers.Handler(rootpath, h) return n } -// add include class +// Include add include class // refer: https://godoc.org/github.com/astaxie/beego#Include func (n *Namespace) Include(cList ...ControllerInterface) *Namespace { n.handlers.Include(cList...) return n } -// nest Namespace +// Namespace add nest Namespace // usage: //ns := beego.NewNamespace(“/v1”). //Namespace( @@ -230,7 +231,7 @@ func (n *Namespace) Namespace(ns ...*Namespace) *Namespace { return n } -// register Namespace into beego.Handler +// AddNamespace register Namespace into beego.Handler // support multi Namespace func AddNamespace(nl ...*Namespace) { for _, n := range nl { @@ -275,113 +276,113 @@ func addPrefix(t *Tree, prefix string) { } -// Namespace Condition -func NSCond(cond namespaceCond) innnerNamespace { +// NSCond is Namespace Condition +func NSCond(cond namespaceCond) LinkNamespace { return func(ns *Namespace) { ns.Cond(cond) } } -// Namespace BeforeRouter filter -func NSBefore(filiterList ...FilterFunc) innnerNamespace { +// NSBefore Namespace BeforeRouter filter +func NSBefore(filiterList ...FilterFunc) LinkNamespace { return func(ns *Namespace) { ns.Filter("before", filiterList...) } } -// Namespace FinishRouter filter -func NSAfter(filiterList ...FilterFunc) innnerNamespace { +// NSAfter add Namespace FinishRouter filter +func NSAfter(filiterList ...FilterFunc) LinkNamespace { return func(ns *Namespace) { ns.Filter("after", filiterList...) } } -// Namespace Include ControllerInterface -func NSInclude(cList ...ControllerInterface) innnerNamespace { +// NSInclude Namespace Include ControllerInterface +func NSInclude(cList ...ControllerInterface) LinkNamespace { return func(ns *Namespace) { ns.Include(cList...) } } -// Namespace Router -func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) innnerNamespace { +// NSRouter call Namespace Router +func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) LinkNamespace { return func(ns *Namespace) { ns.Router(rootpath, c, mappingMethods...) } } -// Namespace Get -func NSGet(rootpath string, f FilterFunc) innnerNamespace { +// NSGet call Namespace Get +func NSGet(rootpath string, f FilterFunc) LinkNamespace { return func(ns *Namespace) { ns.Get(rootpath, f) } } -// Namespace Post -func NSPost(rootpath string, f FilterFunc) innnerNamespace { +// NSPost call Namespace Post +func NSPost(rootpath string, f FilterFunc) LinkNamespace { return func(ns *Namespace) { ns.Post(rootpath, f) } } -// Namespace Head -func NSHead(rootpath string, f FilterFunc) innnerNamespace { +// NSHead call Namespace Head +func NSHead(rootpath string, f FilterFunc) LinkNamespace { return func(ns *Namespace) { ns.Head(rootpath, f) } } -// Namespace Put -func NSPut(rootpath string, f FilterFunc) innnerNamespace { +// NSPut call Namespace Put +func NSPut(rootpath string, f FilterFunc) LinkNamespace { return func(ns *Namespace) { ns.Put(rootpath, f) } } -// Namespace Delete -func NSDelete(rootpath string, f FilterFunc) innnerNamespace { +// NSDelete call Namespace Delete +func NSDelete(rootpath string, f FilterFunc) LinkNamespace { return func(ns *Namespace) { ns.Delete(rootpath, f) } } -// Namespace Any -func NSAny(rootpath string, f FilterFunc) innnerNamespace { +// NSAny call Namespace Any +func NSAny(rootpath string, f FilterFunc) LinkNamespace { return func(ns *Namespace) { ns.Any(rootpath, f) } } -// Namespace Options -func NSOptions(rootpath string, f FilterFunc) innnerNamespace { +// NSOptions call Namespace Options +func NSOptions(rootpath string, f FilterFunc) LinkNamespace { return func(ns *Namespace) { ns.Options(rootpath, f) } } -// Namespace Patch -func NSPatch(rootpath string, f FilterFunc) innnerNamespace { +// NSPatch call Namespace Patch +func NSPatch(rootpath string, f FilterFunc) LinkNamespace { return func(ns *Namespace) { ns.Patch(rootpath, f) } } -//Namespace AutoRouter -func NSAutoRouter(c ControllerInterface) innnerNamespace { +// NSAutoRouter call Namespace AutoRouter +func NSAutoRouter(c ControllerInterface) LinkNamespace { return func(ns *Namespace) { ns.AutoRouter(c) } } -// Namespace AutoPrefix -func NSAutoPrefix(prefix string, c ControllerInterface) innnerNamespace { +// NSAutoPrefix call Namespace AutoPrefix +func NSAutoPrefix(prefix string, c ControllerInterface) LinkNamespace { return func(ns *Namespace) { ns.AutoPrefix(prefix, c) } } -// Namespace add sub Namespace -func NSNamespace(prefix string, params ...innnerNamespace) innnerNamespace { +// NSNamespace add sub Namespace +func NSNamespace(prefix string, params ...LinkNamespace) LinkNamespace { return func(ns *Namespace) { n := NewNamespace(prefix, params...) ns.Namespace(n) diff --git a/orm/cmd.go b/orm/cmd.go index 2358ef3c..3638a75c 100644 --- a/orm/cmd.go +++ b/orm/cmd.go @@ -46,7 +46,7 @@ func printHelp(errs ...string) { os.Exit(2) } -// listen for orm command and then run it if command arguments passed. +// RunCommand listen for orm command and then run it if command arguments passed. func RunCommand() { if len(os.Args) < 2 || os.Args[1] != "orm" { return @@ -100,7 +100,7 @@ func (d *commandSyncDb) Parse(args []string) { func (d *commandSyncDb) Run() error { var drops []string if d.force { - drops = getDbDropSql(d.al) + drops = getDbDropSQL(d.al) } db := d.al.DB @@ -124,7 +124,7 @@ func (d *commandSyncDb) Run() error { } } - sqls, indexes := getDbCreateSql(d.al) + sqls, indexes := getDbCreateSQL(d.al) tables, err := d.al.DbBaser.GetTables(db) if err != nil { @@ -180,7 +180,7 @@ func (d *commandSyncDb) Run() error { fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table) } - query := idx.Sql + query := idx.SQL _, err := db.Exec(query) if d.verbose { fmt.Printf(" %s\n", query) @@ -203,7 +203,7 @@ func (d *commandSyncDb) Run() error { queries := []string{sqls[i]} for _, idx := range indexes[mi.table] { - queries = append(queries, idx.Sql) + queries = append(queries, idx.SQL) } for _, query := range queries { @@ -228,12 +228,12 @@ func (d *commandSyncDb) Run() error { } // database creation commander interface implement. -type commandSqlAll struct { +type commandSQLAll struct { al *alias } // parse orm command line arguments. -func (d *commandSqlAll) Parse(args []string) { +func (d *commandSQLAll) Parse(args []string) { var name string flagSet := flag.NewFlagSet("orm command: sqlall", flag.ExitOnError) @@ -244,13 +244,13 @@ func (d *commandSqlAll) Parse(args []string) { } // run orm line command. -func (d *commandSqlAll) Run() error { - sqls, indexes := getDbCreateSql(d.al) +func (d *commandSQLAll) Run() error { + sqls, indexes := getDbCreateSQL(d.al) var all []string for i, mi := range modelCache.allOrdered() { queries := []string{sqls[i]} for _, idx := range indexes[mi.table] { - queries = append(queries, idx.Sql) + queries = append(queries, idx.SQL) } sql := strings.Join(queries, "\n") all = append(all, sql) @@ -262,10 +262,10 @@ func (d *commandSqlAll) Run() error { func init() { commands["syncdb"] = new(commandSyncDb) - commands["sqlall"] = new(commandSqlAll) + commands["sqlall"] = new(commandSQLAll) } -// run syncdb command line. +// RunSyncdb run syncdb command line. // name means table's alias name. default is "default". // force means run next sql if the current is error. // verbose means show all info when running command or not. diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go index ea105624..da0ee8ab 100644 --- a/orm/cmd_utils.go +++ b/orm/cmd_utils.go @@ -23,11 +23,11 @@ import ( type dbIndex struct { Table string Name string - Sql string + SQL string } // create database drop sql. -func getDbDropSql(al *alias) (sqls []string) { +func getDbDropSQL(al *alias) (sqls []string) { if len(modelCache.cache) == 0 { fmt.Println("no Model found, need register your model") os.Exit(2) @@ -45,13 +45,14 @@ func getDbDropSql(al *alias) (sqls []string) { func getColumnTyp(al *alias, fi *fieldInfo) (col string) { T := al.DbBaser.DbTypes() fieldType := fi.fieldType + fieldSize := fi.size checkColumn: switch fieldType { case TypeBooleanField: col = T["bool"] case TypeCharField: - col = fmt.Sprintf(T["string"], fi.size) + col = fmt.Sprintf(T["string"], fieldSize) case TypeTextField: col = T["string-text"] case TypeDateField: @@ -65,7 +66,7 @@ checkColumn: case TypeIntegerField: col = T["int32"] case TypeBigIntegerField: - if al.Driver == DR_Sqlite { + if al.Driver == DRSqlite { fieldType = TypeIntegerField goto checkColumn } @@ -89,6 +90,7 @@ checkColumn: } case RelForeignKey, RelOneToOne: fieldType = fi.relModelInfo.fields.pk.fieldType + fieldSize = fi.relModelInfo.fields.pk.size goto checkColumn } @@ -104,15 +106,15 @@ func getColumnAddQuery(al *alias, fi *fieldInfo) string { typ += " " + "NOT NULL" } - return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s", - Q, fi.mi.table, Q, - Q, fi.column, Q, + return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s", + Q, fi.mi.table, Q, + Q, fi.column, Q, typ, getColumnDefault(fi), ) } // create database creation string. -func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) { +func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) { if len(modelCache.cache) == 0 { fmt.Println("no Model found, need register your model") os.Exit(2) @@ -142,7 +144,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex if fi.auto { switch al.Driver { - case DR_Sqlite, DR_Postgres: + case DRSqlite, DRPostgres: column += T["auto"] default: column += col + " " + T["auto"] @@ -159,7 +161,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex //if fi.initial.String() != "" { // column += " DEFAULT " + fi.initial.String() //} - + // Append attribute DEFAULT column += getColumnDefault(fi) @@ -201,7 +203,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex sql += strings.Join(columns, ",\n") sql += "\n)" - if al.Driver == DR_MySQL { + if al.Driver == DRMySQL { var engine string if mi.model != nil { engine = getTableEngine(mi.addrField) @@ -237,7 +239,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex index := dbIndex{} index.Table = mi.table index.Name = name - index.Sql = sql + index.SQL = sql tableIndexes[mi.table] = append(tableIndexes[mi.table], index) } @@ -247,7 +249,6 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex return } - // Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands func getColumnDefault(fi *fieldInfo) string { var ( @@ -263,16 +264,20 @@ func getColumnDefault(fi *fieldInfo) string { // These defaults will be useful if there no config value orm:"default" and NOT NULL is on switch fi.fieldType { - case TypeDateField, TypeDateTimeField: - return v; - - case TypeBooleanField, TypeBitField, TypeSmallIntegerField, TypeIntegerField, - TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField, + case TypeDateField, TypeDateTimeField, TypeTextField: + return v + + case TypeBitField, TypeSmallIntegerField, TypeIntegerField, + TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField, TypePositiveIntegerField, TypePositiveBigIntegerField, TypeFloatField, TypeDecimalField: - d = "0" + t = " DEFAULT %s " + d = "0" + case TypeBooleanField: + t = " DEFAULT %s " + d = "FALSE" } - + if fi.colDefault { if !fi.initial.Exist() { v = fmt.Sprintf(t, "") diff --git a/orm/db.go b/orm/db.go index 20dc80f2..b62c165b 100644 --- a/orm/db.go +++ b/orm/db.go @@ -24,12 +24,13 @@ import ( ) const ( - format_Date = "2006-01-02" - format_DateTime = "2006-01-02 15:04:05" + formatDate = "2006-01-02" + formatDateTime = "2006-01-02 15:04:05" ) var ( - ErrMissPK = errors.New("missed pk value") // missing pk error + // ErrMissPK missing pk error + ErrMissPK = errors.New("missed pk value") ) var ( @@ -216,14 +217,14 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val } } if fi.null == false && value == nil { - return nil, errors.New(fmt.Sprintf("field `%s` cannot be NULL", fi.fullName)) + return nil, fmt.Errorf("field `%s` cannot be NULL", fi.fullName) } } } } switch fi.fieldType { case TypeDateField, TypeDateTimeField: - if fi.auto_now || fi.auto_now_add && insert { + if fi.autoNow || fi.autoNowAdd && insert { if insert { if t, ok := value.(time.Time); ok && !t.IsZero() { break @@ -282,13 +283,12 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, var id int64 err := row.Scan(&id) return id, err - } else { - if res, err := stmt.Exec(values...); err == nil { - return res.LastInsertId() - } else { - return 0, err - } } + res, err := stmt.Exec(values...) + if err == nil { + return res.LastInsertId() + } + return 0, err } // query sql ,read records and persist in dbBaser. @@ -339,15 +339,11 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo return ErrNoRows } return err - } else { - elm := reflect.New(mi.addrField.Elem().Type()) - mind := reflect.Indirect(elm) - - d.setColsValues(mi, &mind, mi.fields.dbcols, refs, tz) - - ind.Set(mind) } - + elm := reflect.New(mi.addrField.Elem().Type()) + mind := reflect.Indirect(elm) + d.setColsValues(mi, &mind, mi.fields.dbcols, refs, tz) + ind.Set(mind) return nil } @@ -444,20 +440,19 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s d.ins.ReplaceMarks(&query) if isMulti || !d.ins.HasReturningID(mi, &query) { - if res, err := q.Exec(query, values...); err == nil { + res, err := q.Exec(query, values...) + if err == nil { if isMulti { return res.RowsAffected() } return res.LastInsertId() - } else { - return 0, err } - } else { - row := q.QueryRow(query, values...) - var id int64 - err := row.Scan(&id) - return id, err + return 0, err } + row := q.QueryRow(query, values...) + var id int64 + err := row.Scan(&id) + return id, err } // execute update sql dbQuerier with given struct reflect.Value. @@ -493,11 +488,11 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. d.ins.ReplaceMarks(&query) - if res, err := q.Exec(query, setValues...); err == nil { + res, err := q.Exec(query, setValues...) + if err == nil { return res.RowsAffected() - } else { - return 0, err } + return 0, err } // execute delete sql dbQuerier with given struct reflect.Value. @@ -513,14 +508,12 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, pkName, Q) d.ins.ReplaceMarks(&query) - - if res, err := q.Exec(query, pkValue); err == nil { - + res, err := q.Exec(query, pkValue) + if err == nil { num, err := res.RowsAffected() if err != nil { return 0, err } - if num > 0 { if mi.fields.pk.auto { if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { @@ -529,17 +522,14 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. ind.Field(mi.fields.pk.fieldIndex).SetInt(0) } } - err := d.deleteRels(q, mi, []interface{}{pkValue}, tz) if err != nil { return num, err } } - return num, err - } else { - return 0, err } + return 0, err } // update table-related record by querySet. @@ -565,11 +555,11 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con tables.parseRelated(qs.related, qs.relDepth) } - where, args := tables.getCondSql(cond, false, tz) + where, args := tables.getCondSQL(cond, false, tz) values = append(values, args...) - join := tables.getJoinSql() + join := tables.getJoinSQL() var query, T string @@ -585,13 +575,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con col := fmt.Sprintf("%s%s%s%s", T, Q, v, Q) if c, ok := values[i].(colValue); ok { switch c.opt { - case Col_Add: + case ColAdd: cols = append(cols, col+" = "+col+" + ?") - case Col_Minus: + case ColMinus: cols = append(cols, col+" = "+col+" - ?") - case Col_Multiply: + case ColMultiply: cols = append(cols, col+" = "+col+" * ?") - case Col_Except: + case ColExcept: cols = append(cols, col+" = "+col+" / ?") } values[i] = c.value @@ -610,12 +600,11 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con } d.ins.ReplaceMarks(&query) - - if res, err := q.Exec(query, values...); err == nil { + res, err := q.Exec(query, values...) + if err == nil { return res.RowsAffected() - } else { - return 0, err } + return 0, err } // delete related records. @@ -624,23 +613,23 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz * for _, fi := range mi.fields.fieldsReverse { fi = fi.reverseFieldInfo switch fi.onDelete { - case od_CASCADE: + case odCascade: cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) _, err := d.DeleteBatch(q, nil, fi.mi, cond, tz) if err != nil { return err } - case od_SET_DEFAULT, od_SET_NULL: + case odSetDefault, odSetNULL: cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) params := Params{fi.column: nil} - if fi.onDelete == od_SET_DEFAULT { + if fi.onDelete == odSetDefault { params[fi.column] = fi.initial.String() } _, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz) if err != nil { return err } - case od_DO_NOTHING: + case odDoNothing: } } return nil @@ -661,8 +650,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con Q := d.ins.TableQuote() - where, args := tables.getCondSql(cond, false, tz) - join := tables.getJoinSql() + where, args := tables.getCondSQL(cond, false, tz) + join := tables.getJoinSQL() cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q) query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s", cols, Q, mi.table, Q, join, where) @@ -670,16 +659,14 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con d.ins.ReplaceMarks(&query) var rs *sql.Rows - if r, err := q.Query(query, args...); err != nil { + r, err := q.Query(query, args...) + if err != nil { return 0, err - } else { - rs = r } - + rs = r defer rs.Close() var ref interface{} - args = make([]interface{}, 0) cnt := 0 for rs.Next() { @@ -702,24 +689,21 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql) d.ins.ReplaceMarks(&query) - - if res, err := q.Exec(query, args...); err == nil { + res, err := q.Exec(query, args...) + if err == nil { num, err := res.RowsAffected() if err != nil { return 0, err } - if num > 0 { err := d.deleteRels(q, mi, args, tz) if err != nil { return num, err } } - return num, nil - } else { - return 0, err } + return 0, err } // read related records. @@ -801,10 +785,11 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi tables := newDbTables(mi, d.ins) tables.parseRelated(qs.related, qs.relDepth) - where, args := tables.getCondSql(cond, false, tz) - orderBy := tables.getOrderSql(qs.orders) - limit := tables.getLimitSql(mi, offset, rlimit) - join := tables.getJoinSql() + where, args := tables.getCondSQL(cond, false, tz) + groupBy := tables.getGroupSQL(qs.groups) + orderBy := tables.getOrderSQL(qs.orders) + limit := tables.getLimitSQL(mi, offset, rlimit) + join := tables.getJoinSQL() for _, tbl := range tables.tables { if tbl.sel { @@ -814,16 +799,20 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi } } - query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", sels, Q, mi.table, Q, join, where, orderBy, limit) + sqlSelect := "SELECT" + if qs.distinct { + sqlSelect += " DISTINCT" + } + query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) d.ins.ReplaceMarks(&query) var rs *sql.Rows - if r, err := q.Query(query, args...); err != nil { + r, err := q.Query(query, args...) + if err != nil { return 0, err - } else { - rs = r } + rs = r refs := make([]interface{}, colsNum) for i := range refs { @@ -937,9 +926,9 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition tables := newDbTables(mi, d.ins) tables.parseRelated(qs.related, qs.relDepth) - where, args := tables.getCondSql(cond, false, tz) - tables.getOrderSql(qs.orders) - join := tables.getJoinSql() + where, args := tables.getCondSQL(cond, false, tz) + tables.getOrderSQL(qs.orders) + join := tables.getJoinSQL() Q := d.ins.TableQuote() @@ -954,7 +943,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition } // generate sql with replacing operator string placeholders and replaced values. -func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { +func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { sql := "" params := getFlatParams(fi, args, tz) @@ -979,7 +968,7 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri if len(params) > 1 { panic(fmt.Errorf("operator `%s` need 1 args not %d", operator, len(params))) } - sql = d.ins.OperatorSql(operator) + sql = d.ins.OperatorSQL(operator) switch operator { case "exact": if arg == nil { @@ -1107,12 +1096,12 @@ setValue: ) if len(s) >= 19 { s = s[:19] - t, err = time.ParseInLocation(format_DateTime, s, tz) + t, err = time.ParseInLocation(formatDateTime, s, tz) } else { if len(s) > 10 { s = s[:10] } - t, err = time.ParseInLocation(format_Date, s, tz) + t, err = time.ParseInLocation(formatDate, s, tz) } t = t.In(DefaultTimeLoc) @@ -1443,24 +1432,22 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond } } - where, args := tables.getCondSql(cond, false, tz) - orderBy := tables.getOrderSql(qs.orders) - limit := tables.getLimitSql(mi, qs.offset, qs.limit) - join := tables.getJoinSql() + where, args := tables.getCondSQL(cond, false, tz) + groupBy := tables.getGroupSQL(qs.groups) + orderBy := tables.getOrderSQL(qs.orders) + limit := tables.getLimitSQL(mi, qs.offset, qs.limit) + join := tables.getJoinSQL() sels := strings.Join(cols, ", ") - query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", sels, Q, mi.table, Q, join, where, orderBy, limit) + query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s%s", sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) d.ins.ReplaceMarks(&query) - var rs *sql.Rows - if r, err := q.Query(query, args...); err != nil { + rs, err := q.Query(query, args...) + if err != nil { return 0, err - } else { - rs = r } - refs := make([]interface{}, len(cols)) for i := range refs { var ref interface{} @@ -1475,11 +1462,11 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond ) for rs.Next() { if cnt == 0 { - if cols, err := rs.Columns(); err != nil { + cols, err := rs.Columns() + if err != nil { return 0, err - } else { - columns = cols } + columns = cols } if err := rs.Scan(refs...); err != nil { @@ -1643,7 +1630,7 @@ func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, e } // not implement. -func (d *dbBase) OperatorSql(operator string) string { +func (d *dbBase) OperatorSQL(operator string) string { panic(ErrNotImplement) } diff --git a/orm/db_alias.go b/orm/db_alias.go index 0a862241..79576b8e 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -22,15 +22,17 @@ import ( "time" ) -// database driver constant int. +// DriverType database driver constant int. type DriverType int +// Enum the Database driver const ( - _ DriverType = iota // int enum type - DR_MySQL // mysql - DR_Sqlite // sqlite - DR_Oracle // oracle - DR_Postgres // pgsql + _ DriverType = iota // int enum type + DRMySQL // mysql + DRSqlite // sqlite + DROracle // oracle + DRPostgres // pgsql + DRTiDB // TiDB ) // database driver string. @@ -53,15 +55,17 @@ var _ Driver = new(driver) var ( dataBaseCache = &_dbCache{cache: make(map[string]*alias)} drivers = map[string]DriverType{ - "mysql": DR_MySQL, - "postgres": DR_Postgres, - "sqlite3": DR_Sqlite, + "mysql": DRMySQL, + "postgres": DRPostgres, + "sqlite3": DRSqlite, + "tidb": DRTiDB, } dbBasers = map[DriverType]dbBaser{ - DR_MySQL: newdbBaseMysql(), - DR_Sqlite: newdbBaseSqlite(), - DR_Oracle: newdbBaseMysql(), - DR_Postgres: newdbBasePostgres(), + DRMySQL: newdbBaseMysql(), + DRSqlite: newdbBaseSqlite(), + DROracle: newdbBaseOracle(), + DRPostgres: newdbBasePostgres(), + DRTiDB: newdbBaseTidb(), } ) @@ -119,7 +123,7 @@ func detectTZ(al *alias) { } switch al.Driver { - case DR_MySQL: + case DRMySQL: row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)") var tz string row.Scan(&tz) @@ -147,10 +151,10 @@ func detectTZ(al *alias) { al.Engine = "INNODB" } - case DR_Sqlite: + case DRSqlite: al.TZ = time.UTC - case DR_Postgres: + case DRPostgres: row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')") var tz string row.Scan(&tz) @@ -188,12 +192,13 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { return al, nil } +// AddAliasWthDB add a aliasName for the drivename func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { _, err := addAliasWthDB(aliasName, driverName, db) return err } -// Setting the database connect params. Use the database driver self dataSource args. +// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { var ( err error @@ -236,7 +241,7 @@ end: return err } -// Register a database driver use specify driver name, this can be definition the driver is which database type. +// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. func RegisterDriver(driverName string, typ DriverType) error { if t, ok := drivers[driverName]; ok == false { drivers[driverName] = typ @@ -248,7 +253,7 @@ func RegisterDriver(driverName string, typ DriverType) error { return nil } -// Change the database default used timezone +// SetDataBaseTZ Change the database default used timezone func SetDataBaseTZ(aliasName string, tz *time.Location) error { if al, ok := dataBaseCache.get(aliasName); ok { al.TZ = tz @@ -258,14 +263,14 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error { return nil } -// Change the max idle conns for *sql.DB, use specify database alias name +// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name func SetMaxIdleConns(aliasName string, maxIdleConns int) { al := getDbAlias(aliasName) al.MaxIdleConns = maxIdleConns al.DB.SetMaxIdleConns(maxIdleConns) } -// Change the max open conns for *sql.DB, use specify database alias name +// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name func SetMaxOpenConns(aliasName string, maxOpenConns int) { al := getDbAlias(aliasName) al.MaxOpenConns = maxOpenConns @@ -275,7 +280,7 @@ func SetMaxOpenConns(aliasName string, maxOpenConns int) { } } -// Get *sql.DB from registered database by db alias name. +// GetDB Get *sql.DB from registered database by db alias name. // Use "default" as alias name if you not set. func GetDB(aliasNames ...string) (*sql.DB, error) { var name string @@ -284,9 +289,9 @@ func GetDB(aliasNames ...string) (*sql.DB, error) { } else { name = "default" } - if al, ok := dataBaseCache.get(name); ok { + al, ok := dataBaseCache.get(name) + if ok { return al.DB, nil - } else { - return nil, fmt.Errorf("DataBase of alias name `%s` not found\n", name) } + return nil, fmt.Errorf("DataBase of alias name `%s` not found\n", name) } diff --git a/orm/db_mysql.go b/orm/db_mysql.go index 182914a2..10fe2657 100644 --- a/orm/db_mysql.go +++ b/orm/db_mysql.go @@ -67,7 +67,7 @@ type dbBaseMysql struct { var _ dbBaser = new(dbBaseMysql) // get mysql operator. -func (d *dbBaseMysql) OperatorSql(operator string) string { +func (d *dbBaseMysql) OperatorSQL(operator string) string { return mysqlOperators[operator] } diff --git a/orm/db_postgres.go b/orm/db_postgres.go index 6500ef52..7dbef95a 100644 --- a/orm/db_postgres.go +++ b/orm/db_postgres.go @@ -66,7 +66,7 @@ type dbBasePostgres struct { var _ dbBaser = new(dbBasePostgres) // get postgresql operator. -func (d *dbBasePostgres) OperatorSql(operator string) string { +func (d *dbBasePostgres) OperatorSQL(operator string) string { return postgresOperators[operator] } @@ -101,7 +101,7 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) { num := 0 for _, c := range q { if c == '?' { - num += 1 + num++ } } if num == 0 { @@ -114,7 +114,7 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) { if c == '?' { data = append(data, '$') data = append(data, []byte(strconv.Itoa(num))...) - num += 1 + num++ } else { data = append(data, c) } diff --git a/orm/db_sqlite.go b/orm/db_sqlite.go index c2dcbb46..a3cb69a7 100644 --- a/orm/db_sqlite.go +++ b/orm/db_sqlite.go @@ -66,7 +66,7 @@ type dbBaseSqlite struct { var _ dbBaser = new(dbBaseSqlite) // get sqlite operator. -func (d *dbBaseSqlite) OperatorSql(operator string) string { +func (d *dbBaseSqlite) OperatorSQL(operator string) string { return sqliteOperators[operator] } diff --git a/orm/db_tables.go b/orm/db_tables.go index a9aa10ab..e4c74ace 100644 --- a/orm/db_tables.go +++ b/orm/db_tables.go @@ -164,7 +164,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) { } // generate join string. -func (t *dbTables) getJoinSql() (join string) { +func (t *dbTables) getJoinSQL() (join string) { Q := t.base.TableQuote() for _, jt := range t.tables { @@ -186,7 +186,7 @@ func (t *dbTables) getJoinSql() (join string) { table = jt.mi.table switch { - case jt.fi.fieldType == RelManyToMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany: + case jt.fi.fieldType == RelManyToMany || jt.fi.fieldType == RelReverseMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany: c1 = jt.fi.mi.fields.pk.column for _, ffi := range jt.mi.fields.fieldsRel { if jt.fi.mi == ffi.relModelInfo { @@ -220,7 +220,7 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string ) num := len(exprs) - 1 - names := make([]string, 0) + var names []string inner := true @@ -326,7 +326,7 @@ loopFor: } // generate condition sql. -func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) { +func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) { if cond == nil || cond.IsEmpty() { return } @@ -347,7 +347,7 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe where += "NOT " } if p.isCond { - w, ps := t.getCondSql(p.cond, true, tz) + w, ps := t.getCondSQL(p.cond, true, tz) if w != "" { w = fmt.Sprintf("( %s) ", w) } @@ -372,12 +372,12 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe operator = "exact" } - operSql, args := t.base.GenerateOperatorSql(mi, fi, operator, p.args, tz) + operSQL, args := t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz) leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q) t.base.GenerateOperatorLeftCol(fi, operator, &leftCol) - where += fmt.Sprintf("%s %s ", leftCol, operSql) + where += fmt.Sprintf("%s %s ", leftCol, operSQL) params = append(params, args...) } @@ -390,8 +390,32 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe return } +// generate group sql. +func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) { + if len(groups) == 0 { + return + } + + Q := t.base.TableQuote() + + groupSqls := make([]string, 0, len(groups)) + for _, group := range groups { + exprs := strings.Split(group, ExprSep) + + index, _, fi, suc := t.parseExprs(t.mi, exprs) + if suc == false { + panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) + } + + groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)) + } + + groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", ")) + return +} + // generate order sql. -func (t *dbTables) getOrderSql(orders []string) (orderSql string) { +func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) { if len(orders) == 0 { return } @@ -415,12 +439,12 @@ func (t *dbTables) getOrderSql(orders []string) (orderSql string) { orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc)) } - orderSql = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) + orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) return } // generate limit sql. -func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits string) { +func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) { if limit == 0 { limit = int64(DefaultRowsLimit) } diff --git a/orm/db_tidb.go b/orm/db_tidb.go new file mode 100644 index 00000000..6020a488 --- /dev/null +++ b/orm/db_tidb.go @@ -0,0 +1,63 @@ +// Copyright 2015 TiDB Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" +) + +// mysql dbBaser implementation. +type dbBaseTidb struct { + dbBase +} + +var _ dbBaser = new(dbBaseTidb) + +// get mysql operator. +func (d *dbBaseTidb) OperatorSQL(operator string) string { + return mysqlOperators[operator] +} + +// get mysql table field types. +func (d *dbBaseTidb) DbTypes() map[string]string { + return mysqlTypes +} + +// show table sql for mysql. +func (d *dbBaseTidb) ShowTablesQuery() string { + return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()" +} + +// show columns sql of table for mysql. +func (d *dbBaseTidb) ShowColumnsQuery(table string) string { + return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+ + "WHERE table_schema = DATABASE() AND table_name = '%s'", table) +} + +// execute sql to check index exist. +func (d *dbBaseTidb) IndexExists(db dbQuerier, table string, name string) bool { + row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ + "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) + var cnt int + row.Scan(&cnt) + return cnt > 0 +} + +// create new mysql dbBaser. +func newdbBaseTidb() dbBaser { + b := new(dbBaseTidb) + b.ins = b + return b +} diff --git a/orm/db_utils.go b/orm/db_utils.go index 4a3ba464..ae9b1625 100644 --- a/orm/db_utils.go +++ b/orm/db_utils.go @@ -24,9 +24,8 @@ import ( func getDbAlias(name string) *alias { if al, ok := dataBaseCache.get(name); ok { return al - } else { - panic(fmt.Errorf("unknown DataBase alias name %s", name)) } + panic(fmt.Errorf("unknown DataBase alias name %s", name)) } // get pk column info. @@ -80,19 +79,19 @@ outFor: var err error if len(v) >= 19 { s := v[:19] - t, err = time.ParseInLocation(format_DateTime, s, DefaultTimeLoc) + t, err = time.ParseInLocation(formatDateTime, s, DefaultTimeLoc) } else { s := v if len(v) > 10 { s = v[:10] } - t, err = time.ParseInLocation(format_Date, s, tz) + t, err = time.ParseInLocation(formatDate, s, tz) } if err == nil { if fi.fieldType == TypeDateField { - v = t.In(tz).Format(format_Date) + v = t.In(tz).Format(formatDate) } else { - v = t.In(tz).Format(format_DateTime) + v = t.In(tz).Format(formatDateTime) } } } @@ -137,9 +136,9 @@ outFor: case reflect.Struct: if v, ok := arg.(time.Time); ok { if fi != nil && fi.fieldType == TypeDateField { - arg = v.In(tz).Format(format_Date) + arg = v.In(tz).Format(formatDate) } else { - arg = v.In(tz).Format(format_DateTime) + arg = v.In(tz).Format(formatDateTime) } } else { typ := val.Type() diff --git a/orm/models.go b/orm/models.go index dcb32b55..faf551be 100644 --- a/orm/models.go +++ b/orm/models.go @@ -19,10 +19,10 @@ import ( ) const ( - od_CASCADE = "cascade" - od_SET_NULL = "set_null" - od_SET_DEFAULT = "set_default" - od_DO_NOTHING = "do_nothing" + odCascade = "cascade" + odSetNULL = "set_null" + odSetDefault = "set_default" + odDoNothing = "do_nothing" defaultStructTagName = "orm" defaultStructTagDelim = ";" ) @@ -113,7 +113,7 @@ func (mc *_modelCache) clean() { mc.done = false } -// Clean model cache. Then you can re-RegisterModel. +// ResetModelCache Clean model cache. Then you can re-RegisterModel. // Common use this api for test case. func ResetModelCache() { modelCache.clean() diff --git a/orm/models_boot.go b/orm/models_boot.go index cb44bc05..3690557b 100644 --- a/orm/models_boot.go +++ b/orm/models_boot.go @@ -51,19 +51,16 @@ func registerModel(prefix string, model interface{}) { } info := newModelInfo(val) - if info.fields.pk == nil { outFor: for _, fi := range info.fields.fieldsDB { - if fi.name == "Id" { - if fi.sf.Tag.Get(defaultStructTagName) == "" { - switch fi.addrValue.Elem().Kind() { - case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: - fi.auto = true - fi.pk = true - info.fields.pk = fi - break outFor - } + if strings.ToLower(fi.name) == "id" { + switch fi.addrValue.Elem().Kind() { + case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: + fi.auto = true + fi.pk = true + info.fields.pk = fi + break outFor } } } @@ -269,7 +266,10 @@ func bootStrap() { if found == false { mForC: for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] { - if ffi.relModelInfo == mi { + conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough || + fi.relTable != "" && fi.relTable == ffi.relTable || + fi.relThrough == "" && fi.relTable == "" + if ffi.relModelInfo == mi && conditions { found = true fi.reverseField = ffi.reverseFieldInfoTwo.name @@ -298,12 +298,12 @@ end: } } -// register models +// RegisterModel register models func RegisterModel(models ...interface{}) { RegisterModelWithPrefix("", models...) } -// register models with a prefix +// RegisterModelWithPrefix register models with a prefix func RegisterModelWithPrefix(prefix string, models ...interface{}) { if modelCache.done { panic(fmt.Errorf("RegisterModel must be run before BootStrap")) @@ -314,7 +314,7 @@ func RegisterModelWithPrefix(prefix string, models ...interface{}) { } } -// bootrap models. +// BootStrap bootrap models. // make all model parsed and can not add more models func BootStrap() { if modelCache.done { diff --git a/orm/models_fields.go b/orm/models_fields.go index f038dd0f..a8cf8e4f 100644 --- a/orm/models_fields.go +++ b/orm/models_fields.go @@ -15,49 +15,28 @@ package orm import ( - "errors" "fmt" "strconv" "time" ) +// Define the Type enum const ( - // bool TypeBooleanField = 1 << iota - - // string TypeCharField - - // string TypeTextField - - // time.Time TypeDateField - // time.Time TypeDateTimeField - - // int8 TypeBitField - // int16 TypeSmallIntegerField - // int32 TypeIntegerField - // int64 TypeBigIntegerField - // uint8 TypePositiveBitField - // uint16 TypePositiveSmallIntegerField - // uint32 TypePositiveIntegerField - // uint64 TypePositiveBigIntegerField - - // float64 TypeFloatField - // float64 TypeDecimalField - RelForeignKey RelOneToOne RelManyToMany @@ -65,6 +44,7 @@ const ( RelReverseMany ) +// Define some logic enum const ( IsIntegerField = ^-TypePositiveBigIntegerField >> 4 << 5 IsPostiveIntegerField = ^-TypePositiveBigIntegerField >> 8 << 9 @@ -72,25 +52,30 @@ const ( IsFieldType = ^-RelReverseMany<<1 + 1 ) -// A true/false field. +// BooleanField A true/false field. type BooleanField bool +// Value return the BooleanField func (e BooleanField) Value() bool { return bool(e) } +// Set will set the BooleanField func (e *BooleanField) Set(d bool) { *e = BooleanField(d) } +// String format the Bool to string func (e *BooleanField) String() string { return strconv.FormatBool(e.Value()) } +// FieldType return BooleanField the type func (e *BooleanField) FieldType() int { return TypeBooleanField } +// SetRaw set the interface to bool func (e *BooleanField) SetRaw(value interface{}) error { switch d := value.(type) { case bool: @@ -102,56 +87,65 @@ func (e *BooleanField) SetRaw(value interface{}) error { } return err default: - return errors.New(fmt.Sprintf(" unknown value `%s`", value)) + return fmt.Errorf(" unknown value `%s`", value) } return nil } +// RawValue return the current value func (e *BooleanField) RawValue() interface{} { return e.Value() } +// verify the BooleanField implement the Fielder interface var _ Fielder = new(BooleanField) -// A string field +// CharField A string field // required values tag: size // The size is enforced at the database level and in models’s validation. // eg: `orm:"size(120)"` type CharField string +// Value return the CharField's Value func (e CharField) Value() string { return string(e) } +// Set CharField value func (e *CharField) Set(d string) { *e = CharField(d) } +// String return the CharField func (e *CharField) String() string { return e.Value() } +// FieldType return the enum type func (e *CharField) FieldType() int { return TypeCharField } +// SetRaw set the interface to string func (e *CharField) SetRaw(value interface{}) error { switch d := value.(type) { case string: e.Set(d) default: - return errors.New(fmt.Sprintf(" unknown value `%s`", value)) + return fmt.Errorf(" unknown value `%s`", value) } return nil } +// RawValue return the CharField value func (e *CharField) RawValue() interface{} { return e.Value() } +// verify CharField implement Fielder var _ Fielder = new(CharField) -// A date, represented in go by a time.Time instance. +// DateField A date, represented in go by a time.Time instance. // only date values like 2006-01-02 // Has a few extra, optional attr tag: // @@ -166,106 +160,125 @@ var _ Fielder = new(CharField) // eg: `orm:"auto_now"` or `orm:"auto_now_add"` type DateField time.Time +// Value return the time.Time func (e DateField) Value() time.Time { return time.Time(e) } +// Set set the DateField's value func (e *DateField) Set(d time.Time) { *e = DateField(d) } +// String convert datatime to string func (e *DateField) String() string { return e.Value().String() } +// FieldType return enum type Date func (e *DateField) FieldType() int { return TypeDateField } +// SetRaw convert the interface to time.Time. Allow string and time.Time func (e *DateField) SetRaw(value interface{}) error { switch d := value.(type) { case time.Time: e.Set(d) case string: - v, err := timeParse(d, format_Date) + v, err := timeParse(d, formatDate) if err != nil { e.Set(v) } return err default: - return errors.New(fmt.Sprintf(" unknown value `%s`", value)) + return fmt.Errorf(" unknown value `%s`", value) } return nil } +// RawValue return Date value func (e *DateField) RawValue() interface{} { return e.Value() } +// verify DateField implement fielder interface var _ Fielder = new(DateField) -// A date, represented in go by a time.Time instance. +// DateTimeField A date, represented in go by a time.Time instance. // datetime values like 2006-01-02 15:04:05 // Takes the same extra arguments as DateField. type DateTimeField time.Time +// Value return the datatime value func (e DateTimeField) Value() time.Time { return time.Time(e) } +// Set set the time.Time to datatime func (e *DateTimeField) Set(d time.Time) { *e = DateTimeField(d) } +// String return the time's String func (e *DateTimeField) String() string { return e.Value().String() } +// FieldType return the enum TypeDateTimeField func (e *DateTimeField) FieldType() int { return TypeDateTimeField } +// SetRaw convert the string or time.Time to DateTimeField func (e *DateTimeField) SetRaw(value interface{}) error { switch d := value.(type) { case time.Time: e.Set(d) case string: - v, err := timeParse(d, format_DateTime) + v, err := timeParse(d, formatDateTime) if err != nil { e.Set(v) } return err default: - return errors.New(fmt.Sprintf(" unknown value `%s`", value)) + return fmt.Errorf(" unknown value `%s`", value) } return nil } +// RawValue return the datatime value func (e *DateTimeField) RawValue() interface{} { return e.Value() } +// verify datatime implement fielder var _ Fielder = new(DateTimeField) -// A floating-point number represented in go by a float32 value. +// FloatField A floating-point number represented in go by a float32 value. type FloatField float64 +// Value return the FloatField value func (e FloatField) Value() float64 { return float64(e) } +// Set the Float64 func (e *FloatField) Set(d float64) { *e = FloatField(d) } +// String return the string func (e *FloatField) String() string { return ToStr(e.Value(), -1, 32) } +// FieldType return the enum type func (e *FloatField) FieldType() int { return TypeFloatField } +// SetRaw converter interface Float64 float32 or string to FloatField func (e *FloatField) SetRaw(value interface{}) error { switch d := value.(type) { case float32: @@ -278,36 +291,43 @@ func (e *FloatField) SetRaw(value interface{}) error { e.Set(v) } default: - return errors.New(fmt.Sprintf(" unknown value `%s`", value)) + return fmt.Errorf(" unknown value `%s`", value) } return nil } +// RawValue return the FloatField value func (e *FloatField) RawValue() interface{} { return e.Value() } +// verify FloatField implement Fielder var _ Fielder = new(FloatField) -// -32768 to 32767 +// SmallIntegerField -32768 to 32767 type SmallIntegerField int16 +// Value return int16 value func (e SmallIntegerField) Value() int16 { return int16(e) } +// Set the SmallIntegerField value func (e *SmallIntegerField) Set(d int16) { *e = SmallIntegerField(d) } +// String convert smallint to string func (e *SmallIntegerField) String() string { return ToStr(e.Value()) } +// FieldType return enum type SmallIntegerField func (e *SmallIntegerField) FieldType() int { return TypeSmallIntegerField } +// SetRaw convert interface int16/string to int16 func (e *SmallIntegerField) SetRaw(value interface{}) error { switch d := value.(type) { case int16: @@ -318,36 +338,43 @@ func (e *SmallIntegerField) SetRaw(value interface{}) error { e.Set(v) } default: - return errors.New(fmt.Sprintf(" unknown value `%s`", value)) + return fmt.Errorf(" unknown value `%s`", value) } return nil } +// RawValue return smallint value func (e *SmallIntegerField) RawValue() interface{} { return e.Value() } +// verify SmallIntegerField implement Fielder var _ Fielder = new(SmallIntegerField) -// -2147483648 to 2147483647 +// IntegerField -2147483648 to 2147483647 type IntegerField int32 +// Value return the int32 func (e IntegerField) Value() int32 { return int32(e) } +// Set IntegerField value func (e *IntegerField) Set(d int32) { *e = IntegerField(d) } +// String convert Int32 to string func (e *IntegerField) String() string { return ToStr(e.Value()) } +// FieldType return the enum type func (e *IntegerField) FieldType() int { return TypeIntegerField } +// SetRaw convert interface int32/string to int32 func (e *IntegerField) SetRaw(value interface{}) error { switch d := value.(type) { case int32: @@ -358,36 +385,43 @@ func (e *IntegerField) SetRaw(value interface{}) error { e.Set(v) } default: - return errors.New(fmt.Sprintf(" unknown value `%s`", value)) + return fmt.Errorf(" unknown value `%s`", value) } return nil } +// RawValue return IntegerField value func (e *IntegerField) RawValue() interface{} { return e.Value() } +// verify IntegerField implement Fielder var _ Fielder = new(IntegerField) -// -9223372036854775808 to 9223372036854775807. +// BigIntegerField -9223372036854775808 to 9223372036854775807. type BigIntegerField int64 +// Value return int64 func (e BigIntegerField) Value() int64 { return int64(e) } +// Set the BigIntegerField value func (e *BigIntegerField) Set(d int64) { *e = BigIntegerField(d) } +// String convert BigIntegerField to string func (e *BigIntegerField) String() string { return ToStr(e.Value()) } +// FieldType return enum type func (e *BigIntegerField) FieldType() int { return TypeBigIntegerField } +// SetRaw convert interface int64/string to int64 func (e *BigIntegerField) SetRaw(value interface{}) error { switch d := value.(type) { case int64: @@ -398,36 +432,43 @@ func (e *BigIntegerField) SetRaw(value interface{}) error { e.Set(v) } default: - return errors.New(fmt.Sprintf(" unknown value `%s`", value)) + return fmt.Errorf(" unknown value `%s`", value) } return nil } +// RawValue return BigIntegerField value func (e *BigIntegerField) RawValue() interface{} { return e.Value() } +// verify BigIntegerField implement Fielder var _ Fielder = new(BigIntegerField) -// 0 to 65535 +// PositiveSmallIntegerField 0 to 65535 type PositiveSmallIntegerField uint16 +// Value return uint16 func (e PositiveSmallIntegerField) Value() uint16 { return uint16(e) } +// Set PositiveSmallIntegerField value func (e *PositiveSmallIntegerField) Set(d uint16) { *e = PositiveSmallIntegerField(d) } +// String convert uint16 to string func (e *PositiveSmallIntegerField) String() string { return ToStr(e.Value()) } +// FieldType return enum type func (e *PositiveSmallIntegerField) FieldType() int { return TypePositiveSmallIntegerField } +// SetRaw convert Interface uint16/string to uint16 func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error { switch d := value.(type) { case uint16: @@ -438,36 +479,43 @@ func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error { e.Set(v) } default: - return errors.New(fmt.Sprintf(" unknown value `%s`", value)) + return fmt.Errorf(" unknown value `%s`", value) } return nil } +// RawValue returns PositiveSmallIntegerField value func (e *PositiveSmallIntegerField) RawValue() interface{} { return e.Value() } +// verify PositiveSmallIntegerField implement Fielder var _ Fielder = new(PositiveSmallIntegerField) -// 0 to 4294967295 +// PositiveIntegerField 0 to 4294967295 type PositiveIntegerField uint32 +// Value return PositiveIntegerField value. Uint32 func (e PositiveIntegerField) Value() uint32 { return uint32(e) } +// Set the PositiveIntegerField value func (e *PositiveIntegerField) Set(d uint32) { *e = PositiveIntegerField(d) } +// String convert PositiveIntegerField to string func (e *PositiveIntegerField) String() string { return ToStr(e.Value()) } +// FieldType return enum type func (e *PositiveIntegerField) FieldType() int { return TypePositiveIntegerField } +// SetRaw convert interface uint32/string to Uint32 func (e *PositiveIntegerField) SetRaw(value interface{}) error { switch d := value.(type) { case uint32: @@ -478,36 +526,43 @@ func (e *PositiveIntegerField) SetRaw(value interface{}) error { e.Set(v) } default: - return errors.New(fmt.Sprintf(" unknown value `%s`", value)) + return fmt.Errorf(" unknown value `%s`", value) } return nil } +// RawValue return the PositiveIntegerField Value func (e *PositiveIntegerField) RawValue() interface{} { return e.Value() } +// verify PositiveIntegerField implement Fielder var _ Fielder = new(PositiveIntegerField) -// 0 to 18446744073709551615 +// PositiveBigIntegerField 0 to 18446744073709551615 type PositiveBigIntegerField uint64 +// Value return uint64 func (e PositiveBigIntegerField) Value() uint64 { return uint64(e) } +// Set PositiveBigIntegerField value func (e *PositiveBigIntegerField) Set(d uint64) { *e = PositiveBigIntegerField(d) } +// String convert PositiveBigIntegerField to string func (e *PositiveBigIntegerField) String() string { return ToStr(e.Value()) } +// FieldType return enum type func (e *PositiveBigIntegerField) FieldType() int { return TypePositiveIntegerField } +// SetRaw convert interface uint64/string to Uint64 func (e *PositiveBigIntegerField) SetRaw(value interface{}) error { switch d := value.(type) { case uint64: @@ -518,48 +573,57 @@ func (e *PositiveBigIntegerField) SetRaw(value interface{}) error { e.Set(v) } default: - return errors.New(fmt.Sprintf(" unknown value `%s`", value)) + return fmt.Errorf(" unknown value `%s`", value) } return nil } +// RawValue return PositiveBigIntegerField value func (e *PositiveBigIntegerField) RawValue() interface{} { return e.Value() } +// verify PositiveBigIntegerField implement Fielder var _ Fielder = new(PositiveBigIntegerField) -// A large text field. +// TextField A large text field. type TextField string +// Value return TextField value func (e TextField) Value() string { return string(e) } +// Set the TextField value func (e *TextField) Set(d string) { *e = TextField(d) } +// String convert TextField to string func (e *TextField) String() string { return e.Value() } +// FieldType return enum type func (e *TextField) FieldType() int { return TypeTextField } +// SetRaw convert interface string to string func (e *TextField) SetRaw(value interface{}) error { switch d := value.(type) { case string: e.Set(d) default: - return errors.New(fmt.Sprintf(" unknown value `%s`", value)) + return fmt.Errorf(" unknown value `%s`", value) } return nil } +// RawValue return TextField value func (e *TextField) RawValue() interface{} { return e.Value() } +// verify TextField implement Fielder var _ Fielder = new(TextField) diff --git a/orm/models_info_f.go b/orm/models_info_f.go index 84a0c024..14e1f2c6 100644 --- a/orm/models_info_f.go +++ b/orm/models_info_f.go @@ -119,8 +119,8 @@ type fieldInfo struct { colDefault bool initial StrTo size int - auto_now bool - auto_now_add bool + autoNow bool + autoNowAdd bool rel bool reverse bool reverseField string @@ -223,6 +223,11 @@ checkType: break checkType case "many": fieldType = RelReverseMany + if tv := tags["rel_table"]; tv != "" { + fi.relTable = tv + } else if tv := tags["rel_through"]; tv != "" { + fi.relThrough = tv + } break checkType default: err = fmt.Errorf("error") @@ -309,20 +314,20 @@ checkType: if fi.rel && fi.dbcol { switch onDelete { - case od_CASCADE, od_DO_NOTHING: - case od_SET_DEFAULT: + case odCascade, odDoNothing: + case odSetDefault: if initial.Exist() == false { err = errors.New("on_delete: set_default need set field a default value") goto end } - case od_SET_NULL: + case odSetNULL: if fi.null == false { err = errors.New("on_delete: set_null need set field null") goto end } default: if onDelete == "" { - onDelete = od_CASCADE + onDelete = odCascade } else { err = fmt.Errorf("on_delete value expected choice in `cascade,set_null,set_default,do_nothing`, unknown `%s`", onDelete) goto end @@ -350,9 +355,9 @@ checkType: fi.unique = false case TypeDateField, TypeDateTimeField: if attrs["auto_now"] { - fi.auto_now = true + fi.autoNow = true } else if attrs["auto_now_add"] { - fi.auto_now_add = true + fi.autoNowAdd = true } case TypeFloatField: case TypeDecimalField: diff --git a/orm/models_info_m.go b/orm/models_info_m.go index 3600ee7c..2654cdb5 100644 --- a/orm/models_info_m.go +++ b/orm/models_info_m.go @@ -15,7 +15,6 @@ package orm import ( - "errors" "fmt" "os" "reflect" @@ -72,13 +71,13 @@ func newModelInfo(val reflect.Value) (info *modelInfo) { added := info.fields.Add(fi) if added == false { - err = errors.New(fmt.Sprintf("duplicate column name: %s", fi.column)) + err = fmt.Errorf("duplicate column name: %s", fi.column) break } if fi.pk { if info.fields.pk != nil { - err = errors.New(fmt.Sprintf("one model must have one pk field only")) + err = fmt.Errorf("one model must have one pk field only") break } else { info.fields.pk = fi diff --git a/orm/models_test.go b/orm/models_test.go index 6ca9590c..ee56e8e8 100644 --- a/orm/models_test.go +++ b/orm/models_test.go @@ -25,6 +25,9 @@ import ( _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" + + // As tidb can't use go get, so disable the tidb testing now + // _ "github.com/pingcap/tidb" ) // A slice string field. @@ -76,21 +79,21 @@ func (e *SliceStringField) RawValue() interface{} { var _ Fielder = new(SliceStringField) // A json field. -type JsonField struct { +type JSONField struct { Name string Data string } -func (e *JsonField) String() string { +func (e *JSONField) String() string { data, _ := json.Marshal(e) return string(data) } -func (e *JsonField) FieldType() int { +func (e *JSONField) FieldType() int { return TypeTextField } -func (e *JsonField) SetRaw(value interface{}) error { +func (e *JSONField) SetRaw(value interface{}) error { switch d := value.(type) { case string: return json.Unmarshal([]byte(d), e) @@ -99,14 +102,14 @@ func (e *JsonField) SetRaw(value interface{}) error { } } -func (e *JsonField) RawValue() interface{} { +func (e *JSONField) RawValue() interface{} { return e.String() } -var _ Fielder = new(JsonField) +var _ Fielder = new(JSONField) type Data struct { - Id int + ID int `orm:"column(id)"` Boolean bool Char string `orm:"size(50)"` Text string `orm:"type(text)"` @@ -130,7 +133,7 @@ type Data struct { } type DataNull struct { - Id int + ID int `orm:"column(id)"` Boolean bool `orm:"null"` Char string `orm:"null;size(50)"` Text string `orm:"null;type(text)"` @@ -193,7 +196,7 @@ type Float32 float64 type Float64 float64 type DataCustom struct { - Id int + ID int `orm:"column(id)"` Boolean Boolean Char string `orm:"size(50)"` Text string `orm:"type(text)"` @@ -216,28 +219,28 @@ type DataCustom struct { // only for mysql type UserBig struct { - Id uint64 + ID uint64 `orm:"column(id)"` Name string } type User struct { - Id int - UserName string `orm:"size(30);unique"` - Email string `orm:"size(100)"` - Password string `orm:"size(100)"` - Status int16 `orm:"column(Status)"` - IsStaff bool - IsActive bool `orm:"default(1)"` - Created time.Time `orm:"auto_now_add;type(date)"` - Updated time.Time `orm:"auto_now"` - Profile *Profile `orm:"null;rel(one);on_delete(set_null)"` - Posts []*Post `orm:"reverse(many)" json:"-"` - ShouldSkip string `orm:"-"` - Nums int - Langs SliceStringField `orm:"size(100)"` - Extra JsonField `orm:"type(text)"` - unexport bool `orm:"-"` - unexport_ bool + ID int `orm:"column(id)"` + UserName string `orm:"size(30);unique"` + Email string `orm:"size(100)"` + Password string `orm:"size(100)"` + Status int16 `orm:"column(Status)"` + IsStaff bool + IsActive bool `orm:"default(true)"` + Created time.Time `orm:"auto_now_add;type(date)"` + Updated time.Time `orm:"auto_now"` + Profile *Profile `orm:"null;rel(one);on_delete(set_null)"` + Posts []*Post `orm:"reverse(many)" json:"-"` + ShouldSkip string `orm:"-"` + Nums int + Langs SliceStringField `orm:"size(100)"` + Extra JSONField `orm:"type(text)"` + unexport bool `orm:"-"` + unexportBool bool } func (u *User) TableIndex() [][]string { @@ -259,7 +262,7 @@ func NewUser() *User { } type Profile struct { - Id int + ID int `orm:"column(id)"` Age int16 Money float64 User *User `orm:"reverse(one)" json:"-"` @@ -276,7 +279,7 @@ func NewProfile() *Profile { } type Post struct { - Id int + ID int `orm:"column(id)"` User *User `orm:"rel(fk)"` Title string `orm:"size(60)"` Content string `orm:"type(text)"` @@ -297,7 +300,7 @@ func NewPost() *Post { } type Tag struct { - Id int + ID int `orm:"column(id)"` Name string `orm:"size(30)"` BestPost *Post `orm:"rel(one);null"` Posts []*Post `orm:"reverse(many)" json:"-"` @@ -309,7 +312,7 @@ func NewTag() *Tag { } type PostTags struct { - Id int + ID int `orm:"column(id)"` Post *Post `orm:"rel(fk)"` Tag *Tag `orm:"rel(fk)"` } @@ -319,7 +322,7 @@ func (m *PostTags) TableName() string { } type Comment struct { - Id int + ID int `orm:"column(id)"` Post *Post `orm:"rel(fk);column(post)"` Content string `orm:"type(text)"` Parent *Comment `orm:"null;rel(fk)"` @@ -331,6 +334,24 @@ func NewComment() *Comment { return obj } +type Group struct { + ID int `orm:"column(gid);size(32)"` + Name string + Permissions []*Permission `orm:"reverse(many)" json:"-"` +} + +type Permission struct { + ID int `orm:"column(id)"` + Name string + Groups []*Group `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.GroupPermissions)"` +} + +type GroupPermissions struct { + ID int `orm:"column(id)"` + Group *Group `orm:"rel(fk)"` + Permission *Permission `orm:"rel(fk)"` +} + var DBARGS = struct { Driver string Source string @@ -345,6 +366,7 @@ var ( IsMysql = DBARGS.Driver == "mysql" IsSqlite = DBARGS.Driver == "sqlite3" IsPostgres = DBARGS.Driver == "postgres" + IsTidb = DBARGS.Driver == "tidb" ) var ( @@ -364,6 +386,7 @@ Default DB Drivers. mysql: https://github.com/go-sql-driver/mysql sqlite3: https://github.com/mattn/go-sqlite3 postgres: https://github.com/lib/pq +tidb: https://github.com/pingcap/tidb usage: @@ -371,6 +394,7 @@ go get -u github.com/astaxie/beego/orm go get -u github.com/go-sql-driver/mysql go get -u github.com/mattn/go-sqlite3 go get -u github.com/lib/pq +go get -u github.com/pingcap/tidb #### MySQL mysql -u root -e 'create database orm_test;' @@ -390,6 +414,12 @@ psql -c 'create database orm_test;' -U postgres export ORM_DRIVER=postgres export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" go test -v github.com/astaxie/beego/orm + +#### TiDB +export ORM_DRIVER=tidb +export ORM_SOURCE='memory://test/test' +go test -v github.com/astaxie/beego/orm + `) os.Exit(2) } @@ -397,7 +427,7 @@ go test -v github.com/astaxie/beego/orm RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20) alias := getDbAlias("default") - if alias.Driver == DR_MySQL { + if alias.Driver == DRMySQL { alias.Engine = "INNODB" } diff --git a/orm/orm.go b/orm/orm.go index f881433b..d00d6d03 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package orm provide ORM for MySQL/PostgreSQL/sqlite // Simple Usage // // package main @@ -59,12 +60,13 @@ import ( "time" ) +// DebugQueries define the debug const ( - Debug_Queries = iota + DebugQueries = iota ) +// Define common vars var ( - // DebugLevel = Debug_Queries Debug = false DebugLog = NewLog(os.Stderr) DefaultRowsLimit = 1000 @@ -79,7 +81,10 @@ var ( ErrNotImplement = errors.New("have not implement") ) +// Params stores the Params type Params map[string]interface{} + +// ParamsList stores paramslist type ParamsList []interface{} type orm struct { @@ -188,7 +193,7 @@ func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) { o.setPk(mi, ind, id) - cnt += 1 + cnt++ } } else { mi, _ := o.getMiInd(sind.Index(0).Interface(), false) @@ -489,7 +494,7 @@ func (o *orm) Driver() Driver { return driver(o.alias.Name) } -// create new orm +// NewOrm create new orm func NewOrm() Ormer { BootStrap() // execute only once @@ -501,7 +506,7 @@ func NewOrm() Ormer { return o } -// create a new ormer object with specify *sql.DB for query +// NewOrmWithDB create a new ormer object with specify *sql.DB for query func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { var al *alias diff --git a/orm/orm_conds.go b/orm/orm_conds.go index 6344653a..b2eae418 100644 --- a/orm/orm_conds.go +++ b/orm/orm_conds.go @@ -19,6 +19,7 @@ import ( "strings" ) +// ExprSep define the expression seperation const ( ExprSep = "__" ) @@ -32,19 +33,19 @@ type condValue struct { isCond bool } -// condition struct. +// Condition struct. // work for WHERE conditions. type Condition struct { params []condValue } -// return new condition struct +// NewCondition return new condition struct func NewCondition() *Condition { c := &Condition{} return c } -// add expression to condition +// And add expression to condition func (c Condition) And(expr string, args ...interface{}) *Condition { if expr == "" || len(args) == 0 { panic(fmt.Errorf(" args cannot empty")) @@ -53,7 +54,7 @@ func (c Condition) And(expr string, args ...interface{}) *Condition { return &c } -// add NOT expression to condition +// AndNot add NOT expression to condition func (c Condition) AndNot(expr string, args ...interface{}) *Condition { if expr == "" || len(args) == 0 { panic(fmt.Errorf(" args cannot empty")) @@ -62,7 +63,7 @@ func (c Condition) AndNot(expr string, args ...interface{}) *Condition { return &c } -// combine a condition to current condition +// AndCond combine a condition to current condition func (c *Condition) AndCond(cond *Condition) *Condition { c = c.clone() if c == cond { @@ -74,7 +75,7 @@ func (c *Condition) AndCond(cond *Condition) *Condition { return c } -// add OR expression to condition +// Or add OR expression to condition func (c Condition) Or(expr string, args ...interface{}) *Condition { if expr == "" || len(args) == 0 { panic(fmt.Errorf(" args cannot empty")) @@ -83,7 +84,7 @@ func (c Condition) Or(expr string, args ...interface{}) *Condition { return &c } -// add OR NOT expression to condition +// OrNot add OR NOT expression to condition func (c Condition) OrNot(expr string, args ...interface{}) *Condition { if expr == "" || len(args) == 0 { panic(fmt.Errorf(" args cannot empty")) @@ -92,7 +93,7 @@ func (c Condition) OrNot(expr string, args ...interface{}) *Condition { return &c } -// combine a OR condition to current condition +// OrCond combine a OR condition to current condition func (c *Condition) OrCond(cond *Condition) *Condition { c = c.clone() if c == cond { @@ -104,12 +105,12 @@ func (c *Condition) OrCond(cond *Condition) *Condition { return c } -// check the condition arguments are empty or not. +// IsEmpty check the condition arguments are empty or not. func (c *Condition) IsEmpty() bool { return len(c.params) == 0 } -// clone a condition +// clone clone a condition func (c Condition) clone() *Condition { return &c } diff --git a/orm/orm_log.go b/orm/orm_log.go index 419d8e11..712eb219 100644 --- a/orm/orm_log.go +++ b/orm/orm_log.go @@ -23,11 +23,12 @@ import ( "time" ) +// Log implement the log.Logger type Log struct { *log.Logger } -// set io.Writer to create a Logger. +// NewLog set io.Writer to create a Logger. func NewLog(out io.Writer) *Log { d := new(Log) d.Logger = log.New(out, "[ORM]", 1e9) @@ -41,7 +42,7 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error if err != nil { flag = "FAIL" } - con := fmt.Sprintf(" - %s - [Queries/%s] - [%s / %11s / %7.1fms] - [%s]", t.Format(format_DateTime), alias.Name, flag, operaton, elsp, query) + con := fmt.Sprintf(" - %s - [Queries/%s] - [%s / %11s / %7.1fms] - [%s]", t.Format(formatDateTime), alias.Name, flag, operaton, elsp, query) cons := make([]string, 0, len(args)) for _, arg := range args { cons = append(cons, fmt.Sprintf("%v", arg)) diff --git a/orm/orm_querym2m.go b/orm/orm_querym2m.go index 1eaccf72..60c77cdf 100644 --- a/orm/orm_querym2m.go +++ b/orm/orm_querym2m.go @@ -14,9 +14,7 @@ package orm -import ( - "reflect" -) +import "reflect" // model to model struct type queryM2M struct { @@ -44,7 +42,21 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { dbase := orm.alias.DbBaser var models []interface{} + var other_values []interface{} + var other_names []string + for _, colname := range mi.fields.dbcols { + if colname != mfi.column && colname != rfi.column && colname != fi.mi.fields.pk.column && + mi.fields.columns[colname] != mi.fields.pk { + other_names = append(other_names, colname) + } + } + for i, md := range mds { + if reflect.Indirect(reflect.ValueOf(md)).Kind() != reflect.Struct && i > 0 { + other_values = append(other_values, md) + mds = append(mds[:i], mds[i+1:]...) + } + } for _, md := range mds { val := reflect.ValueOf(md) if val.Kind() == reflect.Slice || val.Kind() == reflect.Array { @@ -67,11 +79,9 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { names := []string{mfi.column, rfi.column} values := make([]interface{}, 0, len(models)*2) - for _, md := range models { ind := reflect.Indirect(reflect.ValueOf(md)) - var v2 interface{} if ind.Kind() != reflect.Struct { v2 = ind.Interface() @@ -81,11 +91,11 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { panic(ErrMissPK) } } - values = append(values, v1, v2) } - + names = append(names, other_names...) + values = append(values, other_values...) return dbase.InsertValue(orm.db, mi, true, names, values) } diff --git a/orm/orm_queryset.go b/orm/orm_queryset.go index 5cc47617..802a1fe0 100644 --- a/orm/orm_queryset.go +++ b/orm/orm_queryset.go @@ -25,11 +25,12 @@ type colValue struct { type operator int +// define Col operations const ( - Col_Add operator = iota - Col_Minus - Col_Multiply - Col_Except + ColAdd operator = iota + ColMinus + ColMultiply + ColExcept ) // ColValue do the field raw changes. e.g Nums = Nums + 10. usage: @@ -38,7 +39,7 @@ const ( // } func ColValue(opt operator, value interface{}) interface{} { switch opt { - case Col_Add, Col_Minus, Col_Multiply, Col_Except: + case ColAdd, ColMinus, ColMultiply, ColExcept: default: panic(fmt.Errorf("orm.ColValue wrong operator")) } @@ -60,7 +61,9 @@ type querySet struct { relDepth int limit int64 offset int64 + groups []string orders []string + distinct bool orm *orm } @@ -105,6 +108,12 @@ func (o querySet) Offset(offset interface{}) QuerySeter { return &o } +// add GROUP expression +func (o querySet) GroupBy(exprs ...string) QuerySeter { + o.groups = exprs + return &o +} + // add ORDER expression. // "column" means ASC, "-column" means DESC. func (o querySet) OrderBy(exprs ...string) QuerySeter { @@ -112,24 +121,30 @@ func (o querySet) OrderBy(exprs ...string) QuerySeter { return &o } +// add DISTINCT to SELECT +func (o querySet) Distinct() QuerySeter { + o.distinct = true + return &o +} + // set relation model to query together. // it will query relation models and assign to parent model. func (o querySet) RelatedSel(params ...interface{}) QuerySeter { - if len(params) == 0 { - o.relDepth = DefaultRelsDepth - } else { - for _, p := range params { - switch val := p.(type) { - case string: - o.related = append(o.related, val) - case int: - o.relDepth = val - default: - panic(fmt.Errorf(" wrong param kind: %v", val)) - } - } - } - return &o + if len(params) == 0 { + o.relDepth = DefaultRelsDepth + } else { + for _, p := range params { + switch val := p.(type) { + case string: + o.related = append(o.related, val) + case int: + o.relDepth = val + default: + panic(fmt.Errorf(" wrong param kind: %v", val)) + } + } + } + return &o } // set condition to QuerySeter. diff --git a/orm/orm_raw.go b/orm/orm_raw.go index 1452d6fc..cbb18064 100644 --- a/orm/orm_raw.go +++ b/orm/orm_raw.go @@ -165,14 +165,14 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { if str != "" { if len(str) >= 19 { str = str[:19] - t, err := time.ParseInLocation(format_DateTime, str, o.orm.alias.TZ) + t, err := time.ParseInLocation(formatDateTime, str, o.orm.alias.TZ) if err == nil { t = t.In(DefaultTimeLoc) ind.Set(reflect.ValueOf(t)) } } else if len(str) >= 10 { str = str[:10] - t, err := time.ParseInLocation(format_Date, str, DefaultTimeLoc) + t, err := time.ParseInLocation(formatDate, str, DefaultTimeLoc) if err == nil { ind.Set(reflect.ValueOf(t)) } @@ -255,12 +255,13 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr // query data and map to container func (o *rawSet) QueryRow(containers ...interface{}) error { - refs := make([]interface{}, 0, len(containers)) - sInds := make([]reflect.Value, 0) - eTyps := make([]reflect.Type, 0) - + var ( + refs = make([]interface{}, 0, len(containers)) + sInds []reflect.Value + eTyps []reflect.Type + sMi *modelInfo + ) structMode := false - var sMi *modelInfo for _, container := range containers { val := reflect.ValueOf(container) ind := reflect.Indirect(val) @@ -385,12 +386,13 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { // query data rows and map to container func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { - refs := make([]interface{}, 0, len(containers)) - sInds := make([]reflect.Value, 0) - eTyps := make([]reflect.Type, 0) - + var ( + refs = make([]interface{}, 0, len(containers)) + sInds []reflect.Value + eTyps []reflect.Type + sMi *modelInfo + ) structMode := false - var sMi *modelInfo for _, container := range containers { val := reflect.ValueOf(container) sInd := reflect.Indirect(val) @@ -557,10 +559,9 @@ func (o *rawSet) readValues(container interface{}, needCols []string) (int64, er args := getFlatParams(nil, o.args, o.orm.alias.TZ) var rs *sql.Rows - if r, err := o.orm.db.Query(query, args...); err != nil { + rs, err := o.orm.db.Query(query, args...) + if err != nil { return 0, err - } else { - rs = r } defer rs.Close() @@ -574,30 +575,30 @@ func (o *rawSet) readValues(container interface{}, needCols []string) (int64, er for rs.Next() { if cnt == 0 { - if columns, err := rs.Columns(); err != nil { + columns, err := rs.Columns() + if err != nil { return 0, err + } + if len(needCols) > 0 { + indexs = make([]int, 0, len(needCols)) } else { + indexs = make([]int, 0, len(columns)) + } + + cols = columns + refs = make([]interface{}, len(cols)) + for i := range refs { + var ref sql.NullString + refs[i] = &ref + if len(needCols) > 0 { - indexs = make([]int, 0, len(needCols)) - } else { - indexs = make([]int, 0, len(columns)) - } - - cols = columns - refs = make([]interface{}, len(cols)) - for i := range refs { - var ref sql.NullString - refs[i] = &ref - - if len(needCols) > 0 { - for _, c := range needCols { - if c == cols[i] { - indexs = append(indexs, i) - } + for _, c := range needCols { + if c == cols[i] { + indexs = append(indexs, i) } - } else { - indexs = append(indexs, i) } + } else { + indexs = append(indexs, i) } } } @@ -684,11 +685,9 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in args := getFlatParams(nil, o.args, o.orm.alias.TZ) - var rs *sql.Rows - if r, err := o.orm.db.Query(query, args...); err != nil { + rs, err := o.orm.db.Query(query, args...) + if err != nil { return 0, err - } else { - rs = r } defer rs.Close() @@ -706,32 +705,29 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in for rs.Next() { if cnt == 0 { - if columns, err := rs.Columns(); err != nil { + columns, err := rs.Columns() + if err != nil { return 0, err - } else { - cols = columns - refs = make([]interface{}, len(cols)) - for i := range refs { - if keyCol == cols[i] { - keyIndex = i - } - - if typ == 1 || keyIndex == i { - var ref sql.NullString - refs[i] = &ref - } else { - var ref interface{} - refs[i] = &ref - } - - if valueCol == cols[i] { - valueIndex = i - } + } + cols = columns + refs = make([]interface{}, len(cols)) + for i := range refs { + if keyCol == cols[i] { + keyIndex = i } - - if keyIndex == -1 || valueIndex == -1 { - panic(fmt.Errorf(" RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol)) + if typ == 1 || keyIndex == i { + var ref sql.NullString + refs[i] = &ref + } else { + var ref interface{} + refs[i] = &ref } + if valueCol == cols[i] { + valueIndex = i + } + } + if keyIndex == -1 || valueIndex == -1 { + panic(fmt.Errorf(" RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol)) } } diff --git a/orm/orm_test.go b/orm/orm_test.go index 14eadabd..d6f6c7a9 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -31,13 +31,26 @@ import ( var _ = os.PathSeparator var ( - test_Date = format_Date + " -0700" - test_DateTime = format_DateTime + " -0700" + testDate = formatDate + " -0700" + testDateTime = formatDateTime + " -0700" ) -func ValuesCompare(is bool, a interface{}, args ...interface{}) (err error, ok bool) { +type argAny []interface{} + +// get interface by index from interface slice +func (a argAny) Get(i int, args ...interface{}) (r interface{}) { + if i >= 0 && i < len(a) { + r = a[i] + } + if len(args) > 0 { + r = args[0] + } + return +} + +func ValuesCompare(is bool, a interface{}, args ...interface{}) (ok bool, err error) { if len(args) == 0 { - return fmt.Errorf("miss args"), false + return false, fmt.Errorf("miss args") } b := args[0] arg := argAny(args) @@ -71,21 +84,21 @@ func ValuesCompare(is bool, a interface{}, args ...interface{}) (err error, ok b wrongArg: if err != nil { - return err, false + return false, err } - return nil, true + return true, nil } func AssertIs(a interface{}, args ...interface{}) error { - if err, ok := ValuesCompare(true, a, args...); ok == false { + if ok, err := ValuesCompare(true, a, args...); ok == false { return err } return nil } func AssertNot(a interface{}, args ...interface{}) error { - if err, ok := ValuesCompare(false, a, args...); ok == false { + if ok, err := ValuesCompare(false, a, args...); ok == false { return err } return nil @@ -171,8 +184,11 @@ func TestSyncDb(t *testing.T) { RegisterModel(new(Comment)) RegisterModel(new(UserBig)) RegisterModel(new(PostTags)) + RegisterModel(new(Group)) + RegisterModel(new(Permission)) + RegisterModel(new(GroupPermissions)) - err := RunSyncdb("default", true, false) + err := RunSyncdb("default", true, Debug) throwFail(t, err) modelCache.clean() @@ -187,6 +203,9 @@ func TestRegisterModels(t *testing.T) { RegisterModel(new(Comment)) RegisterModel(new(UserBig)) RegisterModel(new(PostTags)) + RegisterModel(new(Group)) + RegisterModel(new(Permission)) + RegisterModel(new(GroupPermissions)) BootStrap() @@ -208,7 +227,7 @@ func TestModelSyntax(t *testing.T) { } } -var Data_Values = map[string]interface{}{ +var DataValues = map[string]interface{}{ "Boolean": true, "Char": "char", "Text": "text", @@ -235,7 +254,7 @@ func TestDataTypes(t *testing.T) { d := Data{} ind := reflect.Indirect(reflect.ValueOf(&d)) - for name, value := range Data_Values { + for name, value := range DataValues { e := ind.FieldByName(name) e.Set(reflect.ValueOf(value)) } @@ -244,22 +263,22 @@ func TestDataTypes(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(id, 1)) - d = Data{Id: 1} + d = Data{ID: 1} err = dORM.Read(&d) throwFail(t, err) ind = reflect.Indirect(reflect.ValueOf(&d)) - for name, value := range Data_Values { + for name, value := range DataValues { e := ind.FieldByName(name) vu := e.Interface() switch name { case "Date": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date) - value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date) + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) case "DateTime": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) } throwFail(t, AssertIs(vu == value, true), value, vu) } @@ -278,7 +297,7 @@ func TestNullDataTypes(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(id, 1)) - d = DataNull{Id: 1} + d = DataNull{ID: 1} err = dORM.Read(&d) throwFail(t, err) @@ -309,7 +328,7 @@ func TestNullDataTypes(t *testing.T) { _, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec() throwFail(t, err) - d = DataNull{Id: 2} + d = DataNull{ID: 2} err = dORM.Read(&d) throwFail(t, err) @@ -362,7 +381,7 @@ func TestNullDataTypes(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(id, 3)) - d = DataNull{Id: 3} + d = DataNull{ID: 3} err = dORM.Read(&d) throwFail(t, err) @@ -402,7 +421,7 @@ func TestDataCustomTypes(t *testing.T) { d := DataCustom{} ind := reflect.Indirect(reflect.ValueOf(&d)) - for name, value := range Data_Values { + for name, value := range DataValues { e := ind.FieldByName(name) if !e.IsValid() { continue @@ -414,13 +433,13 @@ func TestDataCustomTypes(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(id, 1)) - d = DataCustom{Id: 1} + d = DataCustom{ID: 1} err = dORM.Read(&d) throwFail(t, err) ind = reflect.Indirect(reflect.ValueOf(&d)) - for name, value := range Data_Values { + for name, value := range DataValues { e := ind.FieldByName(name) if !e.IsValid() { continue @@ -451,7 +470,7 @@ func TestCRUD(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(id, 1)) - u := &User{Id: user.Id} + u := &User{ID: user.ID} err = dORM.Read(u) throwFail(t, err) @@ -461,8 +480,8 @@ func TestCRUD(t *testing.T) { throwFail(t, AssertIs(u.Status, 3)) throwFail(t, AssertIs(u.IsStaff, true)) throwFail(t, AssertIs(u.IsActive, true)) - throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), user.Created.In(DefaultTimeLoc), test_Date)) - throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), user.Updated.In(DefaultTimeLoc), test_DateTime)) + throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), user.Created.In(DefaultTimeLoc), testDate)) + throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), user.Updated.In(DefaultTimeLoc), testDateTime)) user.UserName = "astaxie" user.Profile = profile @@ -470,11 +489,11 @@ func TestCRUD(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(num, 1)) - u = &User{Id: user.Id} + u = &User{ID: user.ID} err = dORM.Read(u) throwFailNow(t, err) throwFail(t, AssertIs(u.UserName, "astaxie")) - throwFail(t, AssertIs(u.Profile.Id, profile.Id)) + throwFail(t, AssertIs(u.Profile.ID, profile.ID)) u = &User{UserName: "astaxie", Password: "pass"} err = dORM.Read(u, "UserName") @@ -487,7 +506,7 @@ func TestCRUD(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(num, 1)) - u = &User{Id: user.Id} + u = &User{ID: user.ID} err = dORM.Read(u) throwFailNow(t, err) throwFail(t, AssertIs(u.UserName, "QQ")) @@ -497,7 +516,7 @@ func TestCRUD(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(num, 1)) - u = &User{Id: user.Id} + u = &User{ID: user.ID} err = dORM.Read(u) throwFail(t, err) throwFail(t, AssertIs(true, u.Profile == nil)) @@ -506,7 +525,7 @@ func TestCRUD(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(num, 1)) - u = &User{Id: 100} + u = &User{ID: 100} err = dORM.Read(u) throwFail(t, AssertIs(err, ErrNoRows)) @@ -516,7 +535,7 @@ func TestCRUD(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(id, 1)) - ub = UserBig{Id: 1} + ub = UserBig{ID: 1} err = dORM.Read(&ub) throwFail(t, err) throwFail(t, AssertIs(ub.Name, "name")) @@ -586,7 +605,7 @@ func TestInsertTestData(t *testing.T) { throwFail(t, AssertIs(id, 4)) tags := []*Tag{ - {Name: "golang", BestPost: &Post{Id: 2}}, + {Name: "golang", BestPost: &Post{ID: 2}}, {Name: "example"}, {Name: "format"}, {Name: "c++"}, @@ -635,10 +654,47 @@ The program—and web server—godoc processes Go source files to extract docume throwFail(t, err) throwFail(t, AssertIs(id > 0, true)) } + + permissions := []*Permission{ + {Name: "writePosts"}, + {Name: "readComments"}, + {Name: "readPosts"}, + } + + groups := []*Group{ + { + Name: "admins", + Permissions: []*Permission{permissions[0], permissions[1], permissions[2]}, + }, + { + Name: "users", + Permissions: []*Permission{permissions[1], permissions[2]}, + }, + } + + for _, permission := range permissions { + id, err := dORM.Insert(permission) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + } + + for _, group := range groups { + _, err := dORM.Insert(group) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + num := len(group.Permissions) + if num > 0 { + nums, err := dORM.QueryM2M(group, "permissions").Add(group.Permissions) + throwFailNow(t, err) + throwFailNow(t, AssertIs(nums, num)) + } + } + } func TestCustomField(t *testing.T) { - user := User{Id: 2} + user := User{ID: 2} err := dORM.Read(&user) throwFailNow(t, err) @@ -648,7 +704,7 @@ func TestCustomField(t *testing.T) { _, err = dORM.Update(&user, "Langs", "Extra") throwFailNow(t, err) - user = User{Id: 2} + user = User{ID: 2} err = dORM.Read(&user) throwFailNow(t, err) throwFailNow(t, AssertIs(len(user.Langs), 2)) @@ -702,7 +758,7 @@ func TestOperators(t *testing.T) { var shouldNum int - if IsSqlite { + if IsSqlite || IsTidb { shouldNum = 2 } else { shouldNum = 0 @@ -740,7 +796,7 @@ func TestOperators(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(num, 1)) - if IsSqlite { + if IsSqlite || IsTidb { shouldNum = 1 } else { shouldNum = 0 @@ -758,7 +814,7 @@ func TestOperators(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(num, 2)) - if IsSqlite { + if IsSqlite || IsTidb { shouldNum = 2 } else { shouldNum = 0 @@ -889,9 +945,9 @@ func TestAll(t *testing.T) { throwFailNow(t, AssertIs(users2[0].UserName, "slene")) throwFailNow(t, AssertIs(users2[1].UserName, "astaxie")) throwFailNow(t, AssertIs(users2[2].UserName, "nobody")) - throwFailNow(t, AssertIs(users2[0].Id, 0)) - throwFailNow(t, AssertIs(users2[1].Id, 0)) - throwFailNow(t, AssertIs(users2[2].Id, 0)) + throwFailNow(t, AssertIs(users2[0].ID, 0)) + throwFailNow(t, AssertIs(users2[1].ID, 0)) + throwFailNow(t, AssertIs(users2[2].ID, 0)) throwFailNow(t, AssertIs(users2[0].Profile == nil, false)) throwFailNow(t, AssertIs(users2[1].Profile == nil, false)) throwFailNow(t, AssertIs(users2[2].Profile == nil, true)) @@ -986,6 +1042,10 @@ func TestValuesFlat(t *testing.T) { } func TestRelatedSel(t *testing.T) { + if IsTidb { + // Skip it. TiDB does not support relation now. + return + } qs := dORM.QueryTable("user") num, err := qs.Filter("profile__age", 28).Count() throwFail(t, err) @@ -1112,7 +1172,7 @@ func TestReverseQuery(t *testing.T) { func TestLoadRelated(t *testing.T) { // load reverse foreign key - user := User{Id: 3} + user := User{ID: 3} err := dORM.Read(&user) throwFailNow(t, err) @@ -1121,7 +1181,7 @@ func TestLoadRelated(t *testing.T) { throwFailNow(t, err) throwFailNow(t, AssertIs(num, 2)) throwFailNow(t, AssertIs(len(user.Posts), 2)) - throwFailNow(t, AssertIs(user.Posts[0].User.Id, 3)) + throwFailNow(t, AssertIs(user.Posts[0].User.ID, 3)) num, err = dORM.LoadRelated(&user, "Posts", true) throwFailNow(t, err) @@ -1143,8 +1203,8 @@ func TestLoadRelated(t *testing.T) { throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) // load reverse one to one - profile := Profile{Id: 3} - profile.BestPost = &Post{Id: 2} + profile := Profile{ID: 3} + profile.BestPost = &Post{ID: 2} num, err = dORM.Update(&profile, "BestPost") throwFailNow(t, err) throwFailNow(t, AssertIs(num, 1)) @@ -1183,7 +1243,7 @@ func TestLoadRelated(t *testing.T) { throwFailNow(t, AssertIs(user.Profile.BestPost == nil, false)) throwFailNow(t, AssertIs(user.Profile.BestPost.Title, "Examples")) - post := Post{Id: 2} + post := Post{ID: 2} // load rel foreign key err = dORM.Read(&post) @@ -1204,7 +1264,7 @@ func TestLoadRelated(t *testing.T) { throwFailNow(t, AssertIs(post.User.Profile.Age, 30)) // load rel m2m - post = Post{Id: 2} + post = Post{ID: 2} err = dORM.Read(&post) throwFailNow(t, err) @@ -1224,7 +1284,7 @@ func TestLoadRelated(t *testing.T) { throwFailNow(t, AssertIs(post.Tags[0].BestPost.User.UserName, "astaxie")) // load reverse m2m - tag := Tag{Id: 1} + tag := Tag{ID: 1} err = dORM.Read(&tag) throwFailNow(t, err) @@ -1233,19 +1293,19 @@ func TestLoadRelated(t *testing.T) { throwFailNow(t, err) throwFailNow(t, AssertIs(num, 3)) throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) - throwFailNow(t, AssertIs(tag.Posts[0].User.Id, 2)) + throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2)) throwFailNow(t, AssertIs(tag.Posts[0].User.Profile == nil, true)) num, err = dORM.LoadRelated(&tag, "Posts", true) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 3)) throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) - throwFailNow(t, AssertIs(tag.Posts[0].User.Id, 2)) + throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2)) throwFailNow(t, AssertIs(tag.Posts[0].User.UserName, "slene")) } func TestQueryM2M(t *testing.T) { - post := Post{Id: 4} + post := Post{ID: 4} m2m := dORM.QueryM2M(&post, "Tags") tag1 := []*Tag{{Name: "TestTag1"}, {Name: "TestTag2"}} @@ -1319,7 +1379,7 @@ func TestQueryM2M(t *testing.T) { for _, post := range posts { p := post.(*Post) - p.User = &User{Id: 1} + p.User = &User{ID: 1} _, err := dORM.Insert(post) throwFailNow(t, err) } @@ -1394,6 +1454,18 @@ func TestQueryRelate(t *testing.T) { // throwFailNow(t, AssertIs(num, 2)) } +func TestPkManyRelated(t *testing.T) { + permission := &Permission{Name: "readPosts"} + err := dORM.Read(permission, "Name") + throwFailNow(t, err) + + var groups []*Group + qs := dORM.QueryTable("Group") + num, err := qs.Filter("Permissions__Permission", permission.ID).All(&groups) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) +} + func TestPrepareInsert(t *testing.T) { qs := dORM.QueryTable("user") i, err := qs.PrepareInsert() @@ -1459,10 +1531,10 @@ func TestRawQueryRow(t *testing.T) { Decimal float64 ) - data_values := make(map[string]interface{}, len(Data_Values)) + dataValues := make(map[string]interface{}, len(DataValues)) - for k, v := range Data_Values { - data_values[strings.ToLower(k)] = v + for k, v := range DataValues { + dataValues[strings.ToLower(k)] = v } Q := dDbBaser.TableQuote() @@ -1488,14 +1560,14 @@ func TestRawQueryRow(t *testing.T) { throwFail(t, AssertIs(id, 1)) case "date": v = v.(time.Time).In(DefaultTimeLoc) - value := data_values[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, value, test_Date)) + value := dataValues[col].(time.Time).In(DefaultTimeLoc) + throwFail(t, AssertIs(v, value, testDate)) case "datetime": v = v.(time.Time).In(DefaultTimeLoc) - value := data_values[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, value, test_DateTime)) + value := dataValues[col].(time.Time).In(DefaultTimeLoc) + throwFail(t, AssertIs(v, value, testDateTime)) default: - throwFail(t, AssertIs(v, data_values[col])) + throwFail(t, AssertIs(v, dataValues[col])) } } @@ -1529,16 +1601,16 @@ func TestQueryRows(t *testing.T) { ind := reflect.Indirect(reflect.ValueOf(datas[0])) - for name, value := range Data_Values { + for name, value := range DataValues { e := ind.FieldByName(name) vu := e.Interface() switch name { case "Date": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date) - value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date) + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) case "DateTime": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) } throwFail(t, AssertIs(vu == value, true), value, vu) } @@ -1553,16 +1625,16 @@ func TestQueryRows(t *testing.T) { ind = reflect.Indirect(reflect.ValueOf(datas2[0])) - for name, value := range Data_Values { + for name, value := range DataValues { e := ind.FieldByName(name) vu := e.Interface() switch name { case "Date": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date) - value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date) + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) case "DateTime": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) } throwFail(t, AssertIs(vu == value, true), value, vu) } @@ -1699,25 +1771,25 @@ func TestUpdate(t *testing.T) { throwFail(t, AssertIs(num, 1)) num, err = qs.Filter("user_name", "slene").Update(Params{ - "Nums": ColValue(Col_Add, 100), + "Nums": ColValue(ColAdd, 100), }) throwFail(t, err) throwFail(t, AssertIs(num, 1)) num, err = qs.Filter("user_name", "slene").Update(Params{ - "Nums": ColValue(Col_Minus, 50), + "Nums": ColValue(ColMinus, 50), }) throwFail(t, err) throwFail(t, AssertIs(num, 1)) num, err = qs.Filter("user_name", "slene").Update(Params{ - "Nums": ColValue(Col_Multiply, 3), + "Nums": ColValue(ColMultiply, 3), }) throwFail(t, err) throwFail(t, AssertIs(num, 1)) num, err = qs.Filter("user_name", "slene").Update(Params{ - "Nums": ColValue(Col_Except, 5), + "Nums": ColValue(ColExcept, 5), }) throwFail(t, err) throwFail(t, AssertIs(num, 1)) @@ -1838,15 +1910,15 @@ func TestReadOrCreate(t *testing.T) { throwFail(t, AssertIs(u.Status, 7)) throwFail(t, AssertIs(u.IsStaff, false)) throwFail(t, AssertIs(u.IsActive, true)) - throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), test_Date)) - throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), test_DateTime)) + throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), testDate)) + throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), testDateTime)) nu := &User{UserName: u.UserName, Email: "someotheremail@gmail.com"} created, pk, err = dORM.ReadOrCreate(nu, "UserName") throwFail(t, err) throwFail(t, AssertIs(created, false)) - throwFail(t, AssertIs(nu.Id, u.Id)) - throwFail(t, AssertIs(pk, u.Id)) + throwFail(t, AssertIs(nu.ID, u.ID)) + throwFail(t, AssertIs(pk, u.ID)) throwFail(t, AssertIs(nu.UserName, u.UserName)) throwFail(t, AssertIs(nu.Email, u.Email)) // should contain the value in the table, not the one specified above throwFail(t, AssertIs(nu.Password, u.Password)) diff --git a/orm/qb.go b/orm/qb.go index efe368db..9f778916 100644 --- a/orm/qb.go +++ b/orm/qb.go @@ -16,6 +16,7 @@ package orm import "errors" +// QueryBuilder is the Query builder interface type QueryBuilder interface { Select(fields ...string) QueryBuilder From(tables ...string) QueryBuilder @@ -43,15 +44,18 @@ type QueryBuilder interface { String() string } +// NewQueryBuilder return the QueryBuilder func NewQueryBuilder(driver string) (qb QueryBuilder, err error) { if driver == "mysql" { qb = new(MySQLQueryBuilder) + } else if driver == "tidb" { + qb = new(TiDBQueryBuilder) } else if driver == "postgres" { - err = errors.New("postgres query builder is not supported yet!") + err = errors.New("postgres query builder is not supported yet") } else if driver == "sqlite" { - err = errors.New("sqlite query builder is not supported yet!") + err = errors.New("sqlite query builder is not supported yet") } else { - err = errors.New("unknown driver for query builder!") + err = errors.New("unknown driver for query builder") } return } diff --git a/orm/qb_mysql.go b/orm/qb_mysql.go index 9ce9b7d9..f6d1e185 100644 --- a/orm/qb_mysql.go +++ b/orm/qb_mysql.go @@ -20,134 +20,160 @@ import ( "strings" ) -const COMMA_SPACE = ", " +// CommaSpace is the seperation +const CommaSpace = ", " +// MySQLQueryBuilder is the SQL build type MySQLQueryBuilder struct { Tokens []string } +// Select will join the fields func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, COMMA_SPACE)) + qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace)) return qb } +// From join the tables func (qb *MySQLQueryBuilder) From(tables ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, COMMA_SPACE)) + qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace)) return qb } +// InnerJoin INNER JOIN the table func (qb *MySQLQueryBuilder) InnerJoin(table string) QueryBuilder { qb.Tokens = append(qb.Tokens, "INNER JOIN", table) return qb } +// LeftJoin LEFT JOIN the table func (qb *MySQLQueryBuilder) LeftJoin(table string) QueryBuilder { qb.Tokens = append(qb.Tokens, "LEFT JOIN", table) return qb } +// RightJoin RIGHT JOIN the table func (qb *MySQLQueryBuilder) RightJoin(table string) QueryBuilder { qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table) return qb } +// On join with on cond func (qb *MySQLQueryBuilder) On(cond string) QueryBuilder { qb.Tokens = append(qb.Tokens, "ON", cond) return qb } +// Where join the Where cond func (qb *MySQLQueryBuilder) Where(cond string) QueryBuilder { qb.Tokens = append(qb.Tokens, "WHERE", cond) return qb } +// And join the and cond func (qb *MySQLQueryBuilder) And(cond string) QueryBuilder { qb.Tokens = append(qb.Tokens, "AND", cond) return qb } +// Or join the or cond func (qb *MySQLQueryBuilder) Or(cond string) QueryBuilder { qb.Tokens = append(qb.Tokens, "OR", cond) return qb } +// In join the IN (vals) func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, COMMA_SPACE), ")") + qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")") return qb } +// OrderBy join the Order by fields func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, COMMA_SPACE)) + qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace)) return qb } +// Asc join the asc func (qb *MySQLQueryBuilder) Asc() QueryBuilder { qb.Tokens = append(qb.Tokens, "ASC") return qb } +// Desc join the desc func (qb *MySQLQueryBuilder) Desc() QueryBuilder { qb.Tokens = append(qb.Tokens, "DESC") return qb } +// Limit join the limit num func (qb *MySQLQueryBuilder) Limit(limit int) QueryBuilder { qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit)) return qb } +// Offset join the offset num func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder { qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset)) return qb } +// GroupBy join the Group by fields func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, COMMA_SPACE)) + qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace)) return qb } +// Having join the Having cond func (qb *MySQLQueryBuilder) Having(cond string) QueryBuilder { qb.Tokens = append(qb.Tokens, "HAVING", cond) return qb } +// Update join the update table func (qb *MySQLQueryBuilder) Update(tables ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, COMMA_SPACE)) + qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace)) return qb } +// Set join the set kv func (qb *MySQLQueryBuilder) Set(kv ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, COMMA_SPACE)) + qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace)) return qb } +// Delete join the Delete tables func (qb *MySQLQueryBuilder) Delete(tables ...string) QueryBuilder { qb.Tokens = append(qb.Tokens, "DELETE") if len(tables) != 0 { - qb.Tokens = append(qb.Tokens, strings.Join(tables, COMMA_SPACE)) + qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace)) } return qb } +// InsertInto join the insert SQL func (qb *MySQLQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { qb.Tokens = append(qb.Tokens, "INSERT INTO", table) if len(fields) != 0 { - fieldsStr := strings.Join(fields, COMMA_SPACE) + fieldsStr := strings.Join(fields, CommaSpace) qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")") } return qb } +// Values join the Values(vals) func (qb *MySQLQueryBuilder) Values(vals ...string) QueryBuilder { - valsStr := strings.Join(vals, COMMA_SPACE) + valsStr := strings.Join(vals, CommaSpace) qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")") return qb } +// Subquery join the sub as alias func (qb *MySQLQueryBuilder) Subquery(sub string, alias string) string { return fmt.Sprintf("(%s) AS %s", sub, alias) } +// String join all Tokens func (qb *MySQLQueryBuilder) String() string { return strings.Join(qb.Tokens, " ") } diff --git a/orm/qb_tidb.go b/orm/qb_tidb.go new file mode 100644 index 00000000..c504049e --- /dev/null +++ b/orm/qb_tidb.go @@ -0,0 +1,176 @@ +// Copyright 2015 TiDB Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "strconv" + "strings" +) + +// TiDBQueryBuilder is the SQL build +type TiDBQueryBuilder struct { + Tokens []string +} + +// Select will join the fields +func (qb *TiDBQueryBuilder) Select(fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace)) + return qb +} + +// From join the tables +func (qb *TiDBQueryBuilder) From(tables ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace)) + return qb +} + +// InnerJoin INNER JOIN the table +func (qb *TiDBQueryBuilder) InnerJoin(table string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "INNER JOIN", table) + return qb +} + +// LeftJoin LEFT JOIN the table +func (qb *TiDBQueryBuilder) LeftJoin(table string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "LEFT JOIN", table) + return qb +} + +// RightJoin RIGHT JOIN the table +func (qb *TiDBQueryBuilder) RightJoin(table string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table) + return qb +} + +// On join with on cond +func (qb *TiDBQueryBuilder) On(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "ON", cond) + return qb +} + +// Where join the Where cond +func (qb *TiDBQueryBuilder) Where(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "WHERE", cond) + return qb +} + +// And join the and cond +func (qb *TiDBQueryBuilder) And(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "AND", cond) + return qb +} + +// Or join the or cond +func (qb *TiDBQueryBuilder) Or(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "OR", cond) + return qb +} + +// In join the IN (vals) +func (qb *TiDBQueryBuilder) In(vals ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")") + return qb +} + +// OrderBy join the Order by fields +func (qb *TiDBQueryBuilder) OrderBy(fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace)) + return qb +} + +// Asc join the asc +func (qb *TiDBQueryBuilder) Asc() QueryBuilder { + qb.Tokens = append(qb.Tokens, "ASC") + return qb +} + +// Desc join the desc +func (qb *TiDBQueryBuilder) Desc() QueryBuilder { + qb.Tokens = append(qb.Tokens, "DESC") + return qb +} + +// Limit join the limit num +func (qb *TiDBQueryBuilder) Limit(limit int) QueryBuilder { + qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit)) + return qb +} + +// Offset join the offset num +func (qb *TiDBQueryBuilder) Offset(offset int) QueryBuilder { + qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset)) + return qb +} + +// GroupBy join the Group by fields +func (qb *TiDBQueryBuilder) GroupBy(fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace)) + return qb +} + +// Having join the Having cond +func (qb *TiDBQueryBuilder) Having(cond string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "HAVING", cond) + return qb +} + +// Update join the update table +func (qb *TiDBQueryBuilder) Update(tables ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace)) + return qb +} + +// Set join the set kv +func (qb *TiDBQueryBuilder) Set(kv ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace)) + return qb +} + +// Delete join the Delete tables +func (qb *TiDBQueryBuilder) Delete(tables ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "DELETE") + if len(tables) != 0 { + qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace)) + } + return qb +} + +// InsertInto join the insert SQL +func (qb *TiDBQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { + qb.Tokens = append(qb.Tokens, "INSERT INTO", table) + if len(fields) != 0 { + fieldsStr := strings.Join(fields, CommaSpace) + qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")") + } + return qb +} + +// Values join the Values(vals) +func (qb *TiDBQueryBuilder) Values(vals ...string) QueryBuilder { + valsStr := strings.Join(vals, CommaSpace) + qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")") + return qb +} + +// Subquery join the sub as alias +func (qb *TiDBQueryBuilder) Subquery(sub string, alias string) string { + return fmt.Sprintf("(%s) AS %s", sub, alias) +} + +// String join all Tokens +func (qb *TiDBQueryBuilder) String() string { + return strings.Join(qb.Tokens, " ") +} diff --git a/orm/types.go b/orm/types.go index b46be4fc..5fac5fed 100644 --- a/orm/types.go +++ b/orm/types.go @@ -20,13 +20,13 @@ import ( "time" ) -// database driver +// Driver define database driver type Driver interface { Name() string Type() DriverType } -// field info +// Fielder define field info type Fielder interface { String() string FieldType() int @@ -34,84 +34,315 @@ type Fielder interface { RawValue() interface{} } -// orm struct +// Ormer define the orm interface type Ormer interface { - Read(interface{}, ...string) error - ReadOrCreate(interface{}, string, ...string) (bool, int64, error) + // read data to model + // for example: + // this will find User by Id field + // u = &User{Id: user.Id} + // err = Ormer.Read(u) + // this will find User by UserName field + // u = &User{UserName: "astaxie", Password: "pass"} + // err = Ormer.Read(u, "UserName") + Read(md interface{}, cols ...string) error + // Try to read a row from the database, or insert one if it doesn't exist + ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) + // insert model data to database + // for example: + // user := new(User) + // id, err = Ormer.Insert(user) + // user must a pointer and Insert will set user's pk field Insert(interface{}) (int64, error) - InsertMulti(int, interface{}) (int64, error) - Update(interface{}, ...string) (int64, error) - Delete(interface{}) (int64, error) - LoadRelated(interface{}, string, ...interface{}) (int64, error) - QueryM2M(interface{}, string) QueryM2Mer - QueryTable(interface{}) QuerySeter - Using(string) error + // insert some models to database + InsertMulti(bulk int, mds interface{}) (int64, error) + // update model to database. + // cols set the columns those want to update. + // find model by Id(pk) field and update columns specified by fields, if cols is null then update all columns + // for example: + // user := User{Id: 2} + // user.Langs = append(user.Langs, "zh-CN", "en-US") + // user.Extra.Name = "beego" + // user.Extra.Data = "orm" + // num, err = Ormer.Update(&user, "Langs", "Extra") + Update(md interface{}, cols ...string) (int64, error) + // delete model in database + Delete(md interface{}) (int64, error) + // load related models to md model. + // args are limit, offset int and order string. + // + // example: + // Ormer.LoadRelated(post,"Tags") + // for _,tag := range post.Tags{...} + //args[0] bool true useDefaultRelsDepth ; false depth 0 + //args[0] int loadRelationDepth + //args[1] int limit default limit 1000 + //args[2] int offset default offset 0 + //args[3] string order for example : "-Id" + // make sure the relation is defined in model struct tags. + LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) + // create a models to models queryer + // for example: + // post := Post{Id: 4} + // m2m := Ormer.QueryM2M(&post, "Tags") + QueryM2M(md interface{}, name string) QueryM2Mer + // return a QuerySeter for table operations. + // table name can be string or struct. + // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), + QueryTable(ptrStructOrTableName interface{}) QuerySeter + // switch to another registered database driver by given name. + Using(name string) error + // begin transaction + // for example: + // o := NewOrm() + // err := o.Begin() + // ... + // err = o.Rollback() Begin() error + // commit transaction Commit() error + // rollback transaction Rollback() error - Raw(string, ...interface{}) RawSeter + // return a raw query seter for raw sql string. + // for example: + // ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec() + // // update user testing's name to slene + Raw(query string, args ...interface{}) RawSeter Driver() Driver } -// insert prepared statement +// Inserter insert prepared statement type Inserter interface { Insert(interface{}) (int64, error) Close() error } -// query seter +// QuerySeter query seter type QuerySeter interface { + // add condition expression to QuerySeter. + // for example: + // filter by UserName == 'slene' + // qs.Filter("UserName", "slene") + // sql : left outer join profile on t0.id1==t1.id2 where t1.age == 28 + // Filter("profile__Age", 28) + // // time compare + // qs.Filter("created", time.Now()) Filter(string, ...interface{}) QuerySeter + // add NOT condition to querySeter. + // have the same usage as Filter Exclude(string, ...interface{}) QuerySeter + // set condition to QuerySeter. + // sql's where condition + // cond := orm.NewCondition() + // cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000) + // //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000 + // num, err := qs.SetCond(cond1).Count() SetCond(*Condition) QuerySeter - Limit(interface{}, ...interface{}) QuerySeter - Offset(interface{}) QuerySeter - OrderBy(...string) QuerySeter - RelatedSel(...interface{}) QuerySeter + // add LIMIT value. + // args[0] means offset, e.g. LIMIT num,offset. + // if Limit <= 0 then Limit will be set to default limit ,eg 1000 + // if QuerySeter doesn't call Limit, the sql's Limit will be set to default limit, eg 1000 + // for example: + // qs.Limit(10, 2) + // // sql-> limit 10 offset 2 + Limit(limit interface{}, args ...interface{}) QuerySeter + // add OFFSET value + // same as Limit function's args[0] + Offset(offset interface{}) QuerySeter + // add ORDER expression. + // "column" means ASC, "-column" means DESC. + // for example: + // qs.OrderBy("-status") + OrderBy(exprs ...string) QuerySeter + // set relation model to query together. + // it will query relation models and assign to parent model. + // for example: + // // will load all related fields use left join . + // qs.RelatedSel().One(&user) + // // will load related field only profile + // qs.RelatedSel("profile").One(&user) + // user.Profile.Age = 32 + RelatedSel(params ...interface{}) QuerySeter + // return QuerySeter execution result number + // for example: + // num, err = qs.Filter("profile__age__gt", 28).Count() Count() (int64, error) + // check result empty or not after QuerySeter executed + // the same as QuerySeter.Count > 0 Exist() bool - Update(Params) (int64, error) + // execute update with parameters + // for example: + // num, err = qs.Filter("user_name", "slene").Update(Params{ + // "Nums": ColValue(Col_Minus, 50), + // }) // user slene's Nums will minus 50 + // num, err = qs.Filter("UserName", "slene").Update(Params{ + // "user_name": "slene2" + // }) // user slene's name will change to slene2 + Update(values Params) (int64, error) + // delete from table + //for example: + // num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete() + // //delete two user who's name is testing1 or testing2 Delete() (int64, error) + // return a insert queryer. + // it can be used in times. + // example: + // i,err := sq.PrepareInsert() + // num, err = i.Insert(&user1) // user table will add one record user1 at once + // num, err = i.Insert(&user2) // user table will add one record user2 at once + // err = i.Close() //don't forget call Close PrepareInsert() (Inserter, error) - All(interface{}, ...string) (int64, error) - One(interface{}, ...string) error - Values(*[]Params, ...string) (int64, error) - ValuesList(*[]ParamsList, ...string) (int64, error) - ValuesFlat(*ParamsList, string) (int64, error) - RowsToMap(*Params, string, string) (int64, error) - RowsToStruct(interface{}, string, string) (int64, error) + // query all data and map to containers. + // cols means the columns when querying. + // for example: + // var users []*User + // qs.All(&users) // users[0],users[1],users[2] ... + All(container interface{}, cols ...string) (int64, error) + // query one row data and map to containers. + // cols means the columns when querying. + // for example: + // var user User + // qs.One(&user) //user.UserName == "slene" + One(container interface{}, cols ...string) error + // query all data and map to []map[string]interface. + // expres means condition expression. + // it converts data to []map[column]value. + // for example: + // var maps []Params + // qs.Values(&maps) //maps[0]["UserName"]=="slene" + Values(results *[]Params, exprs ...string) (int64, error) + // query all data and map to [][]interface + // it converts data to [][column_index]value + // for example: + // var list []ParamsList + // qs.ValuesList(&list) // list[0][1] == "slene" + ValuesList(results *[]ParamsList, exprs ...string) (int64, error) + // query all data and map to []interface. + // it's designed for one column record set, auto change to []value, not [][column]value. + // for example: + // var list ParamsList + // qs.ValuesFlat(&list, "UserName") // list[0] == "slene" + ValuesFlat(result *ParamsList, expr string) (int64, error) + // query all rows into map[string]interface with specify key and value column name. + // keyCol = "name", valueCol = "value" + // table data + // name | value + // total | 100 + // found | 200 + // to map[string]interface{}{ + // "total": 100, + // "found": 200, + // } + RowsToMap(result *Params, keyCol, valueCol string) (int64, error) + // query all rows into struct with specify key and value column name. + // keyCol = "name", valueCol = "value" + // table data + // name | value + // total | 100 + // found | 200 + // to struct { + // Total int + // Found int + // } + RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) } -// model to model query struct +// QueryM2Mer model to model query struct +// all operations are on the m2m table only, will not affect the origin model table type QueryM2Mer interface { + // add models to origin models when creating queryM2M. + // example: + // m2m := orm.QueryM2M(post,"Tag") + // m2m.Add(&Tag1{},&Tag2{}) + // for _,tag := range post.Tags{}{ ... } + // param could also be any of the follow + // []*Tag{{Id:3,Name: "TestTag1"}, {Id:4,Name: "TestTag2"}} + // &Tag{Id:5,Name: "TestTag3"} + // []interface{}{&Tag{Id:6,Name: "TestTag4"}} + // insert one or more rows to m2m table + // make sure the relation is defined in post model struct tag. Add(...interface{}) (int64, error) + // remove models following the origin model relationship + // only delete rows from m2m table + // for example: + //tag3 := &Tag{Id:5,Name: "TestTag3"} + //num, err = m2m.Remove(tag3) Remove(...interface{}) (int64, error) + // check model is existed in relationship of origin model Exist(interface{}) bool + // clean all models in related of origin model Clear() (int64, error) + // count all related models of origin model Count() (int64, error) } -// raw query statement +// RawPreparer raw query statement type RawPreparer interface { Exec(...interface{}) (sql.Result, error) Close() error } -// raw query seter +// RawSeter raw query seter +// create From Ormer.Raw +// for example: +// sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q) +// rs := Ormer.Raw(sql, 1) type RawSeter interface { + //execute sql and get result Exec() (sql.Result, error) - QueryRow(...interface{}) error - QueryRows(...interface{}) (int64, error) + //query data and map to container + //for example: + // var name string + // var id int + // rs.QueryRow(&id,&name) // id==2 name=="slene" + QueryRow(containers ...interface{}) error + + // query data rows and map to container + // var ids []int + // var names []int + // query = fmt.Sprintf("SELECT 'id','name' FROM %suser%s", Q, Q) + // num, err = dORM.Raw(query).QueryRows(&ids,&names) // ids=>{1,2},names=>{"nobody","slene"} + QueryRows(containers ...interface{}) (int64, error) SetArgs(...interface{}) RawSeter - Values(*[]Params, ...string) (int64, error) - ValuesList(*[]ParamsList, ...string) (int64, error) - ValuesFlat(*ParamsList, ...string) (int64, error) - RowsToMap(*Params, string, string) (int64, error) - RowsToStruct(interface{}, string, string) (int64, error) + // query data to []map[string]interface + // see QuerySeter's Values + Values(container *[]Params, cols ...string) (int64, error) + // query data to [][]interface + // see QuerySeter's ValuesList + ValuesList(container *[]ParamsList, cols ...string) (int64, error) + // query data to []interface + // see QuerySeter's ValuesFlat + ValuesFlat(container *ParamsList, cols ...string) (int64, error) + // query all rows into map[string]interface with specify key and value column name. + // keyCol = "name", valueCol = "value" + // table data + // name | value + // total | 100 + // found | 200 + // to map[string]interface{}{ + // "total": 100, + // "found": 200, + // } + RowsToMap(result *Params, keyCol, valueCol string) (int64, error) + // query all rows into struct with specify key and value column name. + // keyCol = "name", valueCol = "value" + // table data + // name | value + // total | 100 + // found | 200 + // to struct { + // Total int + // Found int + // } + RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) + + // return prepared raw statement for used in times. + // for example: + // pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare() + // r, err := pre.Exec("name1") // INSERT INTO tag (name) VALUES (`name1`) Prepare() (RawPreparer, error) } -// statement querier +// stmtQuerier statement querier type stmtQuerier interface { Close() error Exec(args ...interface{}) (sql.Result, error) @@ -160,8 +391,8 @@ type dbBaser interface { UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) - OperatorSql(string) string - GenerateOperatorSql(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{}) + OperatorSQL(string) string + GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{}) GenerateOperatorLeftCol(*fieldInfo, string, *string) PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) diff --git a/orm/utils.go b/orm/utils.go index 88168763..99437c7b 100644 --- a/orm/utils.go +++ b/orm/utils.go @@ -22,9 +22,10 @@ import ( "time" ) +// StrTo is the target string type StrTo string -// set string +// Set string func (f *StrTo) Set(v string) { if v != "" { *f = StrTo(v) @@ -33,93 +34,93 @@ func (f *StrTo) Set(v string) { } } -// clean string +// Clear string func (f *StrTo) Clear() { *f = StrTo(0x1E) } -// check string exist +// Exist check string exist func (f StrTo) Exist() bool { return string(f) != string(0x1E) } -// string to bool +// Bool string to bool func (f StrTo) Bool() (bool, error) { return strconv.ParseBool(f.String()) } -// string to float32 +// Float32 string to float32 func (f StrTo) Float32() (float32, error) { v, err := strconv.ParseFloat(f.String(), 32) return float32(v), err } -// string to float64 +// Float64 string to float64 func (f StrTo) Float64() (float64, error) { return strconv.ParseFloat(f.String(), 64) } -// string to int +// Int string to int func (f StrTo) Int() (int, error) { v, err := strconv.ParseInt(f.String(), 10, 32) return int(v), err } -// string to int8 +// Int8 string to int8 func (f StrTo) Int8() (int8, error) { v, err := strconv.ParseInt(f.String(), 10, 8) return int8(v), err } -// string to int16 +// Int16 string to int16 func (f StrTo) Int16() (int16, error) { v, err := strconv.ParseInt(f.String(), 10, 16) return int16(v), err } -// string to int32 +// Int32 string to int32 func (f StrTo) Int32() (int32, error) { v, err := strconv.ParseInt(f.String(), 10, 32) return int32(v), err } -// string to int64 +// Int64 string to int64 func (f StrTo) Int64() (int64, error) { v, err := strconv.ParseInt(f.String(), 10, 64) return int64(v), err } -// string to uint +// Uint string to uint func (f StrTo) Uint() (uint, error) { v, err := strconv.ParseUint(f.String(), 10, 32) return uint(v), err } -// string to uint8 +// Uint8 string to uint8 func (f StrTo) Uint8() (uint8, error) { v, err := strconv.ParseUint(f.String(), 10, 8) return uint8(v), err } -// string to uint16 +// Uint16 string to uint16 func (f StrTo) Uint16() (uint16, error) { v, err := strconv.ParseUint(f.String(), 10, 16) return uint16(v), err } -// string to uint31 +// Uint32 string to uint31 func (f StrTo) Uint32() (uint32, error) { v, err := strconv.ParseUint(f.String(), 10, 32) return uint32(v), err } -// string to uint64 +// Uint64 string to uint64 func (f StrTo) Uint64() (uint64, error) { v, err := strconv.ParseUint(f.String(), 10, 64) return uint64(v), err } -// string to string +// String string to string func (f StrTo) String() string { if f.Exist() { return string(f) @@ -127,7 +128,7 @@ func (f StrTo) String() string { return "" } -// interface to string +// ToStr interface to string func ToStr(value interface{}, args ...int) (s string) { switch v := value.(type) { case bool: @@ -166,7 +167,7 @@ func ToStr(value interface{}, args ...int) (s string) { return s } -// interface to int64 +// ToInt64 interface to int64 func ToInt64(value interface{}) (d int64) { val := reflect.ValueOf(value) switch value.(type) { @@ -248,30 +249,12 @@ func (a argInt) Get(i int, args ...int) (r int) { return } -type argAny []interface{} - -// get interface by index from interface slice -func (a argAny) Get(i int, args ...interface{}) (r interface{}) { - if i >= 0 && i < len(a) { - r = a[i] - } - if len(args) > 0 { - r = args[0] - } - return -} - // parse time to string with location func timeParse(dateString, format string) (time.Time, error) { tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc) return tp, err } -// format time string -func timeFormat(t time.Time, format string) string { - return t.Format(format) -} - // get pointer indirect type func indirectType(v reflect.Type) reflect.Type { switch v.Kind() { diff --git a/parser.go b/parser.go index cf8dee22..b14d74b9 100644 --- a/parser.go +++ b/parser.go @@ -42,13 +42,13 @@ func init() { ` var ( - lastupdateFilename string = "lastupdate.tmp" + lastupdateFilename = "lastupdate.tmp" commentFilename string pkgLastupdate map[string]int64 genInfoList map[string][]ControllerComments ) -const COMMENTFL = "commentsRouter_" +const coomentPrefix = "commentsRouter_" func init() { pkgLastupdate = make(map[string]int64) @@ -56,7 +56,7 @@ func init() { func parserPkg(pkgRealpath, pkgpath string) error { rep := strings.NewReplacer("/", "_", ".", "_") - commentFilename = COMMENTFL + rep.Replace(pkgpath) + ".go" + commentFilename = coomentPrefix + rep.Replace(pkgpath) + ".go" if !compareFile(pkgRealpath) { Info(pkgRealpath + " has not changed, not reloading") return nil @@ -77,7 +77,10 @@ func parserPkg(pkgRealpath, pkgpath string) error { switch specDecl := d.(type) { case *ast.FuncDecl: if specDecl.Recv != nil { - parserComments(specDecl.Doc, specDecl.Name.String(), fmt.Sprint(specDecl.Recv.List[0].Type.(*ast.StarExpr).X), pkgpath) + exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser + if ok { + parserComments(specDecl.Doc, specDecl.Name.String(), fmt.Sprint(exp.X), pkgpath) + } } } } @@ -127,11 +130,13 @@ func parserComments(comments *ast.CommentGroup, funcName, controllerName, pkgpat } func genRouterCode() { - os.Mkdir(path.Join(workPath, "routers"), 0755) + os.Mkdir("routers", 0755) Info("generate router from comments") - var globalinfo string - sortKey := make([]string, 0) - for k, _ := range genInfoList { + var ( + globalinfo string + sortKey []string + ) + for k := range genInfoList { sortKey = append(sortKey, k) } sort.Strings(sortKey) @@ -167,7 +172,7 @@ func genRouterCode() { } } if globalinfo != "" { - f, err := os.Create(path.Join(workPath, "routers", commentFilename)) + f, err := os.Create(path.Join("routers", commentFilename)) if err != nil { panic(err) } @@ -177,11 +182,11 @@ func genRouterCode() { } func compareFile(pkgRealpath string) bool { - if !utils.FileExists(path.Join(workPath, "routers", commentFilename)) { + if !utils.FileExists(path.Join("routers", commentFilename)) { return true } - if utils.FileExists(path.Join(workPath, lastupdateFilename)) { - content, err := ioutil.ReadFile(path.Join(workPath, lastupdateFilename)) + if utils.FileExists(lastupdateFilename) { + content, err := ioutil.ReadFile(lastupdateFilename) if err != nil { return true } @@ -209,7 +214,7 @@ func savetoFile(pkgRealpath string) { if err != nil { return } - ioutil.WriteFile(path.Join(workPath, lastupdateFilename), d, os.ModePerm) + ioutil.WriteFile(lastupdateFilename, d, os.ModePerm) } func getpathTime(pkgRealpath string) (lastupdate int64, err error) { diff --git a/plugins/apiauth/apiauth.go b/plugins/apiauth/apiauth.go index bbae7def..3091c698 100644 --- a/plugins/apiauth/apiauth.go +++ b/plugins/apiauth/apiauth.go @@ -33,7 +33,7 @@ // // maybe store in configure, maybe in database // } // -// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APIAuthWithFunc(getAppSecret, 360)) +// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APISecretAuth(getAppSecret, 360)) // // Infomation: // @@ -68,8 +68,10 @@ import ( "github.com/astaxie/beego/context" ) -type AppIdToAppSecret func(string) string +// AppIDToAppSecret is used to get appsecret throw appid +type AppIDToAppSecret func(string) string +// APIBaiscAuth use the basic appid/appkey as the AppIdToAppSecret func APIBaiscAuth(appid, appkey string) beego.FilterFunc { ft := func(aid string) string { if aid == appid { @@ -77,52 +79,54 @@ func APIBaiscAuth(appid, appkey string) beego.FilterFunc { } return "" } - return APIAuthWithFunc(ft, 300) + return APISecretAuth(ft, 300) } -func APIAuthWithFunc(f AppIdToAppSecret, timeout int) beego.FilterFunc { +// APISecretAuth use AppIdToAppSecret verify and +func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc { return func(ctx *context.Context) { if ctx.Input.Query("appid") == "" { - ctx.Output.SetStatus(403) + ctx.ResponseWriter.WriteHeader(403) ctx.WriteString("miss query param: appid") return } appsecret := f(ctx.Input.Query("appid")) if appsecret == "" { - ctx.Output.SetStatus(403) + ctx.ResponseWriter.WriteHeader(403) ctx.WriteString("not exist this appid") return } if ctx.Input.Query("signature") == "" { - ctx.Output.SetStatus(403) + ctx.ResponseWriter.WriteHeader(403) ctx.WriteString("miss query param: signature") return } if ctx.Input.Query("timestamp") == "" { - ctx.Output.SetStatus(403) + ctx.ResponseWriter.WriteHeader(403) ctx.WriteString("miss query param: timestamp") return } u, err := time.Parse("2006-01-02 15:04:05", ctx.Input.Query("timestamp")) if err != nil { - ctx.Output.SetStatus(403) + ctx.ResponseWriter.WriteHeader(403) ctx.WriteString("timestamp format is error, should 2006-01-02 15:04:05") return } t := time.Now() if t.Sub(u).Seconds() > float64(timeout) { - ctx.Output.SetStatus(403) + ctx.ResponseWriter.WriteHeader(403) ctx.WriteString("timeout! the request time is long ago, please try again") return } if ctx.Input.Query("signature") != - Signature(appsecret, ctx.Input.Method(), ctx.Request.Form, ctx.Input.Uri()) { - ctx.Output.SetStatus(403) + Signature(appsecret, ctx.Input.Method(), ctx.Request.Form, ctx.Input.URI()) { + ctx.ResponseWriter.WriteHeader(403) ctx.WriteString("auth failed") } } } +// Signature used to generate signature with the appsecret/method/params/RequestURI func Signature(appsecret, method string, params url.Values, RequestURI string) (result string) { var query string pa := make(map[string]string) @@ -139,11 +143,11 @@ func Signature(appsecret, method string, params url.Values, RequestURI string) ( query = fmt.Sprintf("%v%v%v", query, vs.Keys[i], vs.Vals[i]) } } - string_to_sign := fmt.Sprintf("%v\n%v\n%v\n", method, query, RequestURI) + stringToSign := fmt.Sprintf("%v\n%v\n%v\n", method, query, RequestURI) sha256 := sha256.New hash := hmac.New(sha256, []byte(appsecret)) - hash.Write([]byte(string_to_sign)) + hash.Write([]byte(stringToSign)) return base64.StdEncoding.EncodeToString(hash.Sum(nil)) } diff --git a/plugins/auth/basic.go b/plugins/auth/basic.go index 946b8457..c478044a 100644 --- a/plugins/auth/basic.go +++ b/plugins/auth/basic.go @@ -46,6 +46,7 @@ import ( var defaultRealm = "Authorization Required" +// Basic is the http basic auth func Basic(username string, password string) beego.FilterFunc { secrets := func(user, pass string) bool { return user == username && pass == password @@ -53,6 +54,7 @@ func Basic(username string, password string) beego.FilterFunc { return NewBasicAuthenticator(secrets, defaultRealm) } +// NewBasicAuthenticator return the BasicAuth func NewBasicAuthenticator(secrets SecretProvider, Realm string) beego.FilterFunc { return func(ctx *context.Context) { a := &BasicAuth{Secrets: secrets, Realm: Realm} @@ -62,17 +64,19 @@ func NewBasicAuthenticator(secrets SecretProvider, Realm string) beego.FilterFun } } +// SecretProvider is the SecretProvider function type SecretProvider func(user, pass string) bool +// BasicAuth store the SecretProvider and Realm type BasicAuth struct { Secrets SecretProvider Realm string } -//Checks the username/password combination from the request. Returns -//either an empty string (authentication failed) or the name of the -//authenticated user. -//Supports MD5 and SHA1 password entries +// CheckAuth Checks the username/password combination from the request. Returns +// either an empty string (authentication failed) or the name of the +// authenticated user. +// Supports MD5 and SHA1 password entries func (a *BasicAuth) CheckAuth(r *http.Request) string { s := strings.SplitN(r.Header.Get("Authorization"), " ", 2) if len(s) != 2 || s[0] != "Basic" { @@ -94,8 +98,8 @@ func (a *BasicAuth) CheckAuth(r *http.Request) string { return "" } -//http.Handler for BasicAuth which initiates the authentication process -//(or requires reauthentication). +// RequireAuth http.Handler for BasicAuth which initiates the authentication process +// (or requires reauthentication). func (a *BasicAuth) RequireAuth(w http.ResponseWriter, r *http.Request) { w.Header().Set("WWW-Authenticate", `Basic realm="`+a.Realm+`"`) w.WriteHeader(401) diff --git a/plugins/cors/cors.go b/plugins/cors/cors.go index 052d3bc6..1e973a40 100644 --- a/plugins/cors/cors.go +++ b/plugins/cors/cors.go @@ -24,7 +24,7 @@ // // - PUT and PATCH methods // // - Origin header // // - Credentials share -// beego.InsertFilter("*", beego.BeforeRouter,cors.Allow(&cors.Options{ +// beego.InsertFilter("*", beego.BeforeRouter, cors.Allow(&cors.Options{ // AllowOrigins: []string{"https://*.foo.com"}, // AllowMethods: []string{"PUT", "PATCH"}, // AllowHeaders: []string{"Origin"}, @@ -36,7 +36,6 @@ package cors import ( - "net/http" "regexp" "strconv" "strings" @@ -216,8 +215,6 @@ func Allow(opts *Options) beego.FilterFunc { for key, value := range headers { ctx.Output.Header(key, value) } - ctx.Output.SetStatus(http.StatusOK) - ctx.WriteString("") return } headers = opts.Header(origin) diff --git a/plugins/cors/cors_test.go b/plugins/cors/cors_test.go index 5c02ab98..34039143 100644 --- a/plugins/cors/cors_test.go +++ b/plugins/cors/cors_test.go @@ -25,21 +25,23 @@ import ( "github.com/astaxie/beego/context" ) -type HttpHeaderGuardRecorder struct { +// HTTPHeaderGuardRecorder is httptest.ResponseRecorder with own http.Header +type HTTPHeaderGuardRecorder struct { *httptest.ResponseRecorder savedHeaderMap http.Header } -func NewRecorder() *HttpHeaderGuardRecorder { - return &HttpHeaderGuardRecorder{httptest.NewRecorder(), nil} +// NewRecorder return HttpHeaderGuardRecorder +func NewRecorder() *HTTPHeaderGuardRecorder { + return &HTTPHeaderGuardRecorder{httptest.NewRecorder(), nil} } -func (gr *HttpHeaderGuardRecorder) WriteHeader(code int) { +func (gr *HTTPHeaderGuardRecorder) WriteHeader(code int) { gr.ResponseRecorder.WriteHeader(code) gr.savedHeaderMap = gr.ResponseRecorder.Header() } -func (gr *HttpHeaderGuardRecorder) Header() http.Header { +func (gr *HTTPHeaderGuardRecorder) Header() http.Header { if gr.savedHeaderMap != nil { // headers were written. clone so we don't get updates clone := make(http.Header) @@ -47,9 +49,8 @@ func (gr *HttpHeaderGuardRecorder) Header() http.Header { clone[k] = v } return clone - } else { - return gr.ResponseRecorder.Header() } + return gr.ResponseRecorder.Header() } func Test_AllowAll(t *testing.T) { @@ -219,13 +220,13 @@ func Test_Preflight(t *testing.T) { func Benchmark_WithoutCORS(b *testing.B) { recorder := httptest.NewRecorder() handler := beego.NewControllerRegister() - beego.RunMode = "prod" + beego.BConfig.RunMode = beego.PROD handler.Any("/foo", func(ctx *context.Context) { ctx.Output.SetStatus(500) }) b.ResetTimer() - for i := 0; i < 100; i++ { - r, _ := http.NewRequest("PUT", "/foo", nil) + r, _ := http.NewRequest("PUT", "/foo", nil) + for i := 0; i < b.N; i++ { handler.ServeHTTP(recorder, r) } } @@ -233,7 +234,7 @@ func Benchmark_WithoutCORS(b *testing.B) { func Benchmark_WithCORS(b *testing.B) { recorder := httptest.NewRecorder() handler := beego.NewControllerRegister() - beego.RunMode = "prod" + beego.BConfig.RunMode = beego.PROD handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ AllowAllOrigins: true, AllowCredentials: true, @@ -245,8 +246,8 @@ func Benchmark_WithCORS(b *testing.B) { ctx.Output.SetStatus(500) }) b.ResetTimer() - for i := 0; i < 100; i++ { - r, _ := http.NewRequest("PUT", "/foo", nil) + r, _ := http.NewRequest("PUT", "/foo", nil) + for i := 0; i < b.N; i++ { handler.ServeHTTP(recorder, r) } } diff --git a/router.go b/router.go index 3e1ebab3..9d8cd2e6 100644 --- a/router.go +++ b/router.go @@ -15,10 +15,7 @@ package beego import ( - "bufio" - "errors" "fmt" - "net" "net/http" "os" "path" @@ -27,6 +24,7 @@ import ( "runtime" "strconv" "strings" + "sync" "time" beecontext "github.com/astaxie/beego/context" @@ -34,8 +32,8 @@ import ( "github.com/astaxie/beego/utils" ) +// default filter execution points const ( - // default filter execution points BeforeStatic = iota BeforeRouter BeforeExec @@ -50,7 +48,7 @@ const ( ) var ( - // supported http methods. + // HTTPMETHOD list the supported http methods. HTTPMETHOD = map[string]string{ "GET": "GET", "POST": "POST", @@ -71,10 +69,12 @@ var ( "SetSecureCookie", "XsrfToken", "CheckXsrfCookie", "XsrfFormHtml", "GetControllerAndAction"} - url_placeholder = "{{placeholder}}" - DefaultLogFilter FilterHandler = &logFilter{} + urlPlaceholder = "{{placeholder}}" + // DefaultAccessLogFilter will skip the accesslog if return true + DefaultAccessLogFilter FilterHandler = &logFilter{} ) +// FilterHandler is an interface for type FilterHandler interface { Filter(*beecontext.Context) bool } @@ -84,11 +84,11 @@ type logFilter struct { } func (l *logFilter) Filter(ctx *beecontext.Context) bool { - requestPath := path.Clean(ctx.Input.Request.URL.Path) + requestPath := path.Clean(ctx.Request.URL.Path) if requestPath == "/favicon.ico" || requestPath == "/robots.txt" { return true } - for prefix := range StaticDir { + for prefix := range BConfig.WebConfig.StaticDir { if strings.HasPrefix(requestPath, prefix) { return true } @@ -96,7 +96,7 @@ func (l *logFilter) Filter(ctx *beecontext.Context) bool { return false } -// To append a slice's value into "exceptMethod", for controller's methods shouldn't reflect to AutoRouter +// ExceptMethodAppend to append a slice's value into "exceptMethod", for controller's methods shouldn't reflect to AutoRouter func ExceptMethodAppend(action string) { exceptMethod = append(exceptMethod, action) } @@ -106,26 +106,31 @@ type controllerInfo struct { controllerType reflect.Type methods map[string]string handler http.Handler - runfunction FilterFunc + runFunction FilterFunc routerType int } -// ControllerRegistor containers registered router rules, controller handlers and filters. -type ControllerRegistor struct { +// ControllerRegister containers registered router rules, controller handlers and filters. +type ControllerRegister struct { routers map[string]*Tree enableFilter bool filters map[int][]*FilterRouter + pool sync.Pool } -// NewControllerRegister returns a new ControllerRegistor. -func NewControllerRegister() *ControllerRegistor { - return &ControllerRegistor{ +// NewControllerRegister returns a new ControllerRegister. +func NewControllerRegister() *ControllerRegister { + cr := &ControllerRegister{ routers: make(map[string]*Tree), filters: make(map[int][]*FilterRouter), } + cr.pool.New = func() interface{} { + return beecontext.NewContext() + } + return cr } -// Add controller handler and pattern rules to ControllerRegistor. +// Add controller handler and pattern rules to ControllerRegister. // usage: // default methods is the same name as method // Add("/user",&UserController{}) @@ -133,9 +138,9 @@ func NewControllerRegister() *ControllerRegistor { // Add("/api/create",&RestController{},"post:CreateFood") // Add("/api/update",&RestController{},"put:UpdateFood") // Add("/api/delete",&RestController{},"delete:DeleteFood") -// Add("/api",&RestController{},"get,post:ApiFunc") +// Add("/api",&RestController{},"get,post:ApiFunc" // Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") -func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingMethods ...string) { +func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { reflectVal := reflect.ValueOf(c) t := reflect.Indirect(reflectVal).Type() methods := make(map[string]string) @@ -183,8 +188,8 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM } } -func (p *ControllerRegistor) addToRouter(method, pattern string, r *controllerInfo) { - if !RouterCaseSensitive { +func (p *ControllerRegister) addToRouter(method, pattern string, r *controllerInfo) { + if !BConfig.RouterCaseSensitive { pattern = strings.ToLower(pattern) } if t, ok := p.routers[method]; ok { @@ -196,10 +201,10 @@ func (p *ControllerRegistor) addToRouter(method, pattern string, r *controllerIn } } -// only when the Runmode is dev will generate router file in the router/auto.go from the controller +// Include only when the Runmode is dev will generate router file in the router/auto.go from the controller // Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) -func (p *ControllerRegistor) Include(cList ...ControllerInterface) { - if RunMode == "dev" { +func (p *ControllerRegister) Include(cList ...ControllerInterface) { + if BConfig.RunMode == DEV { skip := make(map[string]bool, 10) for _, c := range cList { reflectVal := reflect.ValueOf(c) @@ -238,91 +243,91 @@ func (p *ControllerRegistor) Include(cList ...ControllerInterface) { } } -// add get method +// Get add get method // usage: // Get("/", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func (p *ControllerRegistor) Get(pattern string, f FilterFunc) { +func (p *ControllerRegister) Get(pattern string, f FilterFunc) { p.AddMethod("get", pattern, f) } -// add post method +// Post add post method // usage: // Post("/api", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func (p *ControllerRegistor) Post(pattern string, f FilterFunc) { +func (p *ControllerRegister) Post(pattern string, f FilterFunc) { p.AddMethod("post", pattern, f) } -// add put method +// Put add put method // usage: // Put("/api/:id", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func (p *ControllerRegistor) Put(pattern string, f FilterFunc) { +func (p *ControllerRegister) Put(pattern string, f FilterFunc) { p.AddMethod("put", pattern, f) } -// add delete method +// Delete add delete method // usage: // Delete("/api/:id", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func (p *ControllerRegistor) Delete(pattern string, f FilterFunc) { +func (p *ControllerRegister) Delete(pattern string, f FilterFunc) { p.AddMethod("delete", pattern, f) } -// add head method +// Head add head method // usage: // Head("/api/:id", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func (p *ControllerRegistor) Head(pattern string, f FilterFunc) { +func (p *ControllerRegister) Head(pattern string, f FilterFunc) { p.AddMethod("head", pattern, f) } -// add patch method +// Patch add patch method // usage: // Patch("/api/:id", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func (p *ControllerRegistor) Patch(pattern string, f FilterFunc) { +func (p *ControllerRegister) Patch(pattern string, f FilterFunc) { p.AddMethod("patch", pattern, f) } -// add options method +// Options add options method // usage: // Options("/api/:id", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func (p *ControllerRegistor) Options(pattern string, f FilterFunc) { +func (p *ControllerRegister) Options(pattern string, f FilterFunc) { p.AddMethod("options", pattern, f) } -// add all method +// Any add all method // usage: // Any("/api/:id", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func (p *ControllerRegistor) Any(pattern string, f FilterFunc) { +func (p *ControllerRegister) Any(pattern string, f FilterFunc) { p.AddMethod("*", pattern, f) } -// add http method router +// AddMethod add http method router // usage: // AddMethod("get","/api/:id", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func (p *ControllerRegistor) AddMethod(method, pattern string, f FilterFunc) { +func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) { if _, ok := HTTPMETHOD[strings.ToUpper(method)]; method != "*" && !ok { panic("not support http method: " + method) } route := &controllerInfo{} route.pattern = pattern route.routerType = routerTypeRESTFul - route.runfunction = f + route.runFunction = f methods := make(map[string]string) if method == "*" { for _, val := range HTTPMETHOD { @@ -343,15 +348,15 @@ func (p *ControllerRegistor) AddMethod(method, pattern string, f FilterFunc) { } } -// add user defined Handler -func (p *ControllerRegistor) Handler(pattern string, h http.Handler, options ...interface{}) { +// Handler add user defined Handler +func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) { route := &controllerInfo{} route.pattern = pattern route.routerType = routerTypeHandler route.handler = h if len(options) > 0 { if _, ok := options[0].(bool); ok { - pattern = path.Join(pattern, "?:all") + pattern = path.Join(pattern, "?:all(.*)") } } for _, m := range HTTPMETHOD { @@ -359,21 +364,21 @@ func (p *ControllerRegistor) Handler(pattern string, h http.Handler, options ... } } -// Add auto router to ControllerRegistor. +// AddAuto router to ControllerRegister. // example beego.AddAuto(&MainContorlller{}), // MainController has method List and Page. // visit the url /main/list to execute List function // /main/page to execute Page function. -func (p *ControllerRegistor) AddAuto(c ControllerInterface) { +func (p *ControllerRegister) AddAuto(c ControllerInterface) { p.AddAutoPrefix("/", c) } -// Add auto router to ControllerRegistor with prefix. +// AddAutoPrefix Add auto router to ControllerRegister with prefix. // example beego.AddAutoPrefix("/admin",&MainContorlller{}), // MainController has method List and Page. // visit the url /admin/main/list to execute List function // /admin/main/page to execute Page function. -func (p *ControllerRegistor) AddAutoPrefix(prefix string, c ControllerInterface) { +func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) { reflectVal := reflect.ValueOf(c) rt := reflectVal.Type() ct := reflect.Indirect(reflectVal).Type() @@ -386,28 +391,28 @@ func (p *ControllerRegistor) AddAutoPrefix(prefix string, c ControllerInterface) route.controllerType = ct pattern := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name), "*") patternInit := path.Join(prefix, controllerName, rt.Method(i).Name, "*") - patternfix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name)) - patternfixInit := path.Join(prefix, controllerName, rt.Method(i).Name) + patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name)) + patternFixInit := path.Join(prefix, controllerName, rt.Method(i).Name) route.pattern = pattern for _, m := range HTTPMETHOD { p.addToRouter(m, pattern, route) p.addToRouter(m, patternInit, route) - p.addToRouter(m, patternfix, route) - p.addToRouter(m, patternfixInit, route) + p.addToRouter(m, patternFix, route) + p.addToRouter(m, patternFixInit, route) } } } } -// Add a FilterFunc with pattern rule and action constant. +// InsertFilter Add a FilterFunc with pattern rule and action constant. // The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) -func (p *ControllerRegistor) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { +func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { mr := new(FilterRouter) mr.tree = NewTree() mr.pattern = pattern mr.filterFunc = filter - if !RouterCaseSensitive { + if !BConfig.RouterCaseSensitive { pattern = strings.ToLower(pattern) } if len(params) == 0 { @@ -420,15 +425,15 @@ func (p *ControllerRegistor) InsertFilter(pattern string, pos int, filter Filter } // add Filter into -func (p *ControllerRegistor) insertFilterRouter(pos int, mr *FilterRouter) error { +func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) error { p.filters[pos] = append(p.filters[pos], mr) p.enableFilter = true return nil } -// UrlFor does another controller handler in this request function. +// URLFor does another controller handler in this request function. // it can access any controller method. -func (p *ControllerRegistor) UrlFor(endpoint string, values ...interface{}) string { +func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) string { paths := strings.Split(endpoint, ".") if len(paths) <= 1 { Warn("urlfor endpoint must like path.controller.method") @@ -460,16 +465,16 @@ func (p *ControllerRegistor) UrlFor(endpoint string, values ...interface{}) stri return "" } -func (p *ControllerRegistor) geturl(t *Tree, url, controllName, methodName string, params map[string]string, httpMethod string) (bool, string) { - for k, subtree := range t.fixrouters { - u := path.Join(url, k) +func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName string, params map[string]string, httpMethod string) (bool, string) { + for _, subtree := range t.fixrouters { + u := path.Join(url, subtree.prefix) ok, u := p.geturl(subtree, u, controllName, methodName, params, httpMethod) if ok { return ok, u } } if t.wildcard != nil { - u := path.Join(url, url_placeholder) + u := path.Join(url, urlPlaceholder) ok, u := p.geturl(t.wildcard, u, controllName, methodName, params, httpMethod) if ok { return ok, u @@ -499,22 +504,21 @@ func (p *ControllerRegistor) geturl(t *Tree, url, controllName, methodName strin if find { if l.regexps == nil { if len(l.wildcards) == 0 { - return true, strings.Replace(url, "/"+url_placeholder, "", 1) + tourl(params) + return true, strings.Replace(url, "/"+urlPlaceholder, "", 1) + toUrl(params) } if len(l.wildcards) == 1 { if v, ok := params[l.wildcards[0]]; ok { delete(params, l.wildcards[0]) - return true, strings.Replace(url, url_placeholder, v, 1) + tourl(params) - } else { - return false, "" + return true, strings.Replace(url, urlPlaceholder, v, 1) + toUrl(params) } + return false, "" } if len(l.wildcards) == 3 && l.wildcards[0] == "." { if p, ok := params[":path"]; ok { if e, isok := params[":ext"]; isok { delete(params, ":path") delete(params, ":ext") - return true, strings.Replace(url, url_placeholder, p+"."+e, -1) + tourl(params) + return true, strings.Replace(url, urlPlaceholder, p+"."+e, -1) + toUrl(params) } } } @@ -526,45 +530,43 @@ func (p *ControllerRegistor) geturl(t *Tree, url, controllName, methodName strin } if u, ok := params[v]; ok { delete(params, v) - url = strings.Replace(url, url_placeholder, u, 1) + url = strings.Replace(url, urlPlaceholder, u, 1) } else { if canskip { canskip = false continue - } else { - return false, "" } + return false, "" } } - return true, url + tourl(params) - } else { - var i int - var startreg bool - regurl := "" - for _, v := range strings.Trim(l.regexps.String(), "^$") { - if v == '(' { - startreg = true - continue - } else if v == ')' { - startreg = false - if v, ok := params[l.wildcards[i]]; ok { - delete(params, l.wildcards[i]) - regurl = regurl + v - i++ - } else { - break - } - } else if !startreg { - regurl = string(append([]rune(regurl), v)) + return true, url + toUrl(params) + } + var i int + var startreg bool + regurl := "" + for _, v := range strings.Trim(l.regexps.String(), "^$") { + if v == '(' { + startreg = true + continue + } else if v == ')' { + startreg = false + if v, ok := params[l.wildcards[i]]; ok { + delete(params, l.wildcards[i]) + regurl = regurl + v + i++ + } else { + break } + } else if !startreg { + regurl = string(append([]rune(regurl), v)) } - if l.regexps.MatchString(regurl) { - ps := strings.Split(regurl, "/") - for _, p := range ps { - url = strings.Replace(url, url_placeholder, p, 1) - } - return true, url + tourl(params) + } + if l.regexps.MatchString(regurl) { + ps := strings.Split(regurl, "/") + for _, p := range ps { + url = strings.Replace(url, urlPlaceholder, p, 1) } + return true, url + toUrl(params) } } } @@ -574,168 +576,137 @@ func (p *ControllerRegistor) geturl(t *Tree, url, controllName, methodName strin return false, "" } +func (p *ControllerRegister) execFilter(context *beecontext.Context, pos int, urlPath string) (started bool) { + if p.enableFilter { + if l, ok := p.filters[pos]; ok { + for _, filterR := range l { + if filterR.returnOnOutput && context.ResponseWriter.Started { + return true + } + if ok := filterR.ValidRouter(urlPath, context); ok { + filterR.filterFunc(context) + } + if filterR.returnOnOutput && context.ResponseWriter.Started { + return true + } + } + } + } + return false +} + // Implement http.Handler interface. -func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) { - starttime := time.Now() - var runrouter reflect.Type - var findrouter bool - var runMethod string - var routerInfo *controllerInfo - - w := &responseWriter{writer: rw} - - if RunMode == "dev" { - w.Header().Set("Server", BeegoServerName) - } - - // init context - context := &beecontext.Context{ - ResponseWriter: w, - Request: r, - Input: beecontext.NewInput(r), - Output: beecontext.NewOutput(), - } - context.Output.Context = context - context.Output.EnableGzip = EnableGzip - +func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + startTime := time.Now() + var ( + runRouter reflect.Type + findRouter bool + runMethod string + routerInfo *controllerInfo + ) + context := p.pool.Get().(*beecontext.Context) + context.Reset(rw, r) + defer p.pool.Put(context) defer p.recoverPanic(context) + context.Output.EnableGzip = BConfig.EnableGzip + + if BConfig.RunMode == DEV { + context.Output.Header("Server", BConfig.ServerName) + } + var urlPath string - if !RouterCaseSensitive { + if !BConfig.RouterCaseSensitive { urlPath = strings.ToLower(r.URL.Path) } else { urlPath = r.URL.Path } - // defined filter function - do_filter := func(pos int) (started bool) { - if p.enableFilter { - if l, ok := p.filters[pos]; ok { - for _, filterR := range l { - if filterR.returnOnOutput && w.started { - return true - } - if ok, params := filterR.ValidRouter(urlPath); ok { - for k, v := range params { - if context.Input.Params == nil { - context.Input.Params = make(map[string]string) - } - context.Input.Params[k] = v - } - filterR.filterFunc(context) - } - if filterR.returnOnOutput && w.started { - return true - } - } - } - } - return false - } - // filter wrong httpmethod + // filter wrong http method if _, ok := HTTPMETHOD[r.Method]; !ok { - http.Error(w, "Method Not Allowed", 405) + http.Error(rw, "Method Not Allowed", 405) goto Admin } // filter for static file - if do_filter(BeforeStatic) { + if p.execFilter(context, BeforeStatic, urlPath) { goto Admin } serverStaticRouter(context) - if w.started { - findrouter = true + if context.ResponseWriter.Started { + findRouter = true goto Admin } // session init - if SessionOn { + if BConfig.WebConfig.Session.SessionOn { var err error - context.Input.CruSession, err = GlobalSessions.SessionStart(w, r) + context.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) if err != nil { Error(err) exception("503", context) return } defer func() { - context.Input.CruSession.SessionRelease(w) + context.Input.CruSession.SessionRelease(rw) }() } if r.Method != "GET" && r.Method != "HEAD" { - if CopyRequestBody && !context.Input.IsUpload() { - context.Input.CopyBody() + if BConfig.CopyRequestBody && !context.Input.IsUpload() { + context.Input.CopyBody(BConfig.MaxMemory) } - context.Input.ParseFormOrMulitForm(MaxMemory) + context.Input.ParseFormOrMulitForm(BConfig.MaxMemory) } - if do_filter(BeforeRouter) { + if p.execFilter(context, BeforeRouter, urlPath) { goto Admin } - if context.Input.RunController != nil && context.Input.RunMethod != "" { - findrouter = true - runMethod = context.Input.RunMethod - runrouter = context.Input.RunController - } - - if !findrouter { - http_method := r.Method - - if http_method == "POST" && context.Input.Query("_method") == "PUT" { - http_method = "PUT" - } - - if http_method == "POST" && context.Input.Query("_method") == "DELETE" { - http_method = "DELETE" - } - - if t, ok := p.routers[http_method]; ok { - runObject, p := t.Match(urlPath) + if !findRouter { + httpMethod := r.Method + if t, ok := p.routers[httpMethod]; ok { + runObject := t.Match(urlPath, context) if r, ok := runObject.(*controllerInfo); ok { routerInfo = r - findrouter = true - if splat, ok := p[":splat"]; ok { - splatlist := strings.Split(splat, "/") - for k, v := range splatlist { - p[strconv.Itoa(k)] = v + findRouter = true + if splat := context.Input.Param(":splat"); splat != "" { + for k, v := range strings.Split(splat, "/") { + context.Input.SetParam(strconv.Itoa(k), v) } } - if p != nil { - context.Input.Params = p - } } } } //if no matches to url, throw a not found exception - if !findrouter { + if !findRouter { exception("404", context) goto Admin } - if findrouter { + if findRouter { //execute middleware filters - if do_filter(BeforeExec) { + if p.execFilter(context, BeforeExec, urlPath) { goto Admin } - isRunable := false + isRunnable := false if routerInfo != nil { if routerInfo.routerType == routerTypeRESTFul { if _, ok := routerInfo.methods[r.Method]; ok { - isRunable = true - routerInfo.runfunction(context) + isRunnable = true + routerInfo.runFunction(context) } else { exception("405", context) goto Admin } } else if routerInfo.routerType == routerTypeHandler { - isRunable = true + isRunnable = true routerInfo.handler.ServeHTTP(rw, r) } else { - runrouter = routerInfo.controllerType + runRouter = routerInfo.controllerType method := r.Method if r.Method == "POST" && context.Input.Query("_method") == "PUT" { method = "PUT" @@ -753,33 +724,33 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) } } - // also defined runrouter & runMethod from filter - if !isRunable { + // also defined runRouter & runMethod from filter + if !isRunnable { //Invoke the request handler - vc := reflect.New(runrouter) + vc := reflect.New(runRouter) execController, ok := vc.Interface().(ControllerInterface) if !ok { panic("controller is not ControllerInterface") } //call the controller init function - execController.Init(context, runrouter.Name(), runMethod, vc.Interface()) + execController.Init(context, runRouter.Name(), runMethod, vc.Interface()) //call prepare function execController.Prepare() //if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf - if EnableXSRF { - execController.XsrfToken() + if BConfig.WebConfig.EnableXSRF { + execController.XSRFToken() if r.Method == "POST" || r.Method == "DELETE" || r.Method == "PUT" || (r.Method == "POST" && (context.Input.Query("_method") == "DELETE" || context.Input.Query("_method") == "PUT")) { - execController.CheckXsrfCookie() + execController.CheckXSRFCookie() } } execController.URLMapping() - if !w.started { + if !context.ResponseWriter.Started { //exec main logic switch runMethod { case "GET": @@ -798,15 +769,15 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) execController.Options() default: if !execController.HandlerFunc(runMethod) { - in := make([]reflect.Value, 0) + var in []reflect.Value method := vc.MethodByName(runMethod) method.Call(in) } } //render template - if !w.started && context.Output.Status == 0 { - if AutoRender { + if !context.ResponseWriter.Started && context.Output.Status == 0 { + if BConfig.WebConfig.AutoRender { if err := execController.Render(); err != nil { panic(err) } @@ -814,69 +785,69 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) } } - // finish all runrouter. release resource + // finish all runRouter. release resource execController.Finish() } //execute middleware filters - if do_filter(AfterExec) { + if p.execFilter(context, AfterExec, urlPath) { goto Admin } } - do_filter(FinishRouter) + p.execFilter(context, FinishRouter, urlPath) Admin: - timeend := time.Since(starttime) + timeDur := time.Since(startTime) //admin module record QPS - if EnableAdmin { - if FilterMonitorFunc(r.Method, r.URL.Path, timeend) { - if runrouter != nil { - go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, runrouter.Name(), timeend) + if BConfig.Listen.EnableAdmin { + if FilterMonitorFunc(r.Method, r.URL.Path, timeDur) { + if runRouter != nil { + go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, runRouter.Name(), timeDur) } else { - go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, "", timeend) + go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, "", timeDur) } } } - if RunMode == "dev" || AccessLogs { - var devinfo string - if findrouter { + if BConfig.RunMode == DEV || BConfig.Log.AccessLogs { + var devInfo string + if findRouter { if routerInfo != nil { - devinfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s | % -40s |", r.Method, r.URL.Path, timeend.String(), "match", routerInfo.pattern) + devInfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s | % -40s |", r.Method, r.URL.Path, timeDur.String(), "match", routerInfo.pattern) } else { - devinfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s |", r.Method, r.URL.Path, timeend.String(), "match") + devInfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s |", r.Method, r.URL.Path, timeDur.String(), "match") } } else { - devinfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s |", r.Method, r.URL.Path, timeend.String(), "notmatch") + devInfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s |", r.Method, r.URL.Path, timeDur.String(), "notmatch") } - if DefaultLogFilter == nil || !DefaultLogFilter.Filter(context) { - Debug(devinfo) + if DefaultAccessLogFilter == nil || !DefaultAccessLogFilter.Filter(context) { + Debug(devInfo) } } // Call WriteHeader if status code has been set changed if context.Output.Status != 0 { - w.writer.WriteHeader(context.Output.Status) + context.ResponseWriter.WriteHeader(context.Output.Status) } } -func (p *ControllerRegistor) recoverPanic(context *beecontext.Context) { +func (p *ControllerRegister) recoverPanic(context *beecontext.Context) { if err := recover(); err != nil { - if err == USERSTOPRUN { + if err == ErrAbort { return } - if !RecoverPanic { + if !BConfig.RecoverPanic { panic(err) } else { - if ErrorsShow { + if BConfig.EnableErrorsShow { if _, ok := ErrorMaps[fmt.Sprint(err)]; ok { exception(fmt.Sprint(err), context) return } } var stack string - Critical("the request url is ", context.Input.Url()) + Critical("the request url is ", context.Input.URL()) Critical("Handler crashed with error", err) for i := 1; ; i++ { _, file, line, ok := runtime.Caller(i) @@ -886,52 +857,14 @@ func (p *ControllerRegistor) recoverPanic(context *beecontext.Context) { Critical(fmt.Sprintf("%s:%d", file, line)) stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line)) } - if RunMode == "dev" { + if BConfig.RunMode == DEV { showErr(err, context, stack) } } } } -//responseWriter is a wrapper for the http.ResponseWriter -//started set to true if response was written to then don't execute other handler -type responseWriter struct { - writer http.ResponseWriter - started bool - status int -} - -// Header returns the header map that will be sent by WriteHeader. -func (w *responseWriter) Header() http.Header { - return w.writer.Header() -} - -// Write writes the data to the connection as part of an HTTP reply, -// and sets `started` to true. -// started means the response has sent out. -func (w *responseWriter) Write(p []byte) (int, error) { - w.started = true - return w.writer.Write(p) -} - -// WriteHeader sends an HTTP response header with status code, -// and sets `started` to true. -func (w *responseWriter) WriteHeader(code int) { - w.status = code - w.started = true - w.writer.WriteHeader(code) -} - -// hijacker for http -func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - hj, ok := w.writer.(http.Hijacker) - if !ok { - return nil, nil, errors.New("webserver doesn't support hijacking") - } - return hj.Hijack() -} - -func tourl(params map[string]string) string { +func toUrl(params map[string]string) string { if len(params) == 0 { return "" } diff --git a/router_test.go b/router_test.go index 005f32d6..b0ae7a18 100644 --- a/router_test.go +++ b/router_test.go @@ -45,24 +45,24 @@ func (tc *TestController) List() { } func (tc *TestController) Params() { - tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Params["0"] + tc.Ctx.Input.Params["1"] + tc.Ctx.Input.Params["2"])) + tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Param("0") + tc.Ctx.Input.Param("1") + tc.Ctx.Input.Param("2"))) } func (tc *TestController) Myext() { tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Param(":ext"))) } -func (tc *TestController) GetUrl() { - tc.Ctx.Output.Body([]byte(tc.UrlFor(".Myext"))) +func (tc *TestController) GetURL() { + tc.Ctx.Output.Body([]byte(tc.URLFor(".Myext"))) } -func (t *TestController) GetParams() { - t.Ctx.WriteString(t.Ctx.Input.Query(":last") + "+" + - t.Ctx.Input.Query(":first") + "+" + t.Ctx.Input.Query("learn")) +func (tc *TestController) GetParams() { + tc.Ctx.WriteString(tc.Ctx.Input.Query(":last") + "+" + + tc.Ctx.Input.Query(":first") + "+" + tc.Ctx.Input.Query("learn")) } -func (t *TestController) GetManyRouter() { - t.Ctx.WriteString(t.Ctx.Input.Query(":id") + t.Ctx.Input.Query(":page")) +func (tc *TestController) GetManyRouter() { + tc.Ctx.WriteString(tc.Ctx.Input.Query(":id") + tc.Ctx.Input.Query(":page")) } type ResStatus struct { @@ -70,29 +70,29 @@ type ResStatus struct { Msg string } -type JsonController struct { +type JSONController struct { Controller } -func (this *JsonController) Prepare() { - this.Data["json"] = "prepare" - this.ServeJson(true) +func (jc *JSONController) Prepare() { + jc.Data["json"] = "prepare" + jc.ServeJSON(true) } -func (this *JsonController) Get() { - this.Data["Username"] = "astaxie" - this.Ctx.Output.Body([]byte("ok")) +func (jc *JSONController) Get() { + jc.Data["Username"] = "astaxie" + jc.Ctx.Output.Body([]byte("ok")) } func TestUrlFor(t *testing.T) { handler := NewControllerRegister() handler.Add("/api/list", &TestController{}, "*:List") handler.Add("/person/:last/:first", &TestController{}, "*:Param") - if a := handler.UrlFor("TestController.List"); a != "/api/list" { + if a := handler.URLFor("TestController.List"); a != "/api/list" { Info(a) t.Errorf("TestController.List must equal to /api/list") } - if a := handler.UrlFor("TestController.Param", ":last", "xie", ":first", "asta"); a != "/person/xie/asta" { + if a := handler.URLFor("TestController.Param", ":last", "xie", ":first", "asta"); a != "/person/xie/asta" { t.Errorf("TestController.Param must equal to /person/xie/asta, but get " + a) } } @@ -100,39 +100,39 @@ func TestUrlFor(t *testing.T) { func TestUrlFor3(t *testing.T) { handler := NewControllerRegister() handler.AddAuto(&TestController{}) - if a := handler.UrlFor("TestController.Myext"); a != "/test/myext" && a != "/Test/Myext" { + if a := handler.URLFor("TestController.Myext"); a != "/test/myext" && a != "/Test/Myext" { t.Errorf("TestController.Myext must equal to /test/myext, but get " + a) } - if a := handler.UrlFor("TestController.GetUrl"); a != "/test/geturl" && a != "/Test/GetUrl" { - t.Errorf("TestController.GetUrl must equal to /test/geturl, but get " + a) + if a := handler.URLFor("TestController.GetURL"); a != "/test/geturl" && a != "/Test/GetURL" { + t.Errorf("TestController.GetURL must equal to /test/geturl, but get " + a) } } func TestUrlFor2(t *testing.T) { handler := NewControllerRegister() handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, "*:List") - handler.Add("/v1/:username/edit", &TestController{}, "get:GetUrl") + handler.Add("/v1/:username/edit", &TestController{}, "get:GetURL") handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, "*:Param") handler.Add("/:year:int/:month:int/:title/:entid", &TestController{}) - if handler.UrlFor("TestController.GetUrl", ":username", "astaxie") != "/v1/astaxie/edit" { - Info(handler.UrlFor("TestController.GetUrl")) + if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" { + Info(handler.URLFor("TestController.GetURL")) t.Errorf("TestController.List must equal to /v1/astaxie/edit") } - if handler.UrlFor("TestController.List", ":v", "za", ":id", "12", ":page", "123") != + if handler.URLFor("TestController.List", ":v", "za", ":id", "12", ":page", "123") != "/v1/za/cms_12_123.html" { - Info(handler.UrlFor("TestController.List")) + Info(handler.URLFor("TestController.List")) t.Errorf("TestController.List must equal to /v1/za/cms_12_123.html") } - if handler.UrlFor("TestController.Param", ":v", "za", ":id", "12", ":page", "123") != + if handler.URLFor("TestController.Param", ":v", "za", ":id", "12", ":page", "123") != "/v1/za_cms/ttt_12_123.html" { - Info(handler.UrlFor("TestController.Param")) + Info(handler.URLFor("TestController.Param")) t.Errorf("TestController.List must equal to /v1/za_cms/ttt_12_123.html") } - if handler.UrlFor("TestController.Get", ":year", "1111", ":month", "11", + if handler.URLFor("TestController.Get", ":year", "1111", ":month", "11", ":title", "aaaa", ":entid", "aaaa") != "/1111/11/aaaa/aaaa" { - Info(handler.UrlFor("TestController.Get")) + Info(handler.URLFor("TestController.Get")) t.Errorf("TestController.Get must equal to /1111/11/aaaa/aaaa") } } @@ -270,7 +270,7 @@ func TestPrepare(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/json/list", &JsonController{}) + handler.Add("/json/list", &JSONController{}) handler.ServeHTTP(w, r) if w.Body.String() != `"prepare"` { t.Errorf(w.Body.String() + "user define func can't run") @@ -333,6 +333,18 @@ func TestRouterHandler(t *testing.T) { } } +func TestRouterHandlerAll(t *testing.T) { + r, _ := http.NewRequest("POST", "/sayhi/a/b/c", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Handler("/sayhi", http.HandlerFunc(sayhello), true) + handler.ServeHTTP(w, r) + if w.Body.String() != "sayhello" { + t.Errorf("TestRouterHandler can't run") + } +} + // // Benchmarks NewApp: // diff --git a/session/couchbase/sess_couchbase.go b/session/couchbase/sess_couchbase.go index 827d55d9..d5be11d0 100644 --- a/session/couchbase/sess_couchbase.go +++ b/session/couchbase/sess_couchbase.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// package couchbase for session provider +// Package couchbase for session provider // // depend on github.com/couchbaselabs/go-couchbasee // @@ -30,21 +30,22 @@ // } // // more docs: http://beego.me/docs/module/session.md -package session +package couchbase import ( "net/http" "strings" "sync" - "github.com/couchbaselabs/go-couchbase" + couchbase "github.com/couchbase/go-couchbase" "github.com/astaxie/beego/session" ) -var couchbpder = &CouchbaseProvider{} +var couchbpder = &Provider{} -type CouchbaseSessionStore struct { +// SessionStore store each session +type SessionStore struct { b *couchbase.Bucket sid string lock sync.RWMutex @@ -52,7 +53,8 @@ type CouchbaseSessionStore struct { maxlifetime int64 } -type CouchbaseProvider struct { +// Provider couchabse provided +type Provider struct { maxlifetime int64 savePath string pool string @@ -60,42 +62,47 @@ type CouchbaseProvider struct { b *couchbase.Bucket } -func (cs *CouchbaseSessionStore) Set(key, value interface{}) error { +// Set value to couchabse session +func (cs *SessionStore) Set(key, value interface{}) error { cs.lock.Lock() defer cs.lock.Unlock() cs.values[key] = value return nil } -func (cs *CouchbaseSessionStore) Get(key interface{}) interface{} { +// Get value from couchabse session +func (cs *SessionStore) Get(key interface{}) interface{} { cs.lock.RLock() defer cs.lock.RUnlock() if v, ok := cs.values[key]; ok { return v - } else { - return nil } + return nil } -func (cs *CouchbaseSessionStore) Delete(key interface{}) error { +// Delete value in couchbase session by given key +func (cs *SessionStore) Delete(key interface{}) error { cs.lock.Lock() defer cs.lock.Unlock() delete(cs.values, key) return nil } -func (cs *CouchbaseSessionStore) Flush() error { +// Flush Clean all values in couchbase session +func (cs *SessionStore) Flush() error { cs.lock.Lock() defer cs.lock.Unlock() cs.values = make(map[interface{}]interface{}) return nil } -func (cs *CouchbaseSessionStore) SessionID() string { +// SessionID Get couchbase session store id +func (cs *SessionStore) SessionID() string { return cs.sid } -func (cs *CouchbaseSessionStore) SessionRelease(w http.ResponseWriter) { +// SessionRelease Write couchbase session with Gob string +func (cs *SessionStore) SessionRelease(w http.ResponseWriter) { defer cs.b.Close() bo, err := session.EncodeGob(cs.values) @@ -106,7 +113,7 @@ func (cs *CouchbaseSessionStore) SessionRelease(w http.ResponseWriter) { cs.b.Set(cs.sid, int(cs.maxlifetime), bo) } -func (cp *CouchbaseProvider) getBucket() *couchbase.Bucket { +func (cp *Provider) getBucket() *couchbase.Bucket { c, err := couchbase.Connect(cp.savePath) if err != nil { return nil @@ -125,10 +132,10 @@ func (cp *CouchbaseProvider) getBucket() *couchbase.Bucket { return bucket } -// init couchbase session +// SessionInit init couchbase session // savepath like couchbase server REST/JSON URL // e.g. http://host:port/, Pool, Bucket -func (cp *CouchbaseProvider) SessionInit(maxlifetime int64, savePath string) error { +func (cp *Provider) SessionInit(maxlifetime int64, savePath string) error { cp.maxlifetime = maxlifetime configs := strings.Split(savePath, ",") if len(configs) > 0 { @@ -144,8 +151,8 @@ func (cp *CouchbaseProvider) SessionInit(maxlifetime int64, savePath string) err return nil } -// read couchbase session by sid -func (cp *CouchbaseProvider) SessionRead(sid string) (session.SessionStore, error) { +// SessionRead read couchbase session by sid +func (cp *Provider) SessionRead(sid string) (session.Store, error) { cp.b = cp.getBucket() var doc []byte @@ -161,11 +168,13 @@ func (cp *CouchbaseProvider) SessionRead(sid string) (session.SessionStore, erro } } - cs := &CouchbaseSessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime} + cs := &SessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime} return cs, nil } -func (cp *CouchbaseProvider) SessionExist(sid string) bool { +// SessionExist Check couchbase session exist. +// it checkes sid exist or not. +func (cp *Provider) SessionExist(sid string) bool { cp.b = cp.getBucket() defer cp.b.Close() @@ -173,12 +182,12 @@ func (cp *CouchbaseProvider) SessionExist(sid string) bool { if err := cp.b.Get(sid, &doc); err != nil || doc == nil { return false - } else { - return true } + return true } -func (cp *CouchbaseProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) { +// SessionRegenerate remove oldsid and use sid to generate new session +func (cp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { cp.b = cp.getBucket() var doc []byte @@ -206,11 +215,12 @@ func (cp *CouchbaseProvider) SessionRegenerate(oldsid, sid string) (session.Sess } } - cs := &CouchbaseSessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime} + cs := &SessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime} return cs, nil } -func (cp *CouchbaseProvider) SessionDestroy(sid string) error { +// SessionDestroy Remove bucket in this couchbase +func (cp *Provider) SessionDestroy(sid string) error { cp.b = cp.getBucket() defer cp.b.Close() @@ -218,11 +228,13 @@ func (cp *CouchbaseProvider) SessionDestroy(sid string) error { return nil } -func (cp *CouchbaseProvider) SessionGC() { +// SessionGC Recycle +func (cp *Provider) SessionGC() { return } -func (cp *CouchbaseProvider) SessionAll() int { +// SessionAll return all active session +func (cp *Provider) SessionAll() int { return 0 } diff --git a/session/ledis/ledis_session.go b/session/ledis/ledis_session.go index 643b8817..68f37b08 100644 --- a/session/ledis/ledis_session.go +++ b/session/ledis/ledis_session.go @@ -1,4 +1,5 @@ -package session +// Package ledis provide session Provider +package ledis import ( "net/http" @@ -11,59 +12,58 @@ import ( "github.com/siddontang/ledisdb/ledis" ) -var ledispder = &LedisProvider{} +var ledispder = &Provider{} var c *ledis.DB -// ledis session store -type LedisSessionStore struct { +// SessionStore ledis session store +type SessionStore struct { sid string lock sync.RWMutex values map[interface{}]interface{} maxlifetime int64 } -// set value in ledis session -func (ls *LedisSessionStore) Set(key, value interface{}) error { +// Set value in ledis session +func (ls *SessionStore) Set(key, value interface{}) error { ls.lock.Lock() defer ls.lock.Unlock() ls.values[key] = value return nil } -// get value in ledis session -func (ls *LedisSessionStore) Get(key interface{}) interface{} { +// Get value in ledis session +func (ls *SessionStore) Get(key interface{}) interface{} { ls.lock.RLock() defer ls.lock.RUnlock() if v, ok := ls.values[key]; ok { return v - } else { - return nil } + return nil } -// delete value in ledis session -func (ls *LedisSessionStore) Delete(key interface{}) error { +// Delete value in ledis session +func (ls *SessionStore) Delete(key interface{}) error { ls.lock.Lock() defer ls.lock.Unlock() delete(ls.values, key) return nil } -// clear all values in ledis session -func (ls *LedisSessionStore) Flush() error { +// Flush clear all values in ledis session +func (ls *SessionStore) Flush() error { ls.lock.Lock() defer ls.lock.Unlock() ls.values = make(map[interface{}]interface{}) return nil } -// get ledis session id -func (ls *LedisSessionStore) SessionID() string { +// SessionID get ledis session id +func (ls *SessionStore) SessionID() string { return ls.sid } -// save session values to ledis -func (ls *LedisSessionStore) SessionRelease(w http.ResponseWriter) { +// SessionRelease save session values to ledis +func (ls *SessionStore) SessionRelease(w http.ResponseWriter) { b, err := session.EncodeGob(ls.values) if err != nil { return @@ -72,17 +72,17 @@ func (ls *LedisSessionStore) SessionRelease(w http.ResponseWriter) { c.Expire([]byte(ls.sid), ls.maxlifetime) } -// ledis session provider -type LedisProvider struct { +// Provider ledis session provider +type Provider struct { maxlifetime int64 savePath string db int } -// init ledis session +// SessionInit init ledis session // savepath like ledis server saveDataPath,pool size // e.g. 127.0.0.1:6379,100,astaxie -func (lp *LedisProvider) SessionInit(maxlifetime int64, savePath string) error { +func (lp *Provider) SessionInit(maxlifetime int64, savePath string) error { var err error lp.maxlifetime = maxlifetime configs := strings.Split(savePath, ",") @@ -106,8 +106,8 @@ func (lp *LedisProvider) SessionInit(maxlifetime int64, savePath string) error { return nil } -// read ledis session by sid -func (lp *LedisProvider) SessionRead(sid string) (session.SessionStore, error) { +// SessionRead read ledis session by sid +func (lp *Provider) SessionRead(sid string) (session.Store, error) { kvs, err := c.Get([]byte(sid)) var kv map[interface{}]interface{} if len(kvs) == 0 { @@ -118,22 +118,21 @@ func (lp *LedisProvider) SessionRead(sid string) (session.SessionStore, error) { return nil, err } } - ls := &LedisSessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime} + ls := &SessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime} return ls, nil } -// check ledis session exist by sid -func (lp *LedisProvider) SessionExist(sid string) bool { +// SessionExist check ledis session exist by sid +func (lp *Provider) SessionExist(sid string) bool { count, _ := c.Exists([]byte(sid)) if count == 0 { return false - } else { - return true } + return true } -// generate new sid for ledis session -func (lp *LedisProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) { +// SessionRegenerate generate new sid for ledis session +func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { count, _ := c.Exists([]byte(sid)) if count == 0 { // oldsid doesn't exists, set the new sid directly @@ -156,23 +155,23 @@ func (lp *LedisProvider) SessionRegenerate(oldsid, sid string) (session.SessionS return nil, err } } - ls := &LedisSessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime} + ls := &SessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime} return ls, nil } -// delete ledis session by id -func (lp *LedisProvider) SessionDestroy(sid string) error { +// SessionDestroy delete ledis session by id +func (lp *Provider) SessionDestroy(sid string) error { c.Del([]byte(sid)) return nil } -// Impelment method, no used. -func (lp *LedisProvider) SessionGC() { +// SessionGC Impelment method, no used. +func (lp *Provider) SessionGC() { return } -// @todo -func (lp *LedisProvider) SessionAll() int { +// SessionAll return all active session +func (lp *Provider) SessionAll() int { return 0 } func init() { diff --git a/session/memcache/sess_memcache.go b/session/memcache/sess_memcache.go index bb33075a..f1069bc9 100644 --- a/session/memcache/sess_memcache.go +++ b/session/memcache/sess_memcache.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// package memcache for session provider +// Package memcache for session provider // // depend on github.com/bradfitz/gomemcache/memcache // @@ -30,7 +30,7 @@ // } // // more docs: http://beego.me/docs/module/session.md -package session +package memcache import ( "net/http" @@ -45,56 +45,55 @@ import ( var mempder = &MemProvider{} var client *memcache.Client -// memcache session store -type MemcacheSessionStore struct { +// SessionStore memcache session store +type SessionStore struct { sid string lock sync.RWMutex values map[interface{}]interface{} maxlifetime int64 } -// set value in memcache session -func (rs *MemcacheSessionStore) Set(key, value interface{}) error { +// Set value in memcache session +func (rs *SessionStore) Set(key, value interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values[key] = value return nil } -// get value in memcache session -func (rs *MemcacheSessionStore) Get(key interface{}) interface{} { +// Get value in memcache session +func (rs *SessionStore) Get(key interface{}) interface{} { rs.lock.RLock() defer rs.lock.RUnlock() if v, ok := rs.values[key]; ok { return v - } else { - return nil } + return nil } -// delete value in memcache session -func (rs *MemcacheSessionStore) Delete(key interface{}) error { +// Delete value in memcache session +func (rs *SessionStore) Delete(key interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() delete(rs.values, key) return nil } -// clear all values in memcache session -func (rs *MemcacheSessionStore) Flush() error { +// Flush clear all values in memcache session +func (rs *SessionStore) Flush() error { rs.lock.Lock() defer rs.lock.Unlock() rs.values = make(map[interface{}]interface{}) return nil } -// get memcache session id -func (rs *MemcacheSessionStore) SessionID() string { +// SessionID get memcache session id +func (rs *SessionStore) SessionID() string { return rs.sid } -// save session values to memcache -func (rs *MemcacheSessionStore) SessionRelease(w http.ResponseWriter) { +// SessionRelease save session values to memcache +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { b, err := session.EncodeGob(rs.values) if err != nil { return @@ -103,7 +102,7 @@ func (rs *MemcacheSessionStore) SessionRelease(w http.ResponseWriter) { client.Set(&item) } -// memcahe session provider +// MemProvider memcache session provider type MemProvider struct { maxlifetime int64 conninfo []string @@ -111,7 +110,7 @@ type MemProvider struct { password string } -// init memcache session +// SessionInit init memcache session // savepath like // e.g. 127.0.0.1:9090 func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error { @@ -121,8 +120,8 @@ func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error { return nil } -// read memcache session by sid -func (rp *MemProvider) SessionRead(sid string) (session.SessionStore, error) { +// SessionRead read memcache session by sid +func (rp *MemProvider) SessionRead(sid string) (session.Store, error) { if client == nil { if err := rp.connectInit(); err != nil { return nil, err @@ -130,7 +129,7 @@ func (rp *MemProvider) SessionRead(sid string) (session.SessionStore, error) { } item, err := client.Get(sid) if err != nil && err == memcache.ErrCacheMiss { - rs := &MemcacheSessionStore{sid: sid, values: make(map[interface{}]interface{}), maxlifetime: rp.maxlifetime} + rs := &SessionStore{sid: sid, values: make(map[interface{}]interface{}), maxlifetime: rp.maxlifetime} return rs, nil } var kv map[interface{}]interface{} @@ -142,11 +141,11 @@ func (rp *MemProvider) SessionRead(sid string) (session.SessionStore, error) { return nil, err } } - rs := &MemcacheSessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime} + rs := &SessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime} return rs, nil } -// check memcache session exist by sid +// SessionExist check memcache session exist by sid func (rp *MemProvider) SessionExist(sid string) bool { if client == nil { if err := rp.connectInit(); err != nil { @@ -155,13 +154,12 @@ func (rp *MemProvider) SessionExist(sid string) bool { } if item, err := client.Get(sid); err != nil || len(item.Value) == 0 { return false - } else { - return true } + return true } -// generate new sid for memcache session -func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) { +// SessionRegenerate generate new sid for memcache session +func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.Store, error) { if client == nil { if err := rp.connectInit(); err != nil { return nil, err @@ -195,11 +193,11 @@ func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.SessionSto } } - rs := &MemcacheSessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime} + rs := &SessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime} return rs, nil } -// delete memcache session by id +// SessionDestroy delete memcache session by id func (rp *MemProvider) SessionDestroy(sid string) error { if client == nil { if err := rp.connectInit(); err != nil { @@ -219,12 +217,12 @@ func (rp *MemProvider) connectInit() error { return nil } -// Impelment method, no used. +// SessionGC Impelment method, no used. func (rp *MemProvider) SessionGC() { return } -// @todo +// SessionAll return all activeSession func (rp *MemProvider) SessionAll() int { return 0 } diff --git a/session/mysql/sess_mysql.go b/session/mysql/sess_mysql.go index 76a13932..969d26c9 100644 --- a/session/mysql/sess_mysql.go +++ b/session/mysql/sess_mysql.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// package mysql for session provider +// Package mysql for session provider // // depends on github.com/go-sql-driver/mysql: // @@ -38,7 +38,7 @@ // } // // more docs: http://beego.me/docs/module/session.md -package session +package mysql import ( "database/sql" @@ -47,82 +47,85 @@ import ( "time" "github.com/astaxie/beego/session" - + // import mysql driver _ "github.com/go-sql-driver/mysql" ) -var mysqlpder = &MysqlProvider{} +var ( + // TableName store the session in MySQL + TableName = "session" + mysqlpder = &Provider{} +) -// mysql session store -type MysqlSessionStore struct { +// SessionStore mysql session store +type SessionStore struct { c *sql.DB sid string lock sync.RWMutex values map[interface{}]interface{} } -// set value in mysql session. +// Set value in mysql session. // it is temp value in map. -func (st *MysqlSessionStore) Set(key, value interface{}) error { +func (st *SessionStore) Set(key, value interface{}) error { st.lock.Lock() defer st.lock.Unlock() st.values[key] = value return nil } -// get value from mysql session -func (st *MysqlSessionStore) Get(key interface{}) interface{} { +// Get value from mysql session +func (st *SessionStore) Get(key interface{}) interface{} { st.lock.RLock() defer st.lock.RUnlock() if v, ok := st.values[key]; ok { return v - } else { - return nil } + return nil } -// delete value in mysql session -func (st *MysqlSessionStore) Delete(key interface{}) error { +// Delete value in mysql session +func (st *SessionStore) Delete(key interface{}) error { st.lock.Lock() defer st.lock.Unlock() delete(st.values, key) return nil } -// clear all values in mysql session -func (st *MysqlSessionStore) Flush() error { +// Flush clear all values in mysql session +func (st *SessionStore) Flush() error { st.lock.Lock() defer st.lock.Unlock() st.values = make(map[interface{}]interface{}) return nil } -// get session id of this mysql session store -func (st *MysqlSessionStore) SessionID() string { +// SessionID get session id of this mysql session store +func (st *SessionStore) SessionID() string { return st.sid } -// save mysql session values to database. +// SessionRelease save mysql session values to database. // must call this method to save values to database. -func (st *MysqlSessionStore) SessionRelease(w http.ResponseWriter) { +func (st *SessionStore) SessionRelease(w http.ResponseWriter) { defer st.c.Close() b, err := session.EncodeGob(st.values) if err != nil { return } - st.c.Exec("UPDATE session set `session_data`=?, `session_expiry`=? where session_key=?", + st.c.Exec("UPDATE "+TableName+" set `session_data`=?, `session_expiry`=? where session_key=?", b, time.Now().Unix(), st.sid) } -// mysql session provider -type MysqlProvider struct { +// Provider mysql session provider +type Provider struct { maxlifetime int64 savePath string } // connect to mysql -func (mp *MysqlProvider) connectInit() *sql.DB { +func (mp *Provider) connectInit() *sql.DB { db, e := sql.Open("mysql", mp.savePath) if e != nil { return nil @@ -130,22 +133,22 @@ func (mp *MysqlProvider) connectInit() *sql.DB { return db } -// init mysql session. +// SessionInit init mysql session. // savepath is the connection string of mysql. -func (mp *MysqlProvider) SessionInit(maxlifetime int64, savePath string) error { +func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { mp.maxlifetime = maxlifetime mp.savePath = savePath return nil } -// get mysql session by sid -func (mp *MysqlProvider) SessionRead(sid string) (session.SessionStore, error) { +// SessionRead get mysql session by sid +func (mp *Provider) SessionRead(sid string) (session.Store, error) { c := mp.connectInit() - row := c.QueryRow("select session_data from session where session_key=?", sid) + row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) var sessiondata []byte err := row.Scan(&sessiondata) if err == sql.ErrNoRows { - c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", + c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", sid, "", time.Now().Unix()) } var kv map[interface{}]interface{} @@ -157,34 +160,33 @@ func (mp *MysqlProvider) SessionRead(sid string) (session.SessionStore, error) { return nil, err } } - rs := &MysqlSessionStore{c: c, sid: sid, values: kv} + rs := &SessionStore{c: c, sid: sid, values: kv} return rs, nil } -// check mysql session exist -func (mp *MysqlProvider) SessionExist(sid string) bool { +// SessionExist check mysql session exist +func (mp *Provider) SessionExist(sid string) bool { c := mp.connectInit() defer c.Close() - row := c.QueryRow("select session_data from session where session_key=?", sid) + row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) var sessiondata []byte err := row.Scan(&sessiondata) if err == sql.ErrNoRows { return false - } else { - return true } + return true } -// generate new sid for mysql session -func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) { +// SessionRegenerate generate new sid for mysql session +func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { c := mp.connectInit() - row := c.QueryRow("select session_data from session where session_key=?", oldsid) + row := c.QueryRow("select session_data from "+TableName+" where session_key=?", oldsid) var sessiondata []byte err := row.Scan(&sessiondata) if err == sql.ErrNoRows { - c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", oldsid, "", time.Now().Unix()) + c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", oldsid, "", time.Now().Unix()) } - c.Exec("update session set `session_key`=? where session_key=?", sid, oldsid) + c.Exec("update "+TableName+" set `session_key`=? where session_key=?", sid, oldsid) var kv map[interface{}]interface{} if len(sessiondata) == 0 { kv = make(map[interface{}]interface{}) @@ -194,32 +196,32 @@ func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (session.SessionS return nil, err } } - rs := &MysqlSessionStore{c: c, sid: sid, values: kv} + rs := &SessionStore{c: c, sid: sid, values: kv} return rs, nil } -// delete mysql session by sid -func (mp *MysqlProvider) SessionDestroy(sid string) error { +// SessionDestroy delete mysql session by sid +func (mp *Provider) SessionDestroy(sid string) error { c := mp.connectInit() - c.Exec("DELETE FROM session where session_key=?", sid) + c.Exec("DELETE FROM "+TableName+" where session_key=?", sid) c.Close() return nil } -// delete expired values in mysql session -func (mp *MysqlProvider) SessionGC() { +// SessionGC delete expired values in mysql session +func (mp *Provider) SessionGC() { c := mp.connectInit() - c.Exec("DELETE from session where session_expiry < ?", time.Now().Unix()-mp.maxlifetime) + c.Exec("DELETE from "+TableName+" where session_expiry < ?", time.Now().Unix()-mp.maxlifetime) c.Close() return } -// count values in mysql session -func (mp *MysqlProvider) SessionAll() int { +// SessionAll count values in mysql session +func (mp *Provider) SessionAll() int { c := mp.connectInit() defer c.Close() var total int - err := c.QueryRow("SELECT count(*) as num from session").Scan(&total) + err := c.QueryRow("SELECT count(*) as num from " + TableName).Scan(&total) if err != nil { return 0 } diff --git a/session/postgres/sess_postgresql.go b/session/postgres/sess_postgresql.go index ac9a1612..73f9c13a 100644 --- a/session/postgres/sess_postgresql.go +++ b/session/postgres/sess_postgresql.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// package postgresql for session provider +// Package postgres for session provider // // depends on github.com/lib/pq: // @@ -27,9 +27,9 @@ // session_expiry timestamp NOT NULL, // CONSTRAINT session_key PRIMARY KEY(session_key) // ); - +// // will be activated with these settings in app.conf: - +// // SessionOn = true // SessionProvider = postgresql // SessionSavePath = "user=a password=b dbname=c sslmode=disable" @@ -48,7 +48,7 @@ // } // // more docs: http://beego.me/docs/module/session.md -package session +package postgres import ( "database/sql" @@ -57,64 +57,63 @@ import ( "time" "github.com/astaxie/beego/session" - + // import postgresql Driver _ "github.com/lib/pq" ) -var postgresqlpder = &PostgresqlProvider{} +var postgresqlpder = &Provider{} -// postgresql session store -type PostgresqlSessionStore struct { +// SessionStore postgresql session store +type SessionStore struct { c *sql.DB sid string lock sync.RWMutex values map[interface{}]interface{} } -// set value in postgresql session. +// Set value in postgresql session. // it is temp value in map. -func (st *PostgresqlSessionStore) Set(key, value interface{}) error { +func (st *SessionStore) Set(key, value interface{}) error { st.lock.Lock() defer st.lock.Unlock() st.values[key] = value return nil } -// get value from postgresql session -func (st *PostgresqlSessionStore) Get(key interface{}) interface{} { +// Get value from postgresql session +func (st *SessionStore) Get(key interface{}) interface{} { st.lock.RLock() defer st.lock.RUnlock() if v, ok := st.values[key]; ok { return v - } else { - return nil } + return nil } -// delete value in postgresql session -func (st *PostgresqlSessionStore) Delete(key interface{}) error { +// Delete value in postgresql session +func (st *SessionStore) Delete(key interface{}) error { st.lock.Lock() defer st.lock.Unlock() delete(st.values, key) return nil } -// clear all values in postgresql session -func (st *PostgresqlSessionStore) Flush() error { +// Flush clear all values in postgresql session +func (st *SessionStore) Flush() error { st.lock.Lock() defer st.lock.Unlock() st.values = make(map[interface{}]interface{}) return nil } -// get session id of this postgresql session store -func (st *PostgresqlSessionStore) SessionID() string { +// SessionID get session id of this postgresql session store +func (st *SessionStore) SessionID() string { return st.sid } -// save postgresql session values to database. +// SessionRelease save postgresql session values to database. // must call this method to save values to database. -func (st *PostgresqlSessionStore) SessionRelease(w http.ResponseWriter) { +func (st *SessionStore) SessionRelease(w http.ResponseWriter) { defer st.c.Close() b, err := session.EncodeGob(st.values) if err != nil { @@ -125,14 +124,14 @@ func (st *PostgresqlSessionStore) SessionRelease(w http.ResponseWriter) { } -// postgresql session provider -type PostgresqlProvider struct { +// Provider postgresql session provider +type Provider struct { maxlifetime int64 savePath string } // connect to postgresql -func (mp *PostgresqlProvider) connectInit() *sql.DB { +func (mp *Provider) connectInit() *sql.DB { db, e := sql.Open("postgres", mp.savePath) if e != nil { return nil @@ -140,16 +139,16 @@ func (mp *PostgresqlProvider) connectInit() *sql.DB { return db } -// init postgresql session. +// SessionInit init postgresql session. // savepath is the connection string of postgresql. -func (mp *PostgresqlProvider) SessionInit(maxlifetime int64, savePath string) error { +func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { mp.maxlifetime = maxlifetime mp.savePath = savePath return nil } -// get postgresql session by sid -func (mp *PostgresqlProvider) SessionRead(sid string) (session.SessionStore, error) { +// SessionRead get postgresql session by sid +func (mp *Provider) SessionRead(sid string) (session.Store, error) { c := mp.connectInit() row := c.QueryRow("select session_data from session where session_key=$1", sid) var sessiondata []byte @@ -174,12 +173,12 @@ func (mp *PostgresqlProvider) SessionRead(sid string) (session.SessionStore, err return nil, err } } - rs := &PostgresqlSessionStore{c: c, sid: sid, values: kv} + rs := &SessionStore{c: c, sid: sid, values: kv} return rs, nil } -// check postgresql session exist -func (mp *PostgresqlProvider) SessionExist(sid string) bool { +// SessionExist check postgresql session exist +func (mp *Provider) SessionExist(sid string) bool { c := mp.connectInit() defer c.Close() row := c.QueryRow("select session_data from session where session_key=$1", sid) @@ -188,13 +187,12 @@ func (mp *PostgresqlProvider) SessionExist(sid string) bool { if err == sql.ErrNoRows { return false - } else { - return true } + return true } -// generate new sid for postgresql session -func (mp *PostgresqlProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) { +// SessionRegenerate generate new sid for postgresql session +func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { c := mp.connectInit() row := c.QueryRow("select session_data from session where session_key=$1", oldsid) var sessiondata []byte @@ -213,28 +211,28 @@ func (mp *PostgresqlProvider) SessionRegenerate(oldsid, sid string) (session.Ses return nil, err } } - rs := &PostgresqlSessionStore{c: c, sid: sid, values: kv} + rs := &SessionStore{c: c, sid: sid, values: kv} return rs, nil } -// delete postgresql session by sid -func (mp *PostgresqlProvider) SessionDestroy(sid string) error { +// SessionDestroy delete postgresql session by sid +func (mp *Provider) SessionDestroy(sid string) error { c := mp.connectInit() c.Exec("DELETE FROM session where session_key=$1", sid) c.Close() return nil } -// delete expired values in postgresql session -func (mp *PostgresqlProvider) SessionGC() { +// SessionGC delete expired values in postgresql session +func (mp *Provider) SessionGC() { c := mp.connectInit() c.Exec("DELETE from session where EXTRACT(EPOCH FROM (current_timestamp - session_expiry)) > $1", mp.maxlifetime) c.Close() return } -// count values in postgresql session -func (mp *PostgresqlProvider) SessionAll() int { +// SessionAll count values in postgresql session +func (mp *Provider) SessionAll() int { c := mp.connectInit() defer c.Close() var total int diff --git a/session/redis/sess_redis.go b/session/redis/sess_redis.go index d31feb2c..99a672d5 100644 --- a/session/redis/sess_redis.go +++ b/session/redis/sess_redis.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// package redis for session provider +// Package redis for session provider // // depend on github.com/garyburd/redigo/redis // @@ -30,7 +30,7 @@ // } // // more docs: http://beego.me/docs/module/session.md -package session +package redis import ( "net/http" @@ -43,15 +43,13 @@ import ( "github.com/garyburd/redigo/redis" ) -var redispder = &RedisProvider{} +var redispder = &Provider{} // redis max pool size -var MAX_POOL_SIZE = 100 +var MaxPoolSize = 100 -var redisPool chan redis.Conn - -// redis session store -type RedisSessionStore struct { +// SessionStore redis session store +type SessionStore struct { p *redis.Pool sid string lock sync.RWMutex @@ -59,61 +57,63 @@ type RedisSessionStore struct { maxlifetime int64 } -// set value in redis session -func (rs *RedisSessionStore) Set(key, value interface{}) error { +// Set value in redis session +func (rs *SessionStore) Set(key, value interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values[key] = value return nil } -// get value in redis session -func (rs *RedisSessionStore) Get(key interface{}) interface{} { +// Get value in redis session +func (rs *SessionStore) Get(key interface{}) interface{} { rs.lock.RLock() defer rs.lock.RUnlock() if v, ok := rs.values[key]; ok { return v - } else { - return nil } + return nil } -// delete value in redis session -func (rs *RedisSessionStore) Delete(key interface{}) error { +// Delete value in redis session +func (rs *SessionStore) Delete(key interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() delete(rs.values, key) return nil } -// clear all values in redis session -func (rs *RedisSessionStore) Flush() error { +// Flush clear all values in redis session +func (rs *SessionStore) Flush() error { rs.lock.Lock() defer rs.lock.Unlock() rs.values = make(map[interface{}]interface{}) return nil } -// get redis session id -func (rs *RedisSessionStore) SessionID() string { +// SessionID get redis session id +func (rs *SessionStore) SessionID() string { return rs.sid } -// save session values to redis -func (rs *RedisSessionStore) SessionRelease(w http.ResponseWriter) { - c := rs.p.Get() - defer c.Close() +// SessionRelease save session values to redis +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { b, err := session.EncodeGob(rs.values) if err != nil { return } - c.Do("SETEX", rs.sid, rs.maxlifetime, string(b)) + c := rs.p.Get() + defer c.Close() + // Update session value if exists or error. + if existed, err := redis.Bool(c.Do("EXISTS", rs.sid)); existed || err != nil { + c.Do("SETEX", rs.sid, rs.maxlifetime, string(b)) + } } -// redis session provider -type RedisProvider struct { +// Provider redis session provider +type Provider struct { maxlifetime int64 savePath string poolsize int @@ -122,10 +122,10 @@ type RedisProvider struct { poollist *redis.Pool } -// init redis session +// SessionInit init redis session // savepath like redis server addr,pool size,password,dbnum // e.g. 127.0.0.1:6379,100,astaxie,0 -func (rp *RedisProvider) SessionInit(maxlifetime int64, savePath string) error { +func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { rp.maxlifetime = maxlifetime configs := strings.Split(savePath, ",") if len(configs) > 0 { @@ -134,12 +134,12 @@ func (rp *RedisProvider) SessionInit(maxlifetime int64, savePath string) error { if len(configs) > 1 { poolsize, err := strconv.Atoi(configs[1]) if err != nil || poolsize <= 0 { - rp.poolsize = MAX_POOL_SIZE + rp.poolsize = MaxPoolSize } else { rp.poolsize = poolsize } } else { - rp.poolsize = MAX_POOL_SIZE + rp.poolsize = MaxPoolSize } if len(configs) > 2 { rp.password = configs[2] @@ -176,8 +176,8 @@ func (rp *RedisProvider) SessionInit(maxlifetime int64, savePath string) error { return rp.poollist.Get().Err() } -// read redis session by sid -func (rp *RedisProvider) SessionRead(sid string) (session.SessionStore, error) { +// SessionRead read redis session by sid +func (rp *Provider) SessionRead(sid string) (session.Store, error) { c := rp.poollist.Get() defer c.Close() @@ -192,24 +192,23 @@ func (rp *RedisProvider) SessionRead(sid string) (session.SessionStore, error) { } } - rs := &RedisSessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} + rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} return rs, nil } -// check redis session exist by sid -func (rp *RedisProvider) SessionExist(sid string) bool { +// SessionExist check redis session exist by sid +func (rp *Provider) SessionExist(sid string) bool { c := rp.poollist.Get() defer c.Close() if existed, err := redis.Int(c.Do("EXISTS", sid)); err != nil || existed == 0 { return false - } else { - return true } + return true } -// generate new sid for redis session -func (rp *RedisProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) { +// SessionRegenerate generate new sid for redis session +func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { c := rp.poollist.Get() defer c.Close() @@ -234,12 +233,12 @@ func (rp *RedisProvider) SessionRegenerate(oldsid, sid string) (session.SessionS } } - rs := &RedisSessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} + rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} return rs, nil } -// delete redis session by id -func (rp *RedisProvider) SessionDestroy(sid string) error { +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(sid string) error { c := rp.poollist.Get() defer c.Close() @@ -247,13 +246,13 @@ func (rp *RedisProvider) SessionDestroy(sid string) error { return nil } -// Impelment method, no used. -func (rp *RedisProvider) SessionGC() { +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC() { return } -// @todo -func (rp *RedisProvider) SessionAll() int { +// SessionAll return all activeSession +func (rp *Provider) SessionAll() int { return 0 } diff --git a/session/redis/sess_redis_test.go b/session/redis/sess_redis_test.go new file mode 100644 index 00000000..0c634428 --- /dev/null +++ b/session/redis/sess_redis_test.go @@ -0,0 +1,64 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package redis + +import ( + "testing" +) + +func TestSessionRelease(t *testing.T) { + + provider := Provider{} + if err := provider.SessionInit(3, "127.0.0.1:6379"); err != nil { + t.Fatal("init session err,", err) + } + + sessionID := "beegosessionid_00001" + + session, err := provider.SessionRegenerate("", sessionID) + if err != nil { + t.Fatal("new session error,", err) + } + + // set item. + session.Set("k1", "v1") + // update. + session.SessionRelease(nil) + + session, err = provider.SessionRead(sessionID) + if err != nil { + t.Fatal("read session error,", err) + } + if v1 := session.Get("k1"); v1 == nil { + t.Fatal("want v1 got nil") + } else if v, _ := v1.(string); v != "v1" { + t.Fatalf("want v1 got %s", v) + } + + // delete + provider.SessionDestroy(sessionID) + session.Set("k2", "v2") + + session.SessionRelease(nil) + + session, err = provider.SessionRead(sessionID) + if err != nil { + t.Fatal("read session error,", err) + } + if session.Get("k1") != nil || session.Get("k2") != nil { + t.Fatalf("want emtpy session value,got %s,%s", session.Get("k1"), session.Get("k2")) + } + +} diff --git a/session/sess_cookie.go b/session/sess_cookie.go index 01dc505c..3fefa360 100644 --- a/session/sess_cookie.go +++ b/session/sess_cookie.go @@ -25,7 +25,7 @@ import ( var cookiepder = &CookieProvider{} -// Cookie SessionStore +// CookieSessionStore Cookie SessionStore type CookieSessionStore struct { sid string values map[interface{}]interface{} // session data @@ -47,9 +47,8 @@ func (st *CookieSessionStore) Get(key interface{}) interface{} { defer st.lock.RUnlock() if v, ok := st.values[key]; ok { return v - } else { - return nil } + return nil } // Delete value in cookie session @@ -60,7 +59,7 @@ func (st *CookieSessionStore) Delete(key interface{}) error { return nil } -// Clean all values in cookie session +// Flush Clean all values in cookie session func (st *CookieSessionStore) Flush() error { st.lock.Lock() defer st.lock.Unlock() @@ -68,12 +67,12 @@ func (st *CookieSessionStore) Flush() error { return nil } -// Return id of this cookie session +// SessionID Return id of this cookie session func (st *CookieSessionStore) SessionID() string { return st.sid } -// Write cookie session to http response cookie +// SessionRelease Write cookie session to http response cookie func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) { str, err := encodeCookie(cookiepder.block, cookiepder.config.SecurityKey, @@ -101,14 +100,14 @@ type cookieConfig struct { Maxage int `json:"maxage"` } -// Cookie session provider +// CookieProvider Cookie session provider type CookieProvider struct { maxlifetime int64 config *cookieConfig block cipher.Block } -// Init cookie session provider with max lifetime and config json. +// SessionInit Init cookie session provider with max lifetime and config json. // maxlifetime is ignored. // json config: // securityKey - hash string @@ -136,9 +135,9 @@ func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error return nil } -// Get SessionStore in cooke. +// SessionRead Get SessionStore in cooke. // decode cooke string to map and put into SessionStore with sid. -func (pder *CookieProvider) SessionRead(sid string) (SessionStore, error) { +func (pder *CookieProvider) SessionRead(sid string) (Store, error) { maps, _ := decodeCookie(pder.block, pder.config.SecurityKey, pder.config.SecurityName, @@ -150,32 +149,32 @@ func (pder *CookieProvider) SessionRead(sid string) (SessionStore, error) { return rs, nil } -// Cookie session is always existed +// SessionExist Cookie session is always existed func (pder *CookieProvider) SessionExist(sid string) bool { return true } -// Implement method, no used. -func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) { +// SessionRegenerate Implement method, no used. +func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (Store, error) { return nil, nil } -// Implement method, no used. +// SessionDestroy Implement method, no used. func (pder *CookieProvider) SessionDestroy(sid string) error { return nil } -// Implement method, no used. +// SessionGC Implement method, no used. func (pder *CookieProvider) SessionGC() { return } -// Implement method, return 0. +// SessionAll Implement method, return 0. func (pder *CookieProvider) SessionAll() int { return 0 } -// Implement method, no used. +// SessionUpdate Implement method, no used. func (pder *CookieProvider) SessionUpdate(sid string) error { return nil } diff --git a/session/sess_cookie_test.go b/session/sess_cookie_test.go index fe3ac806..b5982260 100644 --- a/session/sess_cookie_test.go +++ b/session/sess_cookie_test.go @@ -53,3 +53,44 @@ func TestCookie(t *testing.T) { } } } + +func TestDestorySessionCookie(t *testing.T) { + config := `{"cookieName":"gosessionid","enableSetCookie":true,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` + globalSessions, err := NewManager("cookie", config) + if err != nil { + t.Fatal("init cookie session err", err) + } + + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + session, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("session start err,", err) + } + + // request again ,will get same sesssion id . + r1, _ := http.NewRequest("GET", "/", nil) + r1.Header.Set("Cookie", w.Header().Get("Set-Cookie")) + w = httptest.NewRecorder() + newSession, err := globalSessions.SessionStart(w, r1) + if err != nil { + t.Fatal("session start err,", err) + } + if newSession.SessionID() != session.SessionID() { + t.Fatal("get cookie session id is not the same again.") + } + + // After destory session , will get a new session id . + globalSessions.SessionDestroy(w, r1) + r2, _ := http.NewRequest("GET", "/", nil) + r2.Header.Set("Cookie", w.Header().Get("Set-Cookie")) + + w = httptest.NewRecorder() + newSession, err = globalSessions.SessionStart(w, r2) + if err != nil { + t.Fatal("session start error") + } + if newSession.SessionID() == session.SessionID() { + t.Fatal("after destory session and reqeust again ,get cookie session id is same.") + } +} diff --git a/session/sess_file.go b/session/sess_file.go index b1084acf..9265b030 100644 --- a/session/sess_file.go +++ b/session/sess_file.go @@ -32,7 +32,7 @@ var ( gcmaxlifetime int64 ) -// File session store +// FileSessionStore File session store type FileSessionStore struct { sid string lock sync.RWMutex @@ -53,9 +53,8 @@ func (fs *FileSessionStore) Get(key interface{}) interface{} { defer fs.lock.RUnlock() if v, ok := fs.values[key]; ok { return v - } else { - return nil } + return nil } // Delete value in file session by given key @@ -66,7 +65,7 @@ func (fs *FileSessionStore) Delete(key interface{}) error { return nil } -// Clean all values in file session +// Flush Clean all values in file session func (fs *FileSessionStore) Flush() error { fs.lock.Lock() defer fs.lock.Unlock() @@ -74,12 +73,12 @@ func (fs *FileSessionStore) Flush() error { return nil } -// Get file session store id +// SessionID Get file session store id func (fs *FileSessionStore) SessionID() string { return fs.sid } -// Write file session to local file with Gob string +// SessionRelease Write file session to local file with Gob string func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) { b, err := EncodeGob(fs.values) if err != nil { @@ -100,14 +99,14 @@ func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) { f.Close() } -// File session provider +// FileProvider File session provider type FileProvider struct { lock sync.RWMutex maxlifetime int64 savePath string } -// Init file session provider. +// SessionInit Init file session provider. // savePath sets the session files path. func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error { fp.maxlifetime = maxlifetime @@ -115,10 +114,10 @@ func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error { return nil } -// Read file session by sid. +// SessionRead Read file session by sid. // if file is not exist, create it. // the file path is generated from sid string. -func (fp *FileProvider) SessionRead(sid string) (SessionStore, error) { +func (fp *FileProvider) SessionRead(sid string) (Store, error) { filepder.lock.Lock() defer filepder.lock.Unlock() @@ -154,7 +153,7 @@ func (fp *FileProvider) SessionRead(sid string) (SessionStore, error) { return ss, nil } -// Check file session exist. +// SessionExist Check file session exist. // it checkes the file named from sid exist or not. func (fp *FileProvider) SessionExist(sid string) bool { filepder.lock.Lock() @@ -163,12 +162,11 @@ func (fp *FileProvider) SessionExist(sid string) bool { _, err := os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) if err == nil { return true - } else { - return false } + return false } -// Remove all files in this save path +// SessionDestroy Remove all files in this save path func (fp *FileProvider) SessionDestroy(sid string) error { filepder.lock.Lock() defer filepder.lock.Unlock() @@ -176,7 +174,7 @@ func (fp *FileProvider) SessionDestroy(sid string) error { return nil } -// Recycle files in save path +// SessionGC Recycle files in save path func (fp *FileProvider) SessionGC() { filepder.lock.Lock() defer filepder.lock.Unlock() @@ -185,7 +183,7 @@ func (fp *FileProvider) SessionGC() { filepath.Walk(fp.savePath, gcpath) } -// Get active file session number. +// SessionAll Get active file session number. // it walks save path to count files. func (fp *FileProvider) SessionAll() int { a := &activeSession{} @@ -199,9 +197,9 @@ func (fp *FileProvider) SessionAll() int { return a.total } -// Generate new sid for file session. +// SessionRegenerate Generate new sid for file session. // it delete old file and create new file named from new sid. -func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) { +func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) { filepder.lock.Lock() defer filepder.lock.Unlock() @@ -269,14 +267,14 @@ type activeSession struct { total int } -func (self *activeSession) visit(paths string, f os.FileInfo, err error) error { +func (as *activeSession) visit(paths string, f os.FileInfo, err error) error { if err != nil { return err } if f.IsDir() { return nil } - self.total = self.total + 1 + as.total = as.total + 1 return nil } diff --git a/session/sess_mem.go b/session/sess_mem.go index dd066703..dd61ef57 100644 --- a/session/sess_mem.go +++ b/session/sess_mem.go @@ -23,7 +23,7 @@ import ( var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)} -// memory session store. +// MemSessionStore memory session store. // it saved sessions in a map in memory. type MemSessionStore struct { sid string //session id @@ -32,7 +32,7 @@ type MemSessionStore struct { lock sync.RWMutex } -// set value to memory session +// Set value to memory session func (st *MemSessionStore) Set(key, value interface{}) error { st.lock.Lock() defer st.lock.Unlock() @@ -40,18 +40,17 @@ func (st *MemSessionStore) Set(key, value interface{}) error { return nil } -// get value from memory session by key +// Get value from memory session by key func (st *MemSessionStore) Get(key interface{}) interface{} { st.lock.RLock() defer st.lock.RUnlock() if v, ok := st.value[key]; ok { return v - } else { - return nil } + return nil } -// delete in memory session by key +// Delete in memory session by key func (st *MemSessionStore) Delete(key interface{}) error { st.lock.Lock() defer st.lock.Unlock() @@ -59,7 +58,7 @@ func (st *MemSessionStore) Delete(key interface{}) error { return nil } -// clear all values in memory session +// Flush clear all values in memory session func (st *MemSessionStore) Flush() error { st.lock.Lock() defer st.lock.Unlock() @@ -67,15 +66,16 @@ func (st *MemSessionStore) Flush() error { return nil } -// get this id of memory session store +// SessionID get this id of memory session store func (st *MemSessionStore) SessionID() string { return st.sid } -// Implement method, no used. +// SessionRelease Implement method, no used. func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) { } +// MemProvider Implement the provider interface type MemProvider struct { lock sync.RWMutex // locker sessions map[string]*list.Element // map in memory @@ -84,44 +84,42 @@ type MemProvider struct { savePath string } -// init memory session +// SessionInit init memory session func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error { pder.maxlifetime = maxlifetime pder.savePath = savePath return nil } -// get memory session store by sid -func (pder *MemProvider) SessionRead(sid string) (SessionStore, error) { +// SessionRead get memory session store by sid +func (pder *MemProvider) SessionRead(sid string) (Store, error) { pder.lock.RLock() if element, ok := pder.sessions[sid]; ok { go pder.SessionUpdate(sid) pder.lock.RUnlock() return element.Value.(*MemSessionStore), nil - } else { - pder.lock.RUnlock() - pder.lock.Lock() - newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})} - element := pder.list.PushBack(newsess) - pder.sessions[sid] = element - pder.lock.Unlock() - return newsess, nil } + pder.lock.RUnlock() + pder.lock.Lock() + newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})} + element := pder.list.PushBack(newsess) + pder.sessions[sid] = element + pder.lock.Unlock() + return newsess, nil } -// check session store exist in memory session by sid +// SessionExist check session store exist in memory session by sid func (pder *MemProvider) SessionExist(sid string) bool { pder.lock.RLock() defer pder.lock.RUnlock() if _, ok := pder.sessions[sid]; ok { return true - } else { - return false } + return false } -// generate new sid for session store in memory session -func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) { +// SessionRegenerate generate new sid for session store in memory session +func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (Store, error) { pder.lock.RLock() if element, ok := pder.sessions[oldsid]; ok { go pder.SessionUpdate(oldsid) @@ -132,18 +130,17 @@ func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (SessionStore, er delete(pder.sessions, oldsid) pder.lock.Unlock() return element.Value.(*MemSessionStore), nil - } else { - pder.lock.RUnlock() - pder.lock.Lock() - newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})} - element := pder.list.PushBack(newsess) - pder.sessions[sid] = element - pder.lock.Unlock() - return newsess, nil } + pder.lock.RUnlock() + pder.lock.Lock() + newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})} + element := pder.list.PushBack(newsess) + pder.sessions[sid] = element + pder.lock.Unlock() + return newsess, nil } -// delete session store in memory session by id +// SessionDestroy delete session store in memory session by id func (pder *MemProvider) SessionDestroy(sid string) error { pder.lock.Lock() defer pder.lock.Unlock() @@ -155,7 +152,7 @@ func (pder *MemProvider) SessionDestroy(sid string) error { return nil } -// clean expired session stores in memory session +// SessionGC clean expired session stores in memory session func (pder *MemProvider) SessionGC() { pder.lock.RLock() for { @@ -177,12 +174,12 @@ func (pder *MemProvider) SessionGC() { pder.lock.RUnlock() } -// get count number of memory session +// SessionAll get count number of memory session func (pder *MemProvider) SessionAll() int { return pder.list.Len() } -// expand time of session store by id in memory session +// SessionUpdate expand time of session store by id in memory session func (pder *MemProvider) SessionUpdate(sid string) error { pder.lock.Lock() defer pder.lock.Unlock() diff --git a/session/sess_utils.go b/session/sess_utils.go index 9ae74528..d7db5ba8 100644 --- a/session/sess_utils.go +++ b/session/sess_utils.go @@ -43,6 +43,7 @@ func init() { gob.Register(map[int]int64{}) } +// EncodeGob encode the obj to gob func EncodeGob(obj map[interface{}]interface{}) ([]byte, error) { for _, v := range obj { gob.Register(v) @@ -56,6 +57,7 @@ func EncodeGob(obj map[interface{}]interface{}) ([]byte, error) { return buf.Bytes(), nil } +// DecodeGob decode data to map func DecodeGob(encoded []byte) (map[interface{}]interface{}, error) { buf := bytes.NewBuffer(encoded) dec := gob.NewDecoder(buf) @@ -178,11 +180,11 @@ func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime return nil, err } // 5. DecodeGob. - if dst, err := DecodeGob(b); err != nil { + dst, err := DecodeGob(b) + if err != nil { return nil, err - } else { - return dst, nil } + return dst, nil } // Encoding ------------------------------------------------------------------- diff --git a/session/session.go b/session/session.go index f0895de1..39d475fc 100644 --- a/session/session.go +++ b/session/session.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// package session provider +// Package session provider // // Usage: // import( @@ -20,7 +20,7 @@ // ) // // func init() { -// globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid", "enableSetCookie,omitempty": true, "gclifetime":3600, "maxLifetime": 3600, "secure": false, "sessionIDHashFunc": "sha1", "sessionIDHashKey": "", "cookieLifeTime": 3600, "providerConfig": ""}`) +// globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid", "enableSetCookie,omitempty": true, "gclifetime":3600, "maxLifetime": 3600, "secure": false, "cookieLifeTime": 3600, "providerConfig": ""}`) // go globalSessions.GC() // } // @@ -37,8 +37,8 @@ import ( "time" ) -// SessionStore contains all data for one session process with specific id. -type SessionStore interface { +// Store contains all data for one session process with specific id. +type Store interface { Set(key, value interface{}) error //set session value Get(key interface{}) interface{} //get session value Delete(key interface{}) error //delete session value @@ -51,9 +51,9 @@ type SessionStore interface { // it can operate a SessionStore by its id. type Provider interface { SessionInit(gclifetime int64, config string) error - SessionRead(sid string) (SessionStore, error) + SessionRead(sid string) (Store, error) SessionExist(sid string) bool - SessionRegenerate(oldsid, sid string) (SessionStore, error) + SessionRegenerate(oldsid, sid string) (Store, error) SessionDestroy(sid string) error SessionAll() int //get all active session SessionGC() @@ -83,7 +83,7 @@ type managerConfig struct { CookieLifeTime int `json:"cookieLifeTime"` ProviderConfig string `json:"providerConfig"` Domain string `json:"domain"` - SessionIdLength int64 `json:"sessionIdLength"` + SessionIDLength int64 `json:"sessionIDLength"` } // Manager contains Provider and its configuration. @@ -92,7 +92,7 @@ type Manager struct { config *managerConfig } -// Create new Manager with provider name and json config string. +// NewManager Create new Manager with provider name and json config string. // provider name: // 1. cookie // 2. file @@ -123,8 +123,8 @@ func NewManager(provideName, config string) (*Manager, error) { return nil, err } - if cf.SessionIdLength == 0 { - cf.SessionIdLength = 16 + if cf.SessionIDLength == 0 { + cf.SessionIDLength = 16 } return &Manager{ @@ -133,104 +133,108 @@ func NewManager(provideName, config string) (*Manager, error) { }, nil } -// Start session. generate or read the session id from http request. -// if session id exists, return SessionStore with this id. -func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore, err error) { +// getSid retrieves session identifier from HTTP Request. +// First try to retrieve id by reading from cookie, session cookie name is configurable, +// if not exist, then retrieve id from querying parameters. +// +// error is not nil when there is anything wrong. +// sid is empty when need to generate a new session id +// otherwise return an valid session id. +func (manager *Manager) getSid(r *http.Request) (string, error) { cookie, errs := r.Cookie(manager.config.CookieName) - if errs != nil || cookie.Value == "" { - sid, errs := manager.sessionId(r) + if errs != nil || cookie.Value == "" || cookie.MaxAge < 0 { + errs := r.ParseForm() if errs != nil { - return nil, errs - } - session, err = manager.provider.SessionRead(sid) - cookie = &http.Cookie{ - Name: manager.config.CookieName, - Value: url.QueryEscape(sid), - Path: "/", - HttpOnly: true, - Secure: manager.isSecure(r), - Domain: manager.config.Domain, - } - if manager.config.CookieLifeTime > 0 { - cookie.MaxAge = manager.config.CookieLifeTime - cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second) - } - if manager.config.EnableSetCookie { - http.SetCookie(w, cookie) - } - r.AddCookie(cookie) - } else { - sid, errs := url.QueryUnescape(cookie.Value) - if errs != nil { - return nil, errs - } - if manager.provider.SessionExist(sid) { - session, err = manager.provider.SessionRead(sid) - } else { - sid, err = manager.sessionId(r) - if err != nil { - return nil, err - } - session, err = manager.provider.SessionRead(sid) - cookie = &http.Cookie{ - Name: manager.config.CookieName, - Value: url.QueryEscape(sid), - Path: "/", - HttpOnly: true, - Secure: manager.isSecure(r), - Domain: manager.config.Domain, - } - if manager.config.CookieLifeTime > 0 { - cookie.MaxAge = manager.config.CookieLifeTime - cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second) - } - if manager.config.EnableSetCookie { - http.SetCookie(w, cookie) - } - r.AddCookie(cookie) + return "", errs } + + sid := r.FormValue(manager.config.CookieName) + return sid, nil } + + // HTTP Request contains cookie for sessionid info. + return url.QueryUnescape(cookie.Value) +} + +// SessionStart generate or read the session id from http request. +// if session id exists, return SessionStore with this id. +func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session Store, err error) { + sid, errs := manager.getSid(r) + if errs != nil { + return nil, errs + } + + if sid != "" && manager.provider.SessionExist(sid) { + return manager.provider.SessionRead(sid) + } + + // Generate a new session + sid, errs = manager.sessionID() + if errs != nil { + return nil, errs + } + + session, err = manager.provider.SessionRead(sid) + cookie := &http.Cookie{ + Name: manager.config.CookieName, + Value: url.QueryEscape(sid), + Path: "/", + HttpOnly: true, + Secure: manager.isSecure(r), + Domain: manager.config.Domain, + } + if manager.config.CookieLifeTime > 0 { + cookie.MaxAge = manager.config.CookieLifeTime + cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second) + } + if manager.config.EnableSetCookie { + http.SetCookie(w, cookie) + } + r.AddCookie(cookie) + return } -// Destroy session by its id in http request cookie. +// SessionDestroy Destroy session by its id in http request cookie. func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie(manager.config.CookieName) if err != nil || cookie.Value == "" { return - } else { - manager.provider.SessionDestroy(cookie.Value) + } + manager.provider.SessionDestroy(cookie.Value) + if manager.config.EnableSetCookie { expiration := time.Now() - cookie := http.Cookie{Name: manager.config.CookieName, + cookie = &http.Cookie{Name: manager.config.CookieName, Path: "/", HttpOnly: true, Expires: expiration, MaxAge: -1} - http.SetCookie(w, &cookie) + + http.SetCookie(w, cookie) } } -// Get SessionStore by its id. -func (manager *Manager) GetSessionStore(sid string) (sessions SessionStore, err error) { +// GetSessionStore Get SessionStore by its id. +func (manager *Manager) GetSessionStore(sid string) (sessions Store, err error) { sessions, err = manager.provider.SessionRead(sid) return } -// Start session gc process. +// GC Start session gc process. // it can do gc in times after gc lifetime. func (manager *Manager) GC() { manager.provider.SessionGC() time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() }) } -// Regenerate a session id for this SessionStore who's id is saving in http request. -func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Request) (session SessionStore) { - sid, err := manager.sessionId(r) +// SessionRegenerateID Regenerate a session id for this SessionStore who's id is saving in http request. +func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Request) (session Store) { + sid, err := manager.sessionID() if err != nil { return } cookie, err := r.Cookie(manager.config.CookieName) - if err != nil && cookie.Value == "" { + if err != nil || cookie.Value == "" { //delete old cookie session, _ = manager.provider.SessionRead(sid) cookie = &http.Cookie{Name: manager.config.CookieName, @@ -251,23 +255,25 @@ func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Reque cookie.MaxAge = manager.config.CookieLifeTime cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second) } - http.SetCookie(w, cookie) + if manager.config.EnableSetCookie { + http.SetCookie(w, cookie) + } r.AddCookie(cookie) return } -// Get all active sessions count number. +// GetActiveSession Get all active sessions count number. func (manager *Manager) GetActiveSession() int { return manager.provider.SessionAll() } -// Set cookie with https. +// SetSecure Set cookie with https. func (manager *Manager) SetSecure(secure bool) { manager.config.Secure = secure } -func (manager *Manager) sessionId(r *http.Request) (string, error) { - b := make([]byte, manager.config.SessionIdLength) +func (manager *Manager) sessionID() (string, error) { + b := make([]byte, manager.config.SessionIDLength) n, err := rand.Read(b) if n != len(b) || err != nil { return "", fmt.Errorf("Could not successfully read from the system CSPRNG.") diff --git a/staticfile.go b/staticfile.go index 7c1ed98c..f9f3dc3e 100644 --- a/staticfile.go +++ b/staticfile.go @@ -15,107 +15,183 @@ package beego import ( + "bytes" "net/http" "os" "path" + "path/filepath" "strconv" "strings" + "sync" + + "errors" + + "time" "github.com/astaxie/beego/context" - "github.com/astaxie/beego/utils" ) +var notStaticRequestErr = errors.New("request not a static file request") + func serverStaticRouter(ctx *context.Context) { if ctx.Input.Method() != "GET" && ctx.Input.Method() != "HEAD" { return } - requestPath := path.Clean(ctx.Input.Request.URL.Path) - i := 0 - for prefix, staticDir := range StaticDir { + + forbidden, filePath, fileInfo, err := lookupFile(ctx) + if err == notStaticRequestErr { + return + } + + if forbidden { + exception("403", ctx) + return + } + + if filePath == "" || fileInfo == nil { + if BConfig.RunMode == DEV { + Warn("Can't find/open the file:", filePath, err) + } + http.NotFound(ctx.ResponseWriter, ctx.Request) + return + } + if fileInfo.IsDir() { + //serveFile will list dir + http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath) + return + } + + var enableCompress = BConfig.EnableGzip && isStaticCompress(filePath) + var acceptEncoding string + if enableCompress { + acceptEncoding = context.ParseEncoding(ctx.Request) + } + b, n, sch, err := openFile(filePath, fileInfo, acceptEncoding) + if err != nil { + if BConfig.RunMode == DEV { + Warn("Can't compress the file:", filePath, err) + } + http.NotFound(ctx.ResponseWriter, ctx.Request) + return + } + + if b { + ctx.Output.Header("Content-Encoding", n) + } else { + ctx.Output.Header("Content-Length", strconv.FormatInt(sch.size, 10)) + } + + http.ServeContent(ctx.ResponseWriter, ctx.Request, filePath, sch.modTime, sch) + return + +} + +type serveContentHolder struct { + *bytes.Reader + modTime time.Time + size int64 + encoding string +} + +var ( + staticFileMap = make(map[string]*serveContentHolder) + mapLock sync.Mutex +) + +func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, string, *serveContentHolder, error) { + mapKey := acceptEncoding + ":" + filePath + mapFile, _ := staticFileMap[mapKey] + if isOk(mapFile, fi) { + return mapFile.encoding != "", mapFile.encoding, mapFile, nil + } + mapLock.Lock() + defer mapLock.Unlock() + if mapFile, _ = staticFileMap[mapKey]; !isOk(mapFile, fi) { + file, err := os.Open(filePath) + if err != nil { + return false, "", nil, err + } + defer file.Close() + var bufferWriter bytes.Buffer + _, n, err := context.WriteFile(acceptEncoding, &bufferWriter, file) + if err != nil { + return false, "", nil, err + } + mapFile = &serveContentHolder{Reader: bytes.NewReader(bufferWriter.Bytes()), modTime: fi.ModTime(), size: int64(bufferWriter.Len()), encoding: n} + staticFileMap[mapKey] = mapFile + } + + return mapFile.encoding != "", mapFile.encoding, mapFile, nil +} + +func isOk(s *serveContentHolder, fi os.FileInfo) bool { + if s == nil { + return false + } + return s.modTime == fi.ModTime() && s.size == fi.Size() +} + +// isStaticCompress detect static files +func isStaticCompress(filePath string) bool { + for _, statExtension := range BConfig.WebConfig.StaticExtensionsToGzip { + if strings.HasSuffix(strings.ToLower(filePath), strings.ToLower(statExtension)) { + return true + } + } + return false +} + +// searchFile search the file by url path +// if none the static file prefix matches ,return notStaticRequestErr +func searchFile(ctx *context.Context) (string, os.FileInfo, error) { + requestPath := filepath.ToSlash(filepath.Clean(ctx.Request.URL.Path)) + // special processing : favicon.ico/robots.txt can be in any static dir + if requestPath == "/favicon.ico" || requestPath == "/robots.txt" { + file := path.Join(".", requestPath) + if fi, _ := os.Stat(file); fi != nil { + return file, fi, nil + } + for _, staticDir := range BConfig.WebConfig.StaticDir { + filePath := path.Join(staticDir, requestPath) + if fi, _ := os.Stat(filePath); fi != nil { + return filePath, fi, nil + } + } + return "", nil, errors.New(requestPath + " file not find") + } + + for prefix, staticDir := range BConfig.WebConfig.StaticDir { if len(prefix) == 0 { continue } - if requestPath == "/favicon.ico" || requestPath == "/robots.txt" { - file := path.Join(staticDir, requestPath) - if utils.FileExists(file) { - http.ServeFile(ctx.ResponseWriter, ctx.Request, file) - return - } else { - i++ - if i == len(StaticDir) { - http.NotFound(ctx.ResponseWriter, ctx.Request) - return - } else { - continue - } - } + if !strings.Contains(requestPath, prefix) { + continue } - if strings.HasPrefix(requestPath, prefix) { - if len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' { - continue - } - file := path.Join(staticDir, requestPath[len(prefix):]) - finfo, err := os.Stat(file) - if err != nil { - if RunMode == "dev" { - Warn("Can't find the file:", file, err) - } - http.NotFound(ctx.ResponseWriter, ctx.Request) - return - } - //if the request is dir and DirectoryIndex is false then - if finfo.IsDir() { - if !DirectoryIndex { - exception("403", ctx) - return - } else if ctx.Input.Request.URL.Path[len(ctx.Input.Request.URL.Path)-1] != '/' { - http.Redirect(ctx.ResponseWriter, ctx.Request, ctx.Input.Request.URL.Path+"/", 302) - return - } - } else if strings.HasSuffix(requestPath, "/index.html") { - file := path.Join(staticDir, requestPath) - if utils.FileExists(file) { - http.ServeFile(ctx.ResponseWriter, ctx.Request, file) - return - } - } - - //This block obtained from (https://github.com/smithfox/beego) - it should probably get merged into astaxie/beego after a pull request - isStaticFileToCompress := false - if StaticExtensionsToGzip != nil && len(StaticExtensionsToGzip) > 0 { - for _, statExtension := range StaticExtensionsToGzip { - if strings.HasSuffix(strings.ToLower(file), strings.ToLower(statExtension)) { - isStaticFileToCompress = true - break - } - } - } - - if isStaticFileToCompress { - var contentEncoding string - if EnableGzip { - contentEncoding = getAcceptEncodingZip(ctx.Request) - } - - memzipfile, err := openMemZipFile(file, contentEncoding) - if err != nil { - return - } - - if contentEncoding == "gzip" { - ctx.Output.Header("Content-Encoding", "gzip") - } else if contentEncoding == "deflate" { - ctx.Output.Header("Content-Encoding", "deflate") - } else { - ctx.Output.Header("Content-Length", strconv.FormatInt(finfo.Size(), 10)) - } - - http.ServeContent(ctx.ResponseWriter, ctx.Request, file, finfo.ModTime(), memzipfile) - - } else { - http.ServeFile(ctx.ResponseWriter, ctx.Request, file) - } - return + if len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' { + continue + } + filePath := path.Join(staticDir, requestPath[len(prefix):]) + if fi, err := os.Stat(filePath); fi != nil { + return filePath, fi, err } } + return "", nil, notStaticRequestErr +} + +// lookupFile find the file to serve +// if the file is dir ,search the index.html as default file( MUST NOT A DIR also) +// if the index.html not exist or is a dir, give a forbidden response depending on DirectoryIndex +func lookupFile(ctx *context.Context) (bool, string, os.FileInfo, error) { + fp, fi, err := searchFile(ctx) + if fp == "" || fi == nil { + return false, "", nil, err + } + if !fi.IsDir() { + return false, fp, fi, err + } + ifp := filepath.Join(fp, "index.html") + if ifi, _ := os.Stat(ifp); ifi != nil && ifi.Mode().IsRegular() { + return false, ifp, ifi, err + } + return !BConfig.WebConfig.DirectoryIndex, fp, fi, err } diff --git a/staticfile_test.go b/staticfile_test.go new file mode 100644 index 00000000..d3333570 --- /dev/null +++ b/staticfile_test.go @@ -0,0 +1,71 @@ +package beego + +import ( + "bytes" + "compress/gzip" + "compress/zlib" + "io" + "io/ioutil" + "os" + "testing" +) + +const licenseFile = "./LICENSE" + +func testOpenFile(encoding string, content []byte, t *testing.T) { + fi, _ := os.Stat(licenseFile) + b, n, sch, err := openFile(licenseFile, fi, encoding) + if err != nil { + t.Log(err) + t.Fail() + } + + t.Log("open static file encoding "+n, b) + + assetOpenFileAndContent(sch, content, t) +} +func TestOpenStaticFile_1(t *testing.T) { + file, _ := os.Open(licenseFile) + content, _ := ioutil.ReadAll(file) + testOpenFile("", content, t) +} + +func TestOpenStaticFileGzip_1(t *testing.T) { + file, _ := os.Open(licenseFile) + var zipBuf bytes.Buffer + fileWriter, _ := gzip.NewWriterLevel(&zipBuf, gzip.BestCompression) + io.Copy(fileWriter, file) + fileWriter.Close() + content, _ := ioutil.ReadAll(&zipBuf) + + testOpenFile("gzip", content, t) +} +func TestOpenStaticFileDeflate_1(t *testing.T) { + file, _ := os.Open(licenseFile) + var zipBuf bytes.Buffer + fileWriter, _ := zlib.NewWriterLevel(&zipBuf, zlib.BestCompression) + io.Copy(fileWriter, file) + fileWriter.Close() + content, _ := ioutil.ReadAll(&zipBuf) + + testOpenFile("deflate", content, t) +} + +func assetOpenFileAndContent(sch *serveContentHolder, content []byte, t *testing.T) { + t.Log(sch.size, len(content)) + if sch.size != int64(len(content)) { + t.Log("static content file size not same") + t.Fail() + } + bs, _ := ioutil.ReadAll(sch) + for i, v := range content { + if v != bs[i] { + t.Log("content not same") + t.Fail() + } + } + if len(staticFileMap) == 0 { + t.Log("men map is empty") + t.Fail() + } +} diff --git a/swagger/docsSpec.go b/swagger/docs_spec.go similarity index 72% rename from swagger/docsSpec.go rename to swagger/docs_spec.go index 6f8fe1db..d8402aa5 100644 --- a/swagger/docsSpec.go +++ b/swagger/docs_spec.go @@ -12,53 +12,59 @@ // See the License for the specific language governing permissions and // limitations under the License. -// swagger struct definition +// Package swagger struct definition package swagger +// SwaggerVersion show the current swagger version const SwaggerVersion = "1.2" +// ResourceListing list the resource type ResourceListing struct { - ApiVersion string `json:"apiVersion"` + APIVersion string `json:"apiVersion"` SwaggerVersion string `json:"swaggerVersion"` // e.g 1.2 // BasePath string `json:"basePath"` obsolete in 1.1 - Apis []ApiRef `json:"apis"` - Infos Infomation `json:"info"` + APIs []APIRef `json:"apis"` + Info Information `json:"info"` } -type ApiRef struct { +// APIRef description the api path and description +type APIRef struct { Path string `json:"path"` // relative or absolute, must start with / Description string `json:"description"` } -type Infomation struct { +// Information show the API Information +type Information struct { Title string `json:"title,omitempty"` Description string `json:"description,omitempty"` Contact string `json:"contact,omitempty"` - TermsOfServiceUrl string `json:"termsOfServiceUrl,omitempty"` + TermsOfServiceURL string `json:"termsOfServiceUrl,omitempty"` License string `json:"license,omitempty"` - LicenseUrl string `json:"licenseUrl,omitempty"` + LicenseURL string `json:"licenseUrl,omitempty"` } -// https://github.com/wordnik/swagger-core/blob/scala_2.10-1.3-RC3/schemas/api-declaration-schema.json -type ApiDeclaration struct { - ApiVersion string `json:"apiVersion"` +// APIDeclaration see https://github.com/wordnik/swagger-core/blob/scala_2.10-1.3-RC3/schemas/api-declaration-schema.json +type APIDeclaration struct { + APIVersion string `json:"apiVersion"` SwaggerVersion string `json:"swaggerVersion"` BasePath string `json:"basePath"` ResourcePath string `json:"resourcePath"` // must start with / Consumes []string `json:"consumes,omitempty"` Produces []string `json:"produces,omitempty"` - Apis []Api `json:"apis,omitempty"` + APIs []API `json:"apis,omitempty"` Models map[string]Model `json:"models,omitempty"` } -type Api struct { +// API show tha API struct +type API struct { Path string `json:"path"` // relative or absolute, must start with / Description string `json:"description"` Operations []Operation `json:"operations,omitempty"` } +// Operation desc the Operation type Operation struct { - HttpMethod string `json:"httpMethod"` + HTTPMethod string `json:"httpMethod"` Nickname string `json:"nickname"` Type string `json:"type"` // in 1.1 = DataType // ResponseClass string `json:"responseClass"` obsolete in 1.2 @@ -72,15 +78,18 @@ type Operation struct { Protocols []Protocol `json:"protocols,omitempty"` } +// Protocol support which Protocol type Protocol struct { } +// ResponseMessage Show the type ResponseMessage struct { Code int `json:"code"` Message string `json:"message"` ResponseModel string `json:"responseModel"` } +// Parameter desc the request parameters type Parameter struct { ParamType string `json:"paramType"` // path,query,body,header,form Name string `json:"name"` @@ -94,17 +103,20 @@ type Parameter struct { Maximum int `json:"maximum"` } +// ErrorResponse desc response type ErrorResponse struct { Code int `json:"code"` Reason string `json:"reason"` } +// Model define the data model type Model struct { - Id string `json:"id"` + ID string `json:"id"` Required []string `json:"required,omitempty"` Properties map[string]ModelProperty `json:"properties"` } +// ModelProperty define the properties type ModelProperty struct { Type string `json:"type"` Description string `json:"description"` @@ -112,20 +124,20 @@ type ModelProperty struct { Format string `json:"format"` } -// https://github.com/wordnik/swagger-core/wiki/authorizations +// Authorization see https://github.com/wordnik/swagger-core/wiki/authorizations type Authorization struct { LocalOAuth OAuth `json:"local-oauth"` - ApiKey ApiKey `json:"apiKey"` + APIKey APIKey `json:"apiKey"` } -// https://github.com/wordnik/swagger-core/wiki/authorizations +// OAuth see https://github.com/wordnik/swagger-core/wiki/authorizations type OAuth struct { Type string `json:"type"` // e.g. oauth2 Scopes []string `json:"scopes"` // e.g. PUBLIC GrantTypes map[string]GrantType `json:"grantTypes"` } -// https://github.com/wordnik/swagger-core/wiki/authorizations +// GrantType see https://github.com/wordnik/swagger-core/wiki/authorizations type GrantType struct { LoginEndpoint Endpoint `json:"loginEndpoint"` TokenName string `json:"tokenName"` // e.g. access_code @@ -133,16 +145,16 @@ type GrantType struct { TokenEndpoint Endpoint `json:"tokenEndpoint"` } -// https://github.com/wordnik/swagger-core/wiki/authorizations +// Endpoint see https://github.com/wordnik/swagger-core/wiki/authorizations type Endpoint struct { - Url string `json:"url"` - ClientIdName string `json:"clientIdName"` + URL string `json:"url"` + ClientIDName string `json:"clientIdName"` ClientSecretName string `json:"clientSecretName"` TokenName string `json:"tokenName"` } -// https://github.com/wordnik/swagger-core/wiki/authorizations -type ApiKey struct { +// APIKey see https://github.com/wordnik/swagger-core/wiki/authorizations +type APIKey struct { Type string `json:"type"` // e.g. apiKey PassAs string `json:"passAs"` // e.g. header } diff --git a/template.go b/template.go index 64b1939e..9aac3ea2 100644 --- a/template.go +++ b/template.go @@ -28,17 +28,14 @@ import ( ) var ( - beegoTplFuncMap template.FuncMap - // beego template caching map and supported template file extensions. - BeeTemplates map[string]*template.Template - BeeTemplateExt []string + beegoTplFuncMap = make(template.FuncMap) + // BeeTemplates caching map and supported template file extensions. + BeeTemplates = make(map[string]*template.Template) + // BeeTemplateExt stores the template extention which will build + BeeTemplateExt = []string{"tpl", "html"} ) func init() { - BeeTemplates = make(map[string]*template.Template) - beegoTplFuncMap = make(template.FuncMap) - BeeTemplateExt = make([]string, 0) - BeeTemplateExt = append(BeeTemplateExt, "tpl", "html") beegoTplFuncMap["dateformat"] = DateFormat beegoTplFuncMap["date"] = Date beegoTplFuncMap["compare"] = Compare @@ -46,14 +43,15 @@ func init() { beegoTplFuncMap["not_nil"] = NotNil beegoTplFuncMap["not_null"] = NotNil beegoTplFuncMap["substr"] = Substr - beegoTplFuncMap["html2str"] = Html2str + beegoTplFuncMap["html2str"] = HTML2str beegoTplFuncMap["str2html"] = Str2html beegoTplFuncMap["htmlquote"] = Htmlquote beegoTplFuncMap["htmlunquote"] = Htmlunquote beegoTplFuncMap["renderform"] = RenderForm beegoTplFuncMap["assets_js"] = AssetsJs - beegoTplFuncMap["assets_css"] = AssetsCss + beegoTplFuncMap["assets_css"] = AssetsCSS beegoTplFuncMap["config"] = Config + beegoTplFuncMap["map_get"] = MapGet // go1.2 added template funcs // Comparisons @@ -64,7 +62,7 @@ func init() { beegoTplFuncMap["lt"] = lt // < beegoTplFuncMap["ne"] = ne // != - beegoTplFuncMap["urlfor"] = UrlFor // != + beegoTplFuncMap["urlfor"] = URLFor // != } // AddFuncMap let user to register a func in the template. @@ -78,7 +76,7 @@ type templatefile struct { files map[string][]string } -func (self *templatefile) visit(paths string, f os.FileInfo, err error) error { +func (tf *templatefile) visit(paths string, f os.FileInfo, err error) error { if f == nil { return err } @@ -91,21 +89,21 @@ func (self *templatefile) visit(paths string, f os.FileInfo, err error) error { replace := strings.NewReplacer("\\", "/") a := []byte(paths) - a = a[len([]byte(self.root)):] + a = a[len([]byte(tf.root)):] file := strings.TrimLeft(replace.Replace(string(a)), "/") subdir := filepath.Dir(file) - if _, ok := self.files[subdir]; ok { - self.files[subdir] = append(self.files[subdir], file) + if _, ok := tf.files[subdir]; ok { + tf.files[subdir] = append(tf.files[subdir], file) } else { m := make([]string, 1) m[0] = file - self.files[subdir] = m + tf.files[subdir] = m } return nil } -// return this path contains supported template extension of beego or not. +// HasTemplateExt return this path contains supported template extension of beego or not. func HasTemplateExt(paths string) bool { for _, v := range BeeTemplateExt { if strings.HasSuffix(paths, "."+v) { @@ -115,7 +113,7 @@ func HasTemplateExt(paths string) bool { return false } -// add new extension for template. +// AddTemplateExt add new extension for template. func AddTemplateExt(ext string) { for _, v := range BeeTemplateExt { if v == ext { @@ -125,15 +123,14 @@ func AddTemplateExt(ext string) { BeeTemplateExt = append(BeeTemplateExt, ext) } -// build all template files in a directory. +// BuildTemplate will build all template files in a directory. // it makes beego can render any template file in view directory. -func BuildTemplate(dir string) error { +func BuildTemplate(dir string, files ...string) error { if _, err := os.Stat(dir); err != nil { if os.IsNotExist(err) { return nil - } else { - return errors.New("dir open err") } + return errors.New("dir open err") } self := &templatefile{ root: dir, @@ -148,11 +145,13 @@ func BuildTemplate(dir string) error { } for _, v := range self.files { for _, file := range v { - t, err := getTemplate(self.root, file, v...) - if err != nil { - Trace("parse template err:", file, err) - } else { - BeeTemplates[file] = t + if len(files) == 0 || utils.InSlice(file, files) { + t, err := getTemplate(self.root, file, v...) + if err != nil { + Trace("parse template err:", file, err) + } else { + BeeTemplates[file] = t + } } } } @@ -177,7 +176,7 @@ func getTplDeep(root, file, parent string, t *template.Template) (*template.Temp if err != nil { return nil, [][]string{}, err } - reg := regexp.MustCompile(TemplateLeft + "[ ]*template[ ]+\"([^\"]+)\"") + reg := regexp.MustCompile(BConfig.WebConfig.TemplateLeft + "[ ]*template[ ]+\"([^\"]+)\"") allsub := reg.FindAllStringSubmatch(string(data), -1) for _, m := range allsub { if len(m) == 2 { @@ -198,7 +197,7 @@ func getTplDeep(root, file, parent string, t *template.Template) (*template.Temp } func getTemplate(root, file string, others ...string) (t *template.Template, err error) { - t = template.New(file).Delims(TemplateLeft, TemplateRight).Funcs(beegoTplFuncMap) + t = template.New(file).Delims(BConfig.WebConfig.TemplateLeft, BConfig.WebConfig.TemplateRight).Funcs(beegoTplFuncMap) var submods [][]string t, submods, err = getTplDeep(root, file, "", t) if err != nil { @@ -240,7 +239,7 @@ func _getTemplate(t0 *template.Template, root string, submods [][]string, others if err != nil { continue } - reg := regexp.MustCompile(TemplateLeft + "[ ]*define[ ]+\"([^\"]+)\"") + reg := regexp.MustCompile(BConfig.WebConfig.TemplateLeft + "[ ]*define[ ]+\"([^\"]+)\"") allsub := reg.FindAllStringSubmatch(string(data), -1) for _, sub := range allsub { if len(sub) == 2 && sub[1] == m[1] { @@ -260,3 +259,30 @@ func _getTemplate(t0 *template.Template, root string, submods [][]string, others } return } + +// SetViewsPath sets view directory path in beego application. +func SetViewsPath(path string) *App { + BConfig.WebConfig.ViewsPath = path + return BeeApp +} + +// SetStaticPath sets static directory path and proper url pattern in beego application. +// if beego.SetStaticPath("static","public"), visit /static/* to load static file in folder "public". +func SetStaticPath(url string, path string) *App { + if !strings.HasPrefix(url, "/") { + url = "/" + url + } + url = strings.TrimRight(url, "/") + BConfig.WebConfig.StaticDir[url] = path + return BeeApp +} + +// DelStaticPath removes the static folder setting in this url pattern in beego application. +func DelStaticPath(url string) *App { + if !strings.HasPrefix(url, "/") { + url = "/" + url + } + url = strings.TrimRight(url, "/") + delete(BConfig.WebConfig.StaticDir, url) + return BeeApp +} diff --git a/template_test.go b/template_test.go index b35da5ce..2e222efc 100644 --- a/template_test.go +++ b/template_test.go @@ -20,11 +20,11 @@ import ( "testing" ) -var header string = `{{define "header"}} +var header = `{{define "header"}}

Hello, astaxie!

{{end}}` -var index string = ` +var index = ` beego welcome template @@ -37,7 +37,7 @@ var index string = ` ` -var block string = `{{define "block"}} +var block = `{{define "block"}}

Hello, blocks!

{{end}}` @@ -82,7 +82,7 @@ func TestTemplate(t *testing.T) { os.RemoveAll(dir) } -var menu string = `